iT邦幫忙

2021 iThome 鐵人賽

DAY 21
0
自我挑戰組

資料分析及AI深度學習-簡單基礎實作系列 第 21

DAY21:優化器(上)

  • 分享至 

  • xImage
  •  

優化器

  • 優化器演算法

    • 在反向傳播的過程中,優化器的用途在於最小化損失函數的loss值,期望找到全局的最佳解。

    • 有可能會遇到昨天我們所說的鞍點,這時搭配好的學習率,是有機會避開鞍點的。

    • 沒有一個優化器是最好也最厲害的,所以選擇優化器的方式,小弟認為是多嘗試,等到經驗夠多,可能就可以知道哪些的資料集種類適合那些優化器。

    • 不同優化器的收斂軌跡動態圖

      圖片來源:https://lonepatient.top/2018/09/25/Cyclical_Learning_Rate


比較各種優化器事項

  1. 模型選定:用輕量型的預訓練模型densenet201做訓練,訓練時間也較快。

  2. 資料集:因中文手寫字圖片過多,決定節省時間,選定之前介紹過的驗證碼辨識的圖片共12000張訓練集,測試集3000張,驗證集1000張。

    • dataset建立
    import os
     from PIL import Image
     import torch
     from torch.utils.data import Dataset
     import pandas as pd
    
    
     alphabet = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'
    
    
     def img_loader(img_path):
         img = Image.open(img_path)
         return img.convert("RGB")
     # data_path = r"C:/Users/Frank/PycharmProjects/practice/captcha_recognition/picture/"
     # ans_path = r'C:\Users\Frank\PycharmProjects\practice\captcha_recognition\answer\answer2.csv'
     def make_dataset(data_path,ans_path,alphabet):
         img_names = os.listdir(data_path)
         img_names.sort(key=lambda x: int(x.split(".")[0]))
         df_ans = pd.read_csv(ans_path)
         ans_list = list(df_ans["code"].values)
         samples = []
    
         for ans, img_name in zip(ans_list, img_names):
             if len(str(ans)) == 5  :#num_char:
                 img_path = os.path.join(data_path, img_name)
                 target = []
                 for char in str(ans):
                     vec = [0] * 36 # num_class
                     vec[alphabet.find(char)] = 1
                     target += vec
                 samples.append((img_path, target))
             else:
                 print(img_name)
         return samples
    
    
     class CaptchaData(Dataset):
         def __init__(self, data_path,ans_path,
                      transform=None, target_transform=None, alphabet=alphabet):
             super(Dataset, self).__init__()
             self.data_path = data_path
             self.ans_path = ans_path
             # self.num_class = num_class
             # self.num_char = num_char
             self.transform = transform
             self.target_transform = target_transform
             self.alphabet = alphabet
             self.samples = make_dataset(self.data_path,self.ans_path,self.alphabet
                                         )
    
         def __len__(self):
             return len(self.samples)
    
         def __getitem__(self, index):
             img_path, target = self.samples[index]
             img = img_loader(img_path)
             if self.transform is not None:
                 img = self.transform(img)
             if self.target_transform is not None:
                 target = self.target_transform(target)
             return img, torch.Tensor(target)
    

上一篇
DAY20:學習率(下)
下一篇
DAY22:優化器(中)
系列文
資料分析及AI深度學習-簡單基礎實作30
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言