iT邦幫忙

2021 iThome 鐵人賽

DAY 21
0
AI & Data

手寫中文字之影像辨識系列 第 21

【第21天】訓練模型-模型組合與辨識isnull(二)

  • 分享至 

  • xImage
  •  

摘要

  1. 作業流程
  2. 設定資料集路徑
  3. 找出每個中文字的閾值
  4. 任意選擇奇數個模型組合後,產生模型權重表與利用新模型權重得到的機率表。
  5. 判斷isnull

內容

  1. 作業流程(今日進度請參閱紅框處)

  2. 設定資料集路徑

    2.1 我們有7個模型,每個模型輸出3個機率表(官方800字內、官方800字外、測試賽),共21個。

    2.2 機率表中有803個欄位,分別是1~800字機率、預測值、實際值及是否正確預測。

    2.3 程式碼

    # 官方800字內機率表路徑
    offical_in800_1 = read.csv(file = "C:/Users/wooden/Desktop/dl/probCSV/densenet201_v2_2/official_in_800.csv",fileEncoding="UTF-8-BOM")
    offical_in800_2 = read.csv(file = "C:/Users/wooden/Desktop/dl/probCSV/resnet152V2_v1_2/official_in_800.csv",fileEncoding="UTF-8-BOM")
    offical_in800_3 = read.csv(file = "C:/Users/wooden/Desktop/dl/probCSV/xception_v2_2/official_in_800.csv",fileEncoding="UTF-8-BOM")
    offical_in800_ex3 = read.csv(file = "C:/Users/wooden/Desktop/dl/probCSV/inceptionResNetV2_v1_2/official_in_800.csv",fileEncoding="UTF-8-BOM")
    offical_in800_ex4 = read.csv(file = "C:/Users/wooden/Desktop/dl/probCSV/densenet201_in800_official_韋智.csv")
    offical_in800_ex5 = read.csv(file = "C:/Users/wooden/Desktop/dl/probCSV/prob炫斐/official_in_800.csv",fileEncoding="UTF-8-BOM")
    offical_in800_ex6 = read.csv(file = "C:/Users/wooden/Desktop/dl/probCSV/swa_v2/swa_v2_in800_official_韋智.csv")
    
    # 官方800字內機率表之預測值與實際值
    names(offical_in800_1)[801:802] = c('predict_word',"origin_word")
    names(offical_in800_2)[801:802] = c('predict_word',"origin_word")
    names(offical_in800_3)[801:802] = c('predict_word',"origin_word")
    names(offical_in800_ex3)[801:802] = c('predict_word',"origin_word")
    names(offical_in800_ex4)[801:802] = c('predict_word',"origin_word")
    names(offical_in800_ex5)[801:802] = c('predict_word',"origin_word")
    names(offical_in800_ex6)[801:802] = c('predict_word',"origin_word")
    
    # 官方800字外機率表路徑
    offical_noin800_1 = read.csv(file = "C:/Users/wooden/Desktop/dl/probCSV/densenet201_v2_2/official_notin_800.csv",fileEncoding="UTF-8-BOM")
    offical_noin800_2 = read.csv(file = "C:/Users/wooden/Desktop/dl/probCSV/resnet152V2_v1_2/official_notin_800.csv",fileEncoding="UTF-8-BOM")
    offical_noin800_3 = read.csv(file = "C:/Users/wooden/Desktop/dl/probCSV/xception_v2_2/official_notin_800.csv",fileEncoding="UTF-8-BOM")
    offical_noin800_ex3 = read.csv(file = "C:/Users/wooden/Desktop/dl/probCSV/inceptionResNetV2_v1_2/official_notin_800.csv",fileEncoding="UTF-8-BOM")
    offical_noin800_ex4 = read.csv(file = "C:/Users/wooden/Desktop/dl/probCSV/densenet201_notin800_official_韋智.csv")
    offical_noin800_ex5 = read.csv(file = "C:/Users/wooden/Desktop/dl/probCSV/prob炫斐/official_notin_800.csv",fileEncoding="UTF-8-BOM")
    offical_noin800_ex6 = read.csv(file = "C:/Users/wooden/Desktop/dl/probCSV/swa_v2/swa_v2_notin800_official_韋智.csv")
    
    # 測試賽
    offical_noin800_1 = read.csv(file = "C:/Users/wooden/Desktop/dl/probCSV/densenet201_v2_2/official_notin_800.csv",fileEncoding="UTF-8-BOM")
    offical_noin800_2 = read.csv(file = "C:/Users/wooden/Desktop/dl/probCSV/resnet152V2_v1_2/official_notin_800.csv",   fileEncoding="UTF-8-BOM")
    offical_noin800_3 = read.csv(file = "C:/Users/wooden/Desktop/dl/probCSV/xception_v2_2/official_notin_800.csv",fileEncoding="UTF-8-BOM")
    offical_noin800_ex3 = read.csv(file = "C:/Users/wooden/Desktop/dl/probCSV/inceptionResNetV2_v1_2/official_notin_800.csv",fileEncoding="UTF-8-BOM")
    offical_noin800_ex4 = read.csv(file = "C:/Users/wooden/Desktop/dl/probCSV/densenet201_notin800_official_韋智.csv")
    offical_noin800_ex5 = read.csv(file = "C:/Users/wooden/Desktop/dl/probCSV/prob炫斐/official_notin_800.csv",fileEncoding="UTF-8-BOM")
    offical_noin800_ex6 = read.csv(file = "C:/Users/wooden/Desktop/dl/probCSV/swa_v2/swa_v2_notin800_official_韋智.csv")
    
  3. 找出每個中文字的閾值

    3.1 定義function:找出閾值最小值&平均機率

    #function:各字準確度&最小值&平均機率(定義閾值)
    get_acc_min = function(data_prob){
      word = unique(names(data_prob)[1:800])
      n = NULL
      acc = NULL
      mean_prob = NULL
      min_prob = NULL
      for(i in 1:length(word)){
        tmp = data_prob[data_prob$origin_word == word[i],]
        n[i] = nrow(tmp)
        acc[i] = round(sum(tmp$predict_word == tmp$origin_word)/n[i],4)
        if(any(tmp$predict_word == tmp$origin_word)){
          min_prob[i] = min(as.numeric(tmp[tmp$predict_word == tmp$origin_word,which(names(tmp) == word[i])]))
          mean_prob[i] = mean(as.numeric(tmp[tmp$predict_word == tmp$origin_word,which(names(tmp) == word[i])]))
        }
        else{
          min_prob[i] = 0
          mean_prob[i] = 0
        }
      }
      data_summary = data.frame(word = word,acc = acc,min_prob = min_prob,n = n,mean_prob = mean_prob)
      return(data_summary)
    }  
    

    3.2. 找出800個字的閾值,彙整後儲存CSV檔案

    # 取得模型各字準確度&最小值&平均機率 
    offical_in800_1_summary = get_acc_min(offical_in800_1)
    offical_in800_2_summary = get_acc_min(offical_in800_2)
    offical_in800_3_summary = get_acc_min(offical_in800_3)
    offical_in800_ex3_summary = get_acc_min(offical_in800_ex3)
    offical_in800_ex4_summary = get_acc_min(offical_in800_ex4)
    offical_in800_ex5_summary = get_acc_min(offical_in800_ex5)
    offical_in800_ex6_summary = get_acc_min(offical_in800_ex6)
    
    # 賦值
    final = offical_in800_1_summary
    final$acc_2 = offical_in800_2_summary$acc
    final$min_prob_2 = offical_in800_2_summary$min_prob
    final$mean_prob_2 = offical_in800_2_summary$mean_prob
    final$acc_3 = offical_in800_3_summary$acc
    final$min_prob_3 = offical_in800_3_summary$min_prob
    final$mean_prob_3 = offical_in800_3_summary$mean_prob
    final$acc_ex3 = offical_in800_ex3_summary$acc
    final$min_prob_ex3 = offical_in800_ex3_summary$min_prob
    final$mean_prob_ex3 = offical_in800_ex3_summary$mean_prob
    final$acc_ex4 = offical_in800_ex4_summary$acc
    final$min_prob_ex4 = offical_in800_ex4_summary$min_prob
    final$mean_prob_ex4 = offical_in800_ex4_summary$mean_prob
    final$acc_ex5 = offical_in800_ex5_summary$acc
    final$min_prob_ex5 = offical_in800_ex5_summary$min_prob
    final$mean_prob_ex5 = offical_in800_ex5_summary$mean_prob
    final$acc_ex6 = offical_in800_ex6_summary$acc
    final$min_prob_ex6 = offical_in800_ex6_summary$min_prob
    final$mean_prob_ex6 = offical_in800_ex6_summary$mean_prob
    
    names(final) = c("word","acc_1",'min_prob_1',"n",'mean_prob_1',
                  "acc_2","min_prob_2",'mean_prob_2',
                  "acc_3","min_prob_3",'mean_prob_3',
                  "acc_ex3","min_prob_ex3",'mean_prob_ex3',
                  "acc_ex4","min_prob_ex4",'mean_prob_ex4',
                  "acc_ex5","min_prob_ex5",'mean_prob_ex5',
                  "acc_ex6","min_prob_ex6",'mean_prob_ex6')
    
    final = final[,c("word","n",
                      "acc_1","acc_2","acc_3","acc_ex3","acc_ex4","acc_ex5","acc_ex6",
                      'min_prob_1',"min_prob_2","min_prob_3","min_prob_ex3","min_prob_ex4","min_prob_ex5","min_prob_ex6",
                  'mean_prob_1','mean_prob_2','mean_prob_3','mean_prob_ex3','mean_prob_ex4','mean_prob_ex5','mean_prob_ex6'
    )]
    
    # 儲存中文字標籤+ 該字出現n次 + ACC*7 + min_prob_1*7 + mean_prob_1*7個模型
    write.csv(final,file = "C:/Users/wooden/Desktop/dl/model/model_weight_V3.csv",row.names = F)
    

    3.3 輸出結果(以CSV檔顯示)

    • 欄位:中文字標籤、該字出現n次、7個模型ACC、7個模型min_prob、7個模型mean_prob。
    • 表格內容
  4. 任意選擇奇數個模型組合後,產生組合權重表,並利用模型權重得到新的機率表。

    4.1 定義function:任意組合模型(奇數個)。

    # 任意組合奇數個模型
    BitMatrix <- function(n){
      set <- 0:(2^n-1)
      rst <- matrix(0,ncol = n,nrow = 2^n)
      for (i in 1:n){
        rst[, i] = ifelse((set-rowSums(rst*rep(c(2^((n-1):0)), each=2^n)))/(2^(n-i))>=1, 1, 0)
      }
      rst
    }
    

    4.2 定義function:以官方800字內資料集機率表,組合模型後產出權重表。並利用模型權重得到新的機率表。

    get_new_model = function(namesmodel = c(1),stat = 'acc',dataset = "offical_in800"){
    
      new_stat = NULL
      for(i in 1:length(namesmodel)){
        if(stat == 'acc'){
          eval(parse(text = paste0("final$wei_",namesmodel[i]," = final$acc_",namesmodel[i],"/(",paste0('final$acc_',namesmodel,collapse = "+"),")")))
        }
        else{
          eval(parse(text = paste0("final$wei_",namesmodel[i]," = final$mean_prob_",namesmodel[i],"/(",paste0('final$mean_prob_',namesmodel,collapse = "+"),")")))
        }
        eval(parse(text = paste0("wei_matrix = matrix(final$wei_",namesmodel[i],",ncol = nrow(",dataset,"_",namesmodel[i],"),nrow = 800)")))
        wei_matrix = t(wei_matrix)
        if(i == 1){
          eval(parse(text = paste0("result = ",dataset,"_",namesmodel[i],"[,1:800]*wei_matrix")))
        }
        else{
          eval(parse(text = paste0("result = result + ",dataset,"_",namesmodel[i],"[,1:800]*wei_matrix")))
        }
      }
      if(dataset != "offical_noin800"){
        maxindex = apply(result,1,which.max)
        result$acc = final$word[maxindex]
        eval(parse(text = paste0("result$acc = ifelse(",dataset,"_",namesmodel[i],"$origin_word == final$word[maxindex],1,0)")))
      }
      if(dataset == 'test_data'){
        eval(parse(text = paste0("result$origin_word = ",dataset,"_",namesmodel[i],"$origin_word")))
      }
      new_stat = final
      eval(parse(text = paste0("new_stat$min_prob_new =  ",paste0("new_stat$min_prob_",namesmodel,"*new_stat$wei_",namesmodel,sep = "",collapse = '+'))))
      eval(parse(text = paste0("new_stat$mean_prob_new =  ",paste0("new_stat$mean_prob_",namesmodel,"*new_stat$wei_",namesmodel,sep = "",collapse = '+'))))
      result = list(result,new_stat)
      return(result)
    }  
    

    4.3 輸出結果(以CSV檔顯示)

    • 模型組合權重表(紅框處為組合權重)

    • 新機率表(紅框處代表是否正確預測,正確預測為1;錯誤預測為0)

  5. 判斷isnull

    5.1 定義function

    # 判斷isnull的Function
    get_min01 = function(namesmodel = c(1),stat = 'min_prob',dataset = "offical_in800",new_data = NULL,new_stat = NULL){
      if(is.null(new_data) & is.null(new_stat)){
        for(i in 1:length(namesmodel)){
          if(stat == 'min_prob'){
            eval(parse(text = paste0("tmp = ",dataset,"_",namesmodel[i],"[,1:800]")))
            eval(parse(text = paste0("min_prob_index = final$min_prob_",namesmodel[i])))
            min_01 = apply(tmp,1,FUN = function(x){
              maxindex = which.max(x)
              min_01 = ifelse(x[maxindex] >= min_prob_index[maxindex],0,1)
            })  
          }
          else{  
            eval(parse(text = paste0("tmp = ",dataset,"_",namesmodel[i],"[,1:800]")))
            eval(parse(text = paste0("mean_prob_index = final$mean_prob_",namesmodel[i])))
            min_01 = apply(tmp,1,FUN = function(x){
              maxindex = which.max(x)
              min_01 = ifelse(x[maxindex] >= mean_prob_index[maxindex],0,1)
            })
          }
          if(i == 1){
            result = min_01
          }
          else{
            result = result + min_01
          }
        }
        result = result/length(namesmodel)
        result = ifelse(result >= 0.5,1,0)
      }
      else{   
        if(stat == 'min_prob'){
          tmp = new_data[,1:800]
          min_prob_index = new_stat$min_prob_new
          min_01 = apply(tmp,1,FUN = function(x){
            maxindex = which.max(x)
            min_01 = ifelse(x[maxindex] >= min_prob_index[maxindex],0,1)
          })
        }
        else{
          tmp = new_data[,1:800]
          mean_prob_index = new_stat$mean_prob_new
          min_01 = apply(tmp,1,FUN = function(x){
            maxindex = which.max(x)
            min_01 = ifelse(x[maxindex] >= mean_prob_index[maxindex],0,1)
          })
        }
        result = min_01
      }
      return(result)
    }
    

小結

  1. 今天成功取得閾值、奇數的模型組合的權重表,並定義如何判斷isnull的function。
  2. 下一章的目標是:「交叉比對不同的模型組合方法,並選出其中最佳的」。

讓我們繼續看下去...


上一篇
【第20天】訓練模型-模型組合與辨識isnull(一)
下一篇
【第22天】訓練模型-模型組合與辨識isnull(三)
系列文
手寫中文字之影像辨識31
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言