iT邦幫忙

1

PyTorch 2.0 發布與新功能測試

  • 分享至 

  • xImage
  •  

前言

PyTorch研究團隊宣布2022/12/02要推出PyTorch 2.0,2023/03/15 正式推出,主要訴求特點:

  1. 速度更快。
  2. 與Python整合更好。
  3. 依舊使用動態運算圖(Eager execution),保持彈性。
    https://ithelp.ithome.com.tw/upload/images/20230321/20001976bcrIQucH4d.png
    圖一. PyTorch 2.0 標榜的口號

主要變動

根據『PyTorch 2.0 release explained』一文說明,主要的變動包括:

  1. 訓練速度增快30%,使用較少的記憶體(筆者實測恰巧相反,見後文)。
  2. 運算子集合縮小為250個,後端運算核心縮小,運算較快。
  3. 增強分散式處理功能。

綜觀變動項目,都是涉及核心架構的變動,與開發者無關,除了『編譯』(compile)模型,原有程式都不需修改,即完全向後相容(Backward Compatibility)。

PyTorch 安裝或升級

Windows作業系統的安裝指令如下:

pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu117 -U

Linux/Mac作業系統的安裝指令如下:

pip install torch torchvision torchaudio -U

編譯測試

PyTorch 2.0最重要的是『編譯』(compile)模型,因為PyTorch採用動態運算圖(Computation Graph),效能較差,因此,新增『編譯』功能,可以將模型轉換為部分或完整的靜態運算圖,以縮短訓練、推論時間,內部作法如下圖,將使用者的程式改寫為運算圖區塊。
https://ithelp.ithome.com.tw/upload/images/20230321/200019761h4ZwA7YlE.png
圖二. 『編譯』利用 TorchDynamo 改寫程式碼為運算圖

程式只要在訓練之前加一行程式碼 torch.compile(model) 即可。

import torch
import torchvision.models as models

model = models.resnet18().cuda()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
compiled_model = torch.compile(model)

x = torch.randn(16, 3, 224, 224).cuda()
optimizer.zero_grad()
out = compiled_model(x)
out.sum().backward()
optimizer.step()

筆者在下列平台上測試以上程式:
1.Windows作業系統:尚不支援。
2.Colaboratory:torch.compile執行無誤,但多花了44秒。
3.WSL(Windows subsyatem for Linux):CUDA capability需>7.0。

YOLOv8 測試

觀察較大型的程式是否因使用torch.compile,使訓練時間縮短。

  1. Windows作業系統:筆者使用『YOLO v8 模型訓練實測』一文訓練YOLO模型,未修改程式,結果記憶體爆掉了,因此,PyTorch 2.0並不一定會減少記憶體的使用,甚至將批量縮小4倍,結果也是一樣。

  2. Colaboratory:修改 ultralytics\yolo\engine\trainer.py,加入281行如下:
    https://ithelp.ithome.com.tw/upload/images/20230321/20001976R9yd30T5JS.png
    結果還是出錯。

KeyError: 'model.0.conv.weight'

結語

PyTorch 2.0在Windows作業系統功能還未完善,記憶體的使用也增加,而torch.compile在Windows/Linux作業系統也還有問題,讀者可以等待後續版本再升級。或許是筆者功力不夠,還請大家不吝指正。
測試程式可自這裡下載。

以下為工商廣告:)。
深度學習PyTorch入門到實戰應用影音課程:
https://ithelp.ithome.com.tw/upload/images/20230323/20001976OkmtrFnLWL.jpg

Scikit-learn 詳解與企業應用
https://ithelp.ithome.com.tw/upload/images/20230303/20001976OOxNMKpSLR.jpg
內容包括:

  1. 機器學習開發流程詳解。
  2. 各種演算法原理
  3. 自行開發演算法
  4. 眾多應用範例
  5. 還包括推薦、時間序列、半監督學習、MLOps...等。

PyTorch:
開發者傳授 PyTorch 秘笈
https://ithelp.ithome.com.tw/upload/images/20220531/20001976MhL9K2rsgO.png

TensorFlow:
深度學習 -- 最佳入門邁向 AI 專題實戰
https://ithelp.ithome.com.tw/upload/images/20220531/20001976ZOxC7BHyN3.jpg


圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

1 則留言

1
增廣建文
iT邦研究生 5 級 ‧ 2023-03-21 23:16:35

Yolov8已經有好幾個針對pytorch 2.0的PR了 看來可以再等等

讚,謝謝分享。

我要留言

立即登入留言