iT邦幫忙

2021 iThome 鐵人賽

DAY 6
0
  • 昨天我們提過了 Regression 的流程就是有一個初始目標 -> 檢查"糟糕程度" -> 找到最好
  • 上面的流程也就是 Model -> Loss function -> Gradient Descent -> Find the best function
  • 那我們要如何實作呢?

情境假設

  • 我們先假定我們的目標函數是 https://chart.googleapis.com/chart?cht=tx&chl=%24y%20%3D%202%20*%20x%24 ,也就是我們的答案
  • 那在不知道答案的情況下,我們可以先假設 Model 會長 https://chart.googleapis.com/chart?cht=tx&chl=%24y%20%3D%20w%20*%20x%24 ,我們今天要解決的問題就是找出正確的 w
  • 那回到訓練的過程,一定需要資料對吧,因此我們就人為生成一些資料 https://chart.googleapis.com/chart?cht=tx&chl=%24x%20%3D%20%5B1%2C%202%2C%203%2C%204%2C%205%2C%206%5D%24 ,答案的部分則是 https://chart.googleapis.com/chart?cht=tx&chl=%24y%20%3D%20%5B2%2C%204%2C%206%2C%208%2C%2010%2C%2012%5D%24
  • 那就基於這樣的題目的話,我們要怎麼實作解決呢?

實作

  • 下面就不另外寫出來解釋了,有 python 基礎的 + 註解應該能夠理解
import numpy as np

# f = w * x
# f = 2 * x, we set w as 2
x = np.array([1, 2, 3, 4, 5, 6], dtype=np.float32)
y = np.array([2, 4, 6, 8, 10, 12], dtype=np.float32)

# init weight
w = 0.0

# model prediction
# 這邊之後會解釋為啥叫做 forward,可以先視為計算函數而已
def forward(x):
    
    return w * x
    
# set up loss function as mean square error
def loss(y, y_predicted):

    return ((y_predicted-y) ** 2).mean()
    
# gradient
def gradient(x, y, y_predicted):

    return np.dot(2*x, y_predicted-y).mean()
    
print(f'Prediction before training: f(5) = {forward(5): .3f}')

# Training
learning_rate = 0.01
n_iters = 10

for epoch in range(n_iters):
    # perdiction = forward pass
    y_pred = forward(x)

    # loss
    l = loss(y, y_pred)

    # we know that gradient descent is where 
    # calculate gradient and update parameters
    
    # calculation of gradients
    dw = gradient(x, y, y_pred)

    # update weights
    w -= learning_rate * dw

    if epoch % 1 == 0:
        print(f'epoch {epoch + 1}: w = {w:.3f}, loss = {l:.8f}')

print(f'Prediction after training: f(5) = {forward(5): .3f}')

每日小結

  • 從上面的 Example 應該可以發現,其實整體的計算和概念真的沒有很難,機器學習並沒有我們想像的複雜
  • 那今天我們學過了 Regression,明天讓我們聊聊 Classification 吧~

上一篇
Day-04 Python 的 Gradient 計算
下一篇
Day-06 Classification
系列文
Deep Learning 從零開始到放棄的 30 天 PyTorch 數字辨識模型31
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言