OAASVM,CCR, MICRO F1,MACROF1


#!/usr/bin/ruby -Ku
#encoding: utf-8
#適用多類別, 產生ccr, micro f1, macro f

#讀取格式需為 libsvm format


require 'fileutils'


fold_csv = Dir.glob("./fold/*.csv")
fold_computing_data = Array.new

fold_computing_data <<  "fold_name,data size,class A,class B,class C,class D,ccr,Micro F1,Macro F1"
fold_csv.sort!
#從fold0計算到fold4
#index = 0
for index in 0 ..(fold_csv.size-1) do
  printf("%s calating...\n",File.basename("#{fold_csv[index]}"))
  cat_a = 0 ; cat_b = 0 ; cat_c = 0 ; cat_d = 0 ;
  raw_data = Array.new
  class_label = Array.new
  f=File.new("#{fold_csv[index]}")
  while (line = f.gets)
    raw_data << line.split(' ')[0] #測試資料的label
    cat_a += 1 if (line[0] == "1")
    cat_b += 1 if (line[0] == "2")
    cat_c += 1 if (line[0] == "3")
    cat_d += 1 if (line[0] == "4")
  end
  f.close
  data_size = raw_data.size
  class_label = raw_data.sort.uniq  #共有幾類

  micro_a=0 ; micro_c=0; micro_b=0;
  ccr_sum = 0 #該類別並且被正確預測為該類別
#ccr = Array.new
  fmeasure = Array.new
  ccr = Array.new
#四類別,逐一算出class_label[a,b,c,d]
  (class_label.size).times do |cat|
  trains_data = Array.new
  tests_data = Array.new
      #建立測試資料
        f=File.new("#{fold_csv[index]}")
        while (line = f.gets)
          if (class_label[cat] == line.split(' ')[0])
            tests_data << ["1",line.split(' ')[1..-1]].join(' ')
          else
            tests_data << ["-1",line.split(' ')[1..-1]].join(' ')
          end
        end
        f.close
       FileUtils.rm Dir.glob("tests.libsvm") #寫入
       File.open("./tests.libsvm","a") do |txt|
          tests_data.each do |word|
             txt.puts word
          end
       end
       tests_data.clear

        #建立訓練資料
      for i in 0..fold_csv.size-1
        if i != index
          f=File.new("#{fold_csv[i]}")
          while (line = f.gets)
            if (class_label[cat] == line.split(' ')[0])
              trains_data << ["1",line.split(' ')[1..-1]].join(' ')
            else
              trains_data << ["-1",line.split(' ')[1..-1]].join(' ')
            end
          end
          f.close
        end #end if i != index
      end #end for i in 0..fold_csv.size-1
       FileUtils.rm Dir.glob("trains.libsvm") #寫入
       File.open("./trains.libsvm","a") do |txt|
          trains_data.each do |word|
             txt.puts word
          end
       end
       trains_data.clear

    FileUtils.rm Dir.glob("libsvm_result")
    system("svm-train -t 0 trains.libsvm libsvm_model")
    system("svm-predict tests.libsvm libsvm_model libsvm_result")

    #讀取結果
    svm_ans = Array.new
    f=File.new("./libsvm_result")
    while (line = f.gets)
      svm_ans << line[0]  #預測的結果
    end
    f.close
   
    #測試資料的類別
    ori_data = Array.new
    f=File.new("./tests.libsvm")
    while (line = f.gets)
      ori_data << line[0]  #原始資料(正、負類別)
    end
    f.close
    a=0 ;b=0 ; c=0; d=0 ;
    (svm_ans.size).times do |x|
      a += 1 if (ori_data[x][0] == "1" and svm_ans[x] == "1")
      b += 1 if (ori_data[x][0] == "1" and svm_ans[x] != "1")
      c += 1 if (ori_data[x][0] != "1" and svm_ans[x] == "1")
      d += 1 if (ori_data[x][0] != "1" and svm_ans[x] != "1")
    end #end (svm_ans.size).times
    svm_ans.clear
    ori_data.clear
    #a:tp b:fp c:fn d:tn
    printf("[%s],a:%d b:%d c:%d d:%d\n",class_label[cat],a,b,c,d)
    printf("accuracy:%2.3f\n",100*(a+d)/(a+b+c+d).to_f)
    micro_a += a
    micro_b += b
    micro_c += c
    ccr[cat] = 100*(a+d)/(a+b+c+d).to_f
    precision = a.to_f/(a+c)
    recall = a.to_f/(a+b)

    if ((precision+recall) > 0)
      fmeasure[cat] = (2*precision*recall)/((precision+recall).to_f)  #fmeasure = 2pr/(p+r)
    else
      fmeasure[cat] = 0
    end
  #printf("micro_a:%d,micro b:%d, micro c:%d\n", micro_a,micro_b, micro_c)
  end  #end (class_label.size).times do |cat|

#計算此折的正確率
  ccr_rate = sprintf("%2.3f",ccr.inject(0) do |sum, i| sum +i end.to_f/class_label.size) #所有類別中能正確分類出/總資料
  microP = micro_a.to_f/(micro_a+micro_c)
  microR = micro_a.to_f/(micro_a+micro_b)
    if ((microP+microR) > 0)
      micro_f1 = sprintf("%2.3f",100*2*microP*microR.to_f/(microP+microR))
    else
      micro_f1 = 0
    end
  macro_f1 = sprintf("%2.3f",100*fmeasure.inject(0) do |sum, i| sum +i end.to_f/class_label.size) #macro f1的得分
  #printf("p:%2.3f,r:%2.3f, micro f1:%2.3f\n",microP,microR,micro_f1)
#printf("ccr:%2.3f",ans/4.0)
  tmp = Array.new
  tmp << "#{index}" << raw_data.size << cat_a << cat_b << cat_c << cat_d << ccr_rate << micro_f1 << macro_f1
  raw_data.clear
  class_label.clear
  fmeasure.clear
  fold_computing_data << tmp.join(',')
  tmp.clear
  ccr.clear
FileUtils.rm Dir.glob("libsvm_result")
end #end (fold_csv.size).times

printf("oaasvm\n")
fold_computing_data.each do |x|
  puts x
end
#FileUtils.rm Dir.glob("trains.libsvm")
#FileUtils.rm Dir.glob("libsvm_model")





創作者介紹
創作者 igogo 的頭像
igogo

牛大叔.生活隨筆

igogo 發表在 痞客邦 留言(0) 人氣()