iT邦幫忙

2021 iThome 鐵人賽

DAY 16
0
  • 那我們之前看過了 Python 的 Easy Regression 實作,昨天也看過了 Pytorch 如何做到 Gradient Calculation,那我們今天就拿一樣的 Example 來看看如果事 Pytorch 會長怎樣吧~
  • 本篇範例是對應 Day-05 的 Easy Regression Example 去做 Framework 上面的比較

直接上 Code

  • 在這邊我們做了一個大更新,就是把 Graient 的計算交給了 Pytorch,可以發現複雜的微分工作已經交給了 Backpropagation 來處理
  • 這次的程式跟 Day-05 的最大差異就差在我們少了手工微分 Gradient function 了
import torch

# f = w * x
# f = 2 * x, we set w as 2
x = torch.tensor([1, 2, 3, 4, 5, 6], dtype=torch.float32)
y = torch.tensor([2, 4, 6, 8, 10, 12], dtype=torch.float32)

# init weight
# 這邊要注意,我們希望 Pytorch 幫我們計算更新的 Gradient 變數是 w,所以一定要開 requires_grad 在這個變數上
w = torch.tensor(0.0, dtype=torch.float32, requires_grad=True)

# model prediction
def forward(x):
    
    return w * x
    
# set up loss function as mean square error
def loss(y, y_predicted):

    return ((y_predicted-y) ** 2).mean()

# Training
learning_rate = 0.01
n_iters = 10

for epoch in range(n_iters):
    # perdiction = forward pass
    y_pred = forward(x)

    # loss
    l = loss(y, y_pred)

    # gradient descent is where calculate gradient and update parameters
    # so gradient descent here includes gradients and update weights
    # 原本在 Python 的 example 還需要自己建立 Gradient 函式
    # gradients = backward pass
    l.backward() # calculate dl/dw

    # update weights
    with torch.no_grad():
        w -= learning_rate * w.grad

    if epoch % 1 == 0:
        print(f'epoch {epoch + 1}: w = {w:.3f}, loss = {l:.8f}')
    
    # zero gradients,要記得歸零每次運算的 gradients,否則會累加
    w.grad.zero_()

print(f'Prediction after training: f(5) = {forward(5): .3f}')

每日小結

  • 從上述的 Code 可以發現大體上概念並沒有變化,但是這邊的 Gradient 計算交給了 Backpropagation 的反向傳遞的概念,因此省略了我們自己微分 loss function 的部分
  • 從上面的程式可以發現,利用 Pytorch Framework 確實在程式的撰寫上變得更加簡潔,我們需要自己特別操作運算建立的函式也變少了,但是基本元素和運算概念是和純 Python code 並無區別,這也是為甚麼我們前面要花那麼多的時間在介紹基本概念,因為就算是 Framework 也沒有跳脫這個框架,因此基本概念仍然是非常重要的,我們後面會示範從頭開始建立一個類神經網路的,可以敬請期待,我們明天會在示範 Pytorch 實作 Backpropagation

上一篇
Day-14 Pytorch 的 Gradient 計算
下一篇
Day-16 Pytorch 的 Training 流程
系列文
Deep Learning 從零開始到放棄的 30 天 PyTorch 數字辨識模型31
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言