#!/usr/bin/ruby -Ku
#encoding: utf-8
#適用多類別, 產生ccr, micro f1, macro f
#讀取格式需為 libsvm format
require 'fileutils'
fold_csv = Dir.glob("./fold/*.csv")
fold_computing_data = Array.new
fold_computing_data << "fold_name,data size,class A,class B,class C,class D,ccr,Micro F1,Macro F1"
fold_csv.sort!
#從fold0計算到fold4
#index = 0
for index in 0 ..(fold_csv.size-1) do
printf("%s calating...\n",File.basename("#{fold_csv[index]}"))
cat_a = 0 ; cat_b = 0 ; cat_c = 0 ; cat_d = 0 ;
raw_data = Array.new
class_label = Array.new
f=File.new("#{fold_csv[index]}")
while (line = f.gets)
raw_data << line.split(' ')[0] #測試資料的label
cat_a += 1 if (line[0] == "1")
cat_b += 1 if (line[0] == "2")
cat_c += 1 if (line[0] == "3")
cat_d += 1 if (line[0] == "4")
end
f.close
data_size = raw_data.size
class_label = raw_data.sort.uniq #共有幾類
micro_a=0 ; micro_c=0; micro_b=0;
ccr_sum = 0 #該類別並且被正確預測為該類別
#ccr = Array.new
fmeasure = Array.new
ccr = Array.new
#四類別,逐一算出class_label[a,b,c,d]
(class_label.size).times do |cat|
trains_data = Array.new
tests_data = Array.new
#建立測試資料
f=File.new("#{fold_csv[index]}")
while (line = f.gets)
if (class_label[cat] == line.split(' ')[0])
tests_data << ["1",line.split(' ')[1..-1]].join(' ')
else
tests_data << ["-1",line.split(' ')[1..-1]].join(' ')
end
end
f.close
FileUtils.rm Dir.glob("tests.libsvm") #寫入
File.open("./tests.libsvm","a") do |txt|
tests_data.each do |word|
txt.puts word
end
end
tests_data.clear
#建立訓練資料
for i in 0..fold_csv.size-1
if i != index
f=File.new("#{fold_csv[i]}")
while (line = f.gets)
if (class_label[cat] == line.split(' ')[0])
trains_data << ["1",line.split(' ')[1..-1]].join(' ')
else
trains_data << ["-1",line.split(' ')[1..-1]].join(' ')
end
end
f.close
end #end if i != index
end #end for i in 0..fold_csv.size-1
FileUtils.rm Dir.glob("trains.libsvm") #寫入
File.open("./trains.libsvm","a") do |txt|
trains_data.each do |word|
txt.puts word
end
end
trains_data.clear
FileUtils.rm Dir.glob("libsvm_result")
system("svm-train -t 0 trains.libsvm libsvm_model")
system("svm-predict tests.libsvm libsvm_model libsvm_result")
#讀取結果
svm_ans = Array.new
f=File.new("./libsvm_result")
while (line = f.gets)
svm_ans << line[0] #預測的結果
end
f.close
#測試資料的類別
ori_data = Array.new
f=File.new("./tests.libsvm")
while (line = f.gets)
ori_data << line[0] #原始資料(正、負類別)
end
f.close
a=0 ;b=0 ; c=0; d=0 ;
(svm_ans.size).times do |x|
a += 1 if (ori_data[x][0] == "1" and svm_ans[x] == "1")
b += 1 if (ori_data[x][0] == "1" and svm_ans[x] != "1")
c += 1 if (ori_data[x][0] != "1" and svm_ans[x] == "1")
d += 1 if (ori_data[x][0] != "1" and svm_ans[x] != "1")
end #end (svm_ans.size).times
svm_ans.clear
ori_data.clear
#a:tp b:fp c:fn d:tn
printf("[%s],a:%d b:%d c:%d d:%d\n",class_label[cat],a,b,c,d)
printf("accuracy:%2.3f\n",100*(a+d)/(a+b+c+d).to_f)
micro_a += a
micro_b += b
micro_c += c
ccr[cat] = 100*(a+d)/(a+b+c+d).to_f
precision = a.to_f/(a+c)
recall = a.to_f/(a+b)
if ((precision+recall) > 0)
fmeasure[cat] = (2*precision*recall)/((precision+recall).to_f) #fmeasure = 2pr/(p+r)
else
fmeasure[cat] = 0
end
#printf("micro_a:%d,micro b:%d, micro c:%d\n", micro_a,micro_b, micro_c)
end #end (class_label.size).times do |cat|
#計算此折的正確率
ccr_rate = sprintf("%2.3f",ccr.inject(0) do |sum, i| sum +i end.to_f/class_label.size) #所有類別中能正確分類出/總資料
microP = micro_a.to_f/(micro_a+micro_c)
microR = micro_a.to_f/(micro_a+micro_b)
if ((microP+microR) > 0)
micro_f1 = sprintf("%2.3f",100*2*microP*microR.to_f/(microP+microR))
else
micro_f1 = 0
end
macro_f1 = sprintf("%2.3f",100*fmeasure.inject(0) do |sum, i| sum +i end.to_f/class_label.size) #macro f1的得分
#printf("p:%2.3f,r:%2.3f, micro f1:%2.3f\n",microP,microR,micro_f1)
#printf("ccr:%2.3f",ans/4.0)
tmp = Array.new
tmp << "#{index}" << raw_data.size << cat_a << cat_b << cat_c << cat_d << ccr_rate << micro_f1 << macro_f1
raw_data.clear
class_label.clear
fmeasure.clear
fold_computing_data << tmp.join(',')
tmp.clear
ccr.clear
FileUtils.rm Dir.glob("libsvm_result")
end #end (fold_csv.size).times
printf("oaasvm\n")
fold_computing_data.each do |x|
puts x
end
#FileUtils.rm Dir.glob("trains.libsvm")
#FileUtils.rm Dir.glob("libsvm_model")
|