#!/usr/bin/ruby -Ku #encoding: utf-8 #適用多類別, 產生ccr, micro f1, macro f
#讀取格式需為 libsvm format
require 'fileutils'
fold_csv = Dir.glob("./fold/*.csv") fold_computing_data = Array.new
fold_computing_data << "fold_name,data size,class A,class B,class C,class D,ccr,Micro F1,Macro F1" fold_csv.sort! #從fold0計算到fold4 #index = 0 for index in 0 ..(fold_csv.size-1) do printf("%s calating...\n",File.basename("#{fold_csv[index]}")) cat_a = 0 ; cat_b = 0 ; cat_c = 0 ; cat_d = 0 ; raw_data = Array.new class_label = Array.new f=File.new("#{fold_csv[index]}") while (line = f.gets) raw_data << line.split(' ')[0] #測試資料的label cat_a += 1 if (line[0] == "1") cat_b += 1 if (line[0] == "2") cat_c += 1 if (line[0] == "3") cat_d += 1 if (line[0] == "4") end f.close data_size = raw_data.size class_label = raw_data.sort.uniq #共有幾類
micro_a=0 ; micro_c=0; micro_b=0; ccr_sum = 0 #該類別並且被正確預測為該類別 #ccr = Array.new fmeasure = Array.new ccr = Array.new #四類別,逐一算出class_label[a,b,c,d] (class_label.size).times do |cat| trains_data = Array.new tests_data = Array.new #建立測試資料 f=File.new("#{fold_csv[index]}") while (line = f.gets) if (class_label[cat] == line.split(' ')[0]) tests_data << ["1",line.split(' ')[1..-1]].join(' ') else tests_data << ["-1",line.split(' ')[1..-1]].join(' ') end end f.close FileUtils.rm Dir.glob("tests.libsvm") #寫入 File.open("./tests.libsvm","a") do |txt| tests_data.each do |word| txt.puts word end end tests_data.clear
#建立訓練資料 for i in 0..fold_csv.size-1 if i != index f=File.new("#{fold_csv[i]}") while (line = f.gets) if (class_label[cat] == line.split(' ')[0]) trains_data << ["1",line.split(' ')[1..-1]].join(' ') else trains_data << ["-1",line.split(' ')[1..-1]].join(' ') end end f.close end #end if i != index end #end for i in 0..fold_csv.size-1 FileUtils.rm Dir.glob("trains.libsvm") #寫入 File.open("./trains.libsvm","a") do |txt| trains_data.each do |word| txt.puts word end end trains_data.clear
FileUtils.rm Dir.glob("libsvm_result") system("svm-train -t 0 trains.libsvm libsvm_model") system("svm-predict tests.libsvm libsvm_model libsvm_result")
#讀取結果 svm_ans = Array.new f=File.new("./libsvm_result") while (line = f.gets) svm_ans << line[0] #預測的結果 end f.close #測試資料的類別 ori_data = Array.new f=File.new("./tests.libsvm") while (line = f.gets) ori_data << line[0] #原始資料(正、負類別) end f.close a=0 ;b=0 ; c=0; d=0 ; (svm_ans.size).times do |x| a += 1 if (ori_data[x][0] == "1" and svm_ans[x] == "1") b += 1 if (ori_data[x][0] == "1" and svm_ans[x] != "1") c += 1 if (ori_data[x][0] != "1" and svm_ans[x] == "1") d += 1 if (ori_data[x][0] != "1" and svm_ans[x] != "1") end #end (svm_ans.size).times svm_ans.clear ori_data.clear #a:tp b:fp c:fn d:tn printf("[%s],a:%d b:%d c:%d d:%d\n",class_label[cat],a,b,c,d) printf("accuracy:%2.3f\n",100*(a+d)/(a+b+c+d).to_f) micro_a += a micro_b += b micro_c += c ccr[cat] = 100*(a+d)/(a+b+c+d).to_f precision = a.to_f/(a+c) recall = a.to_f/(a+b)
if ((precision+recall) > 0) fmeasure[cat] = (2*precision*recall)/((precision+recall).to_f) #fmeasure = 2pr/(p+r) else fmeasure[cat] = 0 end #printf("micro_a:%d,micro b:%d, micro c:%d\n", micro_a,micro_b, micro_c) end #end (class_label.size).times do |cat|
#計算此折的正確率 ccr_rate = sprintf("%2.3f",ccr.inject(0) do |sum, i| sum +i end.to_f/class_label.size) #所有類別中能正確分類出/總資料 microP = micro_a.to_f/(micro_a+micro_c) microR = micro_a.to_f/(micro_a+micro_b) if ((microP+microR) > 0) micro_f1 = sprintf("%2.3f",100*2*microP*microR.to_f/(microP+microR)) else micro_f1 = 0 end macro_f1 = sprintf("%2.3f",100*fmeasure.inject(0) do |sum, i| sum +i end.to_f/class_label.size) #macro f1的得分 #printf("p:%2.3f,r:%2.3f, micro f1:%2.3f\n",microP,microR,micro_f1) #printf("ccr:%2.3f",ans/4.0) tmp = Array.new tmp << "#{index}" << raw_data.size << cat_a << cat_b << cat_c << cat_d << ccr_rate << micro_f1 << macro_f1 raw_data.clear class_label.clear fmeasure.clear fold_computing_data << tmp.join(',') tmp.clear ccr.clear FileUtils.rm Dir.glob("libsvm_result") end #end (fold_csv.size).times
printf("oaasvm\n") fold_computing_data.each do |x| puts x end #FileUtils.rm Dir.glob("trains.libsvm") #FileUtils.rm Dir.glob("libsvm_model")
|