PyTorch研究團隊宣布2022/12/02要推出PyTorch 2.0,2023/03/15 正式推出,主要訴求特點:
根據『PyTorch 2.0 release explained』一文說明,主要的變動包括:
綜觀變動項目,都是涉及核心架構的變動,與開發者無關,除了『編譯』(compile)模型,原有程式都不需修改,即完全向後相容(Backward Compatibility)。
Windows作業系統的安裝指令如下:
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu117 -U
Linux/Mac作業系統的安裝指令如下:
pip install torch torchvision torchaudio -U
PyTorch 2.0最重要的是『編譯』(compile)模型,因為PyTorch採用動態運算圖(Computation Graph),效能較差,因此,新增『編譯』功能,可以將模型轉換為部分或完整的靜態運算圖,以縮短訓練、推論時間,內部作法如下圖,將使用者的程式改寫為運算圖區塊。
圖二. 『編譯』利用 TorchDynamo 改寫程式碼為運算圖
程式只要在訓練之前加一行程式碼 torch.compile(model) 即可。
import torch
import torchvision.models as models
model = models.resnet18().cuda()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
compiled_model = torch.compile(model)
x = torch.randn(16, 3, 224, 224).cuda()
optimizer.zero_grad()
out = compiled_model(x)
out.sum().backward()
optimizer.step()
筆者在下列平台上測試以上程式:
1.Windows作業系統:尚不支援。
2.Colaboratory:torch.compile執行無誤,但多花了44秒。
3.WSL(Windows subsyatem for Linux):CUDA capability需>7.0。
觀察較大型的程式是否因使用torch.compile,使訓練時間縮短。
Windows作業系統:筆者使用『YOLO v8 模型訓練實測』一文訓練YOLO模型,未修改程式,結果記憶體爆掉了,因此,PyTorch 2.0並不一定會減少記憶體的使用,甚至將批量縮小4倍,結果也是一樣。
Colaboratory:修改 ultralytics\yolo\engine\trainer.py,加入281行如下:
結果還是出錯。
KeyError: 'model.0.conv.weight'
PyTorch 2.0在Windows作業系統功能還未完善,記憶體的使用也增加,而torch.compile在Windows/Linux作業系統也還有問題,讀者可以等待後續版本再升級。或許是筆者功力不夠,還請大家不吝指正。
測試程式可自這裡下載。
以下為工商廣告:)。
深度學習PyTorch入門到實戰應用影音課程:
Scikit-learn 詳解與企業應用:
內容包括:
PyTorch:
開發者傳授 PyTorch 秘笈
TensorFlow:
深度學習 -- 最佳入門邁向 AI 專題實戰。