iT邦幫忙

2021 iThome 鐵人賽

DAY 24
1
AI & Data

Deep Learning 從零開始到放棄的 30 天 PyTorch 數字辨識模型系列 第 24

Day-23 Model 可以重複使用嗎? 儲存和讀取 Model

  • 總算,我們已經會建立自己獨一無二的神經網路了~但,你有沒有發現一個問題,我們的該不會每次要使用模型之前,都要全部重頭來一遍吧?今天小小的資料跟小小的模型我都能夠接受這樣操作,但是...如果今天是百萬個神經元級別的,我總不能還是每次要使用之前都從頭來一次吧?
  • 當然不用,我們回頭思考一下,對於我們訓練模型來說,最重要的東西是什麼?就是我們模型中的變數對吧?只要我們記得最後訓練完取得的參數,其實就等於是我們訓練完的模型結果了阿~那我們在使用上也就是使用這些參數在操作,因此我們有沒有個辦法去記錄這些東西呢?只要我們能記錄一個模型的狀況,這樣有幾個好處,
    • 訓練完的模型可以直接在需要時做讀取使用
    • 訓練過程中如果持續有做紀錄,如果訓練不小心中斷,可以直接從中斷的地方開始訓練
    • 要從中間做調整也可以利用中間的資料開始訓練
  • 簡單來說,當我們成功取得儲存讀取心法,我們可以說我們就掌控了整個訓練,對於整個流程都更加的靈活強大了,所以就讓我們來聊聊如何對模型做儲存和讀取吧~

Save & Load Model

  • 模型的讀寫有分成兩種方式,一種方式我們稱為懶人法,另一種則是比較推薦正規的方式,我們會分別聊到,也會解釋差異

Lazy Way

  • Pytorch 提供了一個偷懶的方式,就是把整個 Model 儲存起來,那我們直接拿一個例子做舉例
import torch
import torch.nn as nn

class ExampleModel(nn.Module):
    
    def __init__(self, input_size):
        super(ExampleModel, self).__init__()
        self.linear = nn.Linear(input_size, 1)
    
    def forward(self, x):
        y_pred = torch.sigmoid(self.linear(x))
        
        return y_pred
        

model = ExampleModel(input_size=6)
  • 我們在初始宣告一個 Model 的時候,其實就會有初始的參數了,因此我們在這個時候去輸出我們的參數的話會變成這樣
print('Before saveing: ')
for parm in model.parameters():
    print(parm)
    
# before save
# Parameter containing:
# tensor([[-0.2966, -0.2289, -0.3195, -0.2210, -0.2217,  0.1012]],
#        requires_grad=True)
# Parameter containing:
# tensor([0.1014], requires_grad=True)
  • 我們可以看到這個時候我們已經有初始的 weights 跟 bias 了,那依照一般的訓練過程就是會拿這組參數去驗證資料,然後看 loss 的狀況等等一路往下做訓練
  • 所以讓我們把現在的模型狀況儲存起來,會用到的工具叫做 torch.save(arg, PATH),會需要兩個參數,方別是我們的 model 跟要儲存的位置 PATH,因此範例會長下面這樣
# save whole model
FILE = 'model_all.pt'
torch.save(model, FILE)
  • 那我們如果要使用這個儲存起來的模型,我們要怎麼去讀取呢?這時會利用到另一個函式 torch.load(PATH),只要給儲存的位置,就會自動處理讀取,我們看範例
# load model
model = torch.load(FILE)
  • 那這邊要注意一件事情,模型在讀取進來時,我們如果要使用評估模式(確保固定的推理狀況),或是訓練模式(確保可以有完整的訓練過程),需要宣告不同的 model 狀態,也就 model.eval()(評估模式)、model.train()(訓練模式)
  • 那我們今天要驗證資料,因此用 model.eval() 的評估模式來做資料檢查
model.eval()

print('whole model load')
for parm in model.parameters():
    print(parm)
    
