獲得各模型預測字的機率表
資料總筆數
程式碼
import torch
import torch.nn as nn
import os
from dataset import CaptchaData
from torchvision.transforms import Compose, ToTensor
import csv
import copy
data_path = r"C:\Users\Frank\PycharmProjects\practice\mountain\清洗標籤final\train_all"
img_names = os.listdir(data_path)
source = img_names
title = copy.deepcopy(source)
title.append('predict')
title.append('true')
f = open('./densenet201_in800_official_nomask.csv', 'a',newline='')
w = csv.writer(f)
w.writerow(title)
f.close()
alphabet = ''.join(source)
def predict(img_dir):
n = 0
m = 0
transforms = Compose([
ToTensor()
])
dataset = CaptchaData(img_dir, transform=transforms)
model = torch.load('./best_densenet201_8.pth')
if torch.cuda.is_available():
model = model.cuda()
model.eval()
for k, (img, target) in enumerate(dataset):
img = img.view(1, 3 , 80 ,80 ).cuda()
target = target.view(1, 1 * 800).cuda()
output = model(img)
output = output.view(-1, 800)
target = target.view(-1, 800)
output_prob = nn.functional.softmax(output, dim=1)
output_prob_list = output_prob.cpu().detach().numpy().tolist()
output = torch.argmax(output_prob, dim=1)
target = torch.argmax(target, dim=1)
output = output.view(-1, 1)[0]
target = target.view(-1, 1)[0]
print('pred: ' + ''.join([alphabet[i] for i in output.cpu().numpy()]))
print('true: ' + ''.join([alphabet[i] for i in target.cpu().numpy()]))
pred = ''.join([alphabet[i] for i in output.cpu().numpy()])
true = ''.join([alphabet[i] for i in target.cpu().numpy()])
if pred == true:
n += 1
output_prob_list[0].append(pred)
output_prob_list[0].append(true)
# output_prob_list[0].append(1)
f = open('./densenet201_in800_official_nomask.csv', 'a',newline='')
w = csv.writer(f)
w.writerow(output_prob_list[0])
f.close()
else:
m += 1
output_prob_list[0].append(pred)
output_prob_list[0].append(true)
# output_prob_list[0].append(0)
f = open('./densenet201_in800_official_nomask.csv', 'a',newline='')
w = csv.writer(f)
w.writerow(output_prob_list[0])
f.close()
print("pred_acc:", n / (n + m))
print(m)
輸出的800字機率表
圖片來源:https://ithelp.ithome.com.tw/articles/10277916
判斷方法
選擇每個字的閾值
任意選擇奇數個模型組合後,產生模型權重表與利用新模型權重得到的機率表。
如何判斷isnull
之後的作法可參考訓練模型-模型組合與辨識isnull(二)以及訓練模型-模型組合與辨識isnull(三)。
因為同組的關係,我的隊友寫得又比我快比我好,我忍不住要來分享一下他的文,我們做法是這麼做的,在最後會來探討其他組別的作法。
後面判斷完isnull就剩下上GCP架設API,供比賽的時候使用。