# whole model load
# Parameter containing:
# tensor([[-0.2966, -0.2289, -0.3195, -0.2210, -0.2217,  0.1012]],
#        requires_grad=True)
# Parameter containing:
# tensor([0.1014], requires_grad=True)
  • 我們可以發現讀取進來的資料跟保存時的狀態一毛毛一樣樣,這就是我們希望的效果
  • 但是,將整個模型保存下來的方式是不被推薦的,詳細原因可以參考官方文件Pytroch saveing and loading model,簡單來說就是這樣的做法是比較不穩定的,容易造成模型的損毀
  • 那讓我們來看看被推薦的做法

Recommended way

  • 今天我們主要目標其實是當前模型的參數,模型本身的結構那些其實不是那麼重要,因為我們隨時可以自己重新建立這個結構,因此 Pytorch 提供了一個函式叫做 state_dict
  • state_dict 是一個簡單的Python字典對象,每個層映射到其參數張量。我們來看看範例
import torch
import torch.nn as nn

class ExampleModel(nn.Module):
    
    def __init__(self, input_size):
        super(ExampleModel, self).__init__()
        self.linear = nn.Linear(input_size, 1)
    
    def forward(self, x):
        y_pred = torch.sigmoid(self.linear(x))
        
        return y_pred
        

model = ExampleModel(input_size=6)
print('before save')
print(model.state_dict())

# before save
# OrderedDict([('linear.weight', tensor([[-0.0637,  0.2750, -0.3998,  0.2837, -0.2839,  0.3845]])), ('linear.bias', tensor([0.1257]))])
  • 那既然 state_dict 中已經儲存了我們足夠需要的資料了,那我們是不是可以只儲存 state_dict() 的資料?當然可以,所以就讓我們這麼做吧~
FILE = 'model_state_dict.pt'
model.save(model.state_dict(), FILE)
  • 那這邊要注意,我們已經沒有儲存整個模型的結構狀況了,因此在讀取資料時,方式有點不同,我們首先還是要宣告 model,但我們要改用儲存的 state_dict 參數代替原本初始的參數,來做剩下的行為,因此程式會變成這樣
model = ExampleModel(input_size=6)
model.load_state_dict(torch.load(FILE))
  • 那這樣就可以達到儲存參數的效果了~

Checkpoint Design

  • 那我們有提過如果我們能適當的儲存我們訓練的過程作為記錄點,會有助於我們在不管是
    • 中斷還原
    • 更改訓練狀況
      等等其他訓練的狀況的使用
  • 因此如何建立好的 Checkpoint 是一個很好的問題,那這邊我們就示範怎麼建立 Checkpoint(檢查點)
  • 常見的 Checkpoint 會包含
    • epoch
    • model_state_dict
    • optimizer_state_dict
    • loss
    • ...
  • 還有其他東西,都可以視情況做添加,因此如果我們要儲存這些資訊,我們要怎麼去儲存和讀取?
  • 儲存
model = TheModelClass(*args, **kwargs)
loss = LossFunctionClass()
optimizer = TheOptimizerClass(*args, **kwargs)

# traning loop
for epoch in range(num_epochs):
    ...
    
    torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss,
                ...
                }, PATH)
  • 讀取
model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)

checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

model.eval()
# - or -
model.train()

每日小結

  • 模型的狀態儲存是一個非常重要的議題,這決定了我們訓練的模型狀況可否被沿用,訓練狀況的保存挑整等等
  • Pytorch 提供了非常方便儲存訓練參數的方式,但是要記得裡面仍然有些許限制,建議大家去官網好好看看,這邊只是入門
  • 我們總算是把所有 PyTorch 的心法都說明清楚了~可喜可賀可喜可賀,明天讓我們聊聊在深度學習領域都會遇到的好朋友 CNN 之後,就可以開始我們的手寫辨識訓練了~

上一篇
Day-22 更加靈活的神經網路,我們可以做哪些變化
下一篇
Day-24 一定會見面,Convolutional Neural Network (CNN)
系列文
Deep Learning 從零開始到放棄的 30 天 PyTorch 數字辨識模型31
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言