iT邦幫忙

2021 iThome 鐵人賽

DAY 12
0
AI & Data

Deep Learning 從零開始到放棄的 30 天 PyTorch 數字辨識模型系列 第 12

Day-11 Backpropagation 介紹

  • 我們前面提到過深度學習就是模仿神經網路建構一個龐大的訓練模型,來達到特徵的選取(調整 weight 的數值來達到決定輸入特徵的權重),那我們看過 Gradient Descent 的數值更新狀況概念很簡單,但實際上我們可以想像當結構變複雜之時,我們可以預期 Gradient Descent 的計算將會變得太過複雜
  • Baskpropagation(反向傳遞法),就是希望讓 neural network 的 training 變得更加有效率
  • 回顧一下 Gradient Descent
    • Network parameters https://chart.googleapis.com/chart?cht=tx&chl=%24%5Ctheta%20%3D%20%7Bw_1%2C%20w_2%2C%20...%2C%20b_1%2C%20b_2%2C%20...%7D%24
    • 先選擇一個初始的參數 https://chart.googleapis.com/chart?cht=tx&chl=%24%5Ctheta%5E0%24 ,然後計算這個 https://chart.googleapis.com/chart?cht=tx&chl=%24%5Ctheta%5E0%24 對於我們的 loss function 的 Gradient https://chart.googleapis.com/chart?cht=tx&chl=%24%5Cnabla%20L(%5Ctheta%5E0)%24 ,也就是計算每一個 network 裡面的參數對於 https://chart.googleapis.com/chart?cht=tx&chl=%24%5Cnabla%20L(%5Ctheta)%24 的偏微分 https://chart.googleapis.com/chart?cht=tx&chl=%24%24%5Cnabla%20L(%5Ctheta)%20%3D%20%5Cleft%20%5B%20%5Cbegin%7Barray%7D%7Bcc%7D%20%5Cpartial%20L(%5Ctheta)%20%2F%20%5Cpartial%20w_1%20%5C%5C%20%5Cpartial%20L(%5Ctheta)%20%2F%20%5Cpartial%20w_2%20%5C%5C%20%5Ccdots%20%5C%5C%20%5Cpartial%20L(%5Ctheta)%20%2F%20%5Cpartial%20b_1%20%5C%5C%20%5Ccdots%20%5Cend%7Barray%7D%20%5Cright%20%5D%24%24
    • 那我們就會拿到 Gradient,這個 Gradient 會是一個 Vector,就可以利用 Vector 來更新我們的參數 https://chart.googleapis.com/chart?cht=tx&chl=%24%5Ctheta%5E1%20%3D%20%5Ctheta%5E0%20-%20%5Ceta%20%5Cnabla%20L(%5Ctheta%5E0)%24
  • 那我們會重複這個流程直到我們的期望次數,所以可以發現在一般的 Logistic Regression 或是 Linear Regression 在這邊的操作是沒太多區別的,唯一的問題是 Neural network 的參數非常的多,我們的 Gradient Vector 會變得非常巨大,所以如何有效地去計算這個 Vector,就是 Backpropagation 在做的事情
  • 所以 Backpropagation 並不是一個全新的方法,他說白了就是 Gradient Descent,只是它是一個更有效率的演算法,目的在於更有效率地去取得 Gradient Vector,這也是為什麼之後提到的 PyTorch Gradient Calculation 會交給 Backpropagation 做計算

About Backpropagation

  • 我們提到過 Backpropagation 可以想成一個更有效率的 Gradient Descent 了,那 Backpropagation 有沒有特別需要注意的部分呢?
  • 對於 Backpropagation 最重要的的觀念就是 Chain Rule(連鎖律)

Chain Rule

  • Chain Rule 連鎖律其實就是在強調數值之間的關係,那這邊為甚麼會這麼重要是因為回顧一下神經網路傳遞的方式,他們是一層一層的往下傳遞,因此就最終結果而言,其實是受到初始參數的影響一路往下層層變化的,那這些參數之間對於結果的關係是什麼?其實就會受到連鎖律的影響,因此基本的連鎖律概念我們在這裡簡單的幫大家 Summarize 一下
  • Case 1:
    • https://chart.googleapis.com/chart?cht=tx&chl=%24y%20%3D%20g(x)%2C%20z%20%3D%20h(y)%24 的話,如果 x 受到影響,會影響到 y ,進而影響到 z,也就是 https://chart.googleapis.com/chart?cht=tx&chl=%24%5Ctriangle%20x%20%5Cto%20%5Ctriangle%20y%20%5Cto%20%5Ctriangle%20z%24
    • 所以如果我們今天要計算 https://chart.googleapis.com/chart?cht=tx&chl=%24%7Bdz%20%5Cover%20dx%7D%24 ,可以先把它轉換成 https://chart.googleapis.com/chart?cht=tx&chl=%24%7Bdz%20%5Cover%20dy%7D%20%7Bdy%20%5Cover%20dx%7D%24
  • Case 2:
    • https://chart.googleapis.com/chart?cht=tx&chl=%24x%20%3D%20g(s)%2C%20y%20%3D%20h(s)%2C%20z%20%3D%20k(x%2C%20y)%24
    • 也就是說 https://chart.googleapis.com/chart?cht=tx&chl=%24%5Ctriangle%20s%20%5Cto%20%5Ctriangle%20x%20%5Cto%20%5Ctriangle%20z%24 ,還有 https://chart.googleapis.com/chart?cht=tx&chl=%24%5Ctriangle%20s%20%5Cto%20%5Ctriangle%20y%20%5Cto%20%5Ctriangle%20z%24 ,s 透過了兩個路徑去影響到了 z
    • 所以如果我們今天要計算 https://chart.googleapis.com/chart?cht=tx&chl=%24%7Bdz%20%5Cover%20ds%7D%24 ,可以先把它轉換成 https://chart.googleapis.com/chart?cht=tx&chl=%24%7B%5Cpartial%20z%20%5Cover%20%5Cpartial%20x%7D%20%7Bdx%20%5Cover%20ds%7D%20%2B%20%7B%5Cpartial%20z%20%5Cover%20%5Cpartial%20y%7D%20%7Bdy%20%5Cover%20ds%7D%24
  • 我們已經回顧了基本的 Chain Rule 在微分時會需要注意的部分,讓我們回到 Nueral Network

Basic Nueral Network

  • 我們回到基本的訓練過程去做思考,今天我們的 Nueral Network 在做訓練的過程是怎麼做訓練的?就是我們傳遞了一筆資料,經過神經網路的計算之後,會得到一個答案,那這個答案可能跟我們的預期答案有所落差,因此我們就可以利用這個落差的總和得到我們的 total loss

    • 所以這邊的 https://chart.googleapis.com/chart?cht=tx&chl=%24C%5En%24 就代表著 https://chart.googleapis.com/chart?cht=tx&chl=%24y%5En%24https://chart.googleapis.com/chart?cht=tx&chl=%24%5Chat%20y%5En%24 之間的落差
    • 那如果我們對 loss 和某一個 w 去做偏微分,我們可以發現就等於我把每個參數的 loss 對特定參數 w 的微分加總,就是 loss 對指定的 w 做偏微分了,因此我們之後就可以不用考慮去計算 https://chart.googleapis.com/chart?cht=tx&chl=%24%7B%5Cpartial%20L(%5Ctheta)%20%5Cover%20%5Cpartial%20w%7D%24 ,而改思考對某一筆 data 的 https://chart.googleapis.com/chart?cht=tx&chl=%24%7B%5Cpartial%20C%5En(%5Ctheta)%20%5Cover%20%5Cpartial%20w%7D%24 就可以了
    • https://chart.googleapis.com/chart?cht=tx&chl=%24L(%5Ctheta)%20%3D%20%5Csum%5Climits_%7Bn%3D1%7D%5EN%20C%5En(%5Ctheta)%20%5Cto%20%7B%5Cpartial%20L(%5Ctheta)%20%5Cover%20%5Cpartial%20w%7D%20%3D%20%5Csum%5Climits_%7Bn%3D1%7D%5EN%20%7B%5Cpartial%20C%5En(%5Ctheta)%20%5Cover%20%5Cpartial%20w%7D%24
  • 那我們從一個簡單的 Neural network 來看看,假設我們有一個 network 長下面這樣

  • 那我們從某一個 neuron 來看看



    https://chart.googleapis.com/chart?cht=tx&chl=%24z%20%3D%20x_1w_1%20%2B%20x_2w_2%20%2B%20b%24

  • 那我們今天要算 https://chart.googleapis.com/chart?cht=tx&chl=%24%7B%5Cpartial%20C%20%5Cover%20%5Cpartial%20w%7D%24 要怎麼算,依照 Chain Rule 我們可以拆成兩項,也就是 https://chart.googleapis.com/chart?cht=tx&chl=%24%7B%5Cpartial%20z%20%5Cover%20%5Cpartial%20w%7D%20%7B%5Cpartial%20C%20%5Cover%20%5Cpartial%20z%7D%24
  • 那計算 https://chart.googleapis.com/chart?cht=tx&chl=%24%7B%5Cpartial%20z%20%5Cover%20%5Cpartial%20w%7D%24 其實是非常簡單的,我們稱為 Forward pass,那計算 https://chart.googleapis.com/chart?cht=tx&chl=%24%7B%5Cpartial%20C%20%5Cover%20%5Cpartial%20z%7D%24 我們則稱為 Backward pass,那為啥要叫 forward 跟 backward 我們等等就知道了

Forward pass

  • 先來看看怎麼計算 Forward pass,我們前面有說我們的 https://chart.googleapis.com/chart?cht=tx&chl=%24z%20%3D%20x_1w_1%20%2B%20x_2w_2%20%2B%20b%24 了,所以如果我們希望計算 https://chart.googleapis.com/chart?cht=tx&chl=%24%7B%5Cpartial%20z%20%5Cover%20%5Cpartial%20w_1%7D%24 ,其實就是 https://chart.googleapis.com/chart?cht=tx&chl=%24x_1%24https://chart.googleapis.com/chart?cht=tx&chl=%24%7B%5Cpartial%20z%20%5Cover%20%5Cpartial%20w_2%7D%24 ,其實就是 https://chart.googleapis.com/chart?cht=tx&chl=%24x_2%24
  • 所以我們可以發現一個規律,當我們想找 https://chart.googleapis.com/chart?cht=tx&chl=%24%7B%5Cpartial%20z%20%5Cover%20%5Cpartial%20w%7D%24 ,事實上就是去看那個 w 前面接的參數,也就是這個神經元的輸入
  • 因此如果我們希望找到所有的 https://chart.googleapis.com/chart?cht=tx&chl=%24%7B%5Cpartial%20z%20%5Cover%20%5Cpartial%20w%7D%24 ,就必須先就算正向的參數,也就是我們 input 參數進入之後,一路往下到輸出的所有一層一層傳遞的參數,這也是為甚麼我們稱其為 forward pass,因為就是我們一般求輸出的正向運算
  • 那這邊也是為甚麼我們說找 https://chart.googleapis.com/chart?cht=tx&chl=%24%7B%5Cpartial%20z%20%5Cover%20%5Cpartial%20w%7D%24 是非常簡單的,因為根本就是輸入參數

Backward pass

  • 那如果我們已經知道 Forward pass 就是順向/正向運算,那 Backward pass 顧名思義應該就是反向運算了,但是要怎麼做呢?
  • 我們現在要算 https://chart.googleapis.com/chart?cht=tx&chl=%24%7B%5Cpartial%20C%20%5Cover%20%5Cpartial%20z%7D%24 ,我們知道 z 好取得,但是 C 就是要繼續往下看一路運算到最後結果,這是非常複雜的,那怎麼辦呢?那我們試著再用 Chain rule 拆解看看這一項


    from: ML Lecture 7: Backpropagation

  • 我們先假設接在 Z 後的 activation function(我們之後再解釋 QQ) 是 sigmoid function https://chart.googleapis.com/chart?cht=tx&chl=%24a%20%3D%20%5Csigma(z)%24 ,然後輸出了一個結果 https://chart.googleapis.com/chart?cht=tx&chl=%24a%24 ,那我們先不管後面的部分,我們在多了一個變數 https://chart.googleapis.com/chart?cht=tx&chl=%24a%24 之後,就可以利用 Chain rule 再把式子拆分成 https://chart.googleapis.com/chart?cht=tx&chl=%24%7B%5Cpartial%20C%20%5Cover%20%5Cpartial%20z%7D%20%3D%20%7B%5Cpartial%20a%20%5Cover%20%5Cpartial%20z%7D%20%7B%5Cpartial%20C%20%5Cover%20%5Cpartial%20a%7D%24
  • 那我們先來看 https://chart.googleapis.com/chart?cht=tx&chl=%24%7B%5Cpartial%20a%20%5Cover%20%5Cpartial%20z%7D%24 是什麼,我們已經知道 https://chart.googleapis.com/chart?cht=tx&chl=%24a%20%3D%20%5Csigma(z)%24 ,所以 https://chart.googleapis.com/chart?cht=tx&chl=%24%7B%5Cpartial%20a%20%5Cover%20%5Cpartial%20z%7D%24 其實就是 https://chart.googleapis.com/chart?cht=tx&chl=%24%5Csigma%5E%7B'%7D(z)%24 ,也就是 sigmoid function 的微分
  • https://chart.googleapis.com/chart?cht=tx&chl=%24%7B%5Cpartial%20C%20%5Cover%20%5Cpartial%20a%7D%24 應該長怎樣呢? 應該長 https://chart.googleapis.com/chart?cht=tx&chl=%24%7B%5Cpartial%20C%20%5Cover%20%5Cpartial%20a%7D%20%3D%20%7B%5Cpartial%20z%5E%7B'%7D%20%5Cover%20%5Cpartial%20a%7D%7B%5Cpartial%20C%20%5Cover%20%5Cpartial%20z%5E%7B'%7D%7D%20%2B%20%7B%5Cpartial%20z%5E%7B''%7D%20%5Cover%20%5Cpartial%20a%7D%7B%5Cpartial%20C%20%5Cover%20%5Cpartial%20z%5E%7B''%7D%7D%24

  • 那我們看上圖可以發現 https://chart.googleapis.com/chart?cht=tx&chl=%24%7B%5Cpartial%20z%5E%7B'%7D%20%5Cover%20%5Cpartial%20a%7D%24https://chart.googleapis.com/chart?cht=tx&chl=%24%7B%5Cpartial%20z%5E%7B''%7D%20%5Cover%20%5Cpartial%20a%7D%24 其實就是後面的 https://chart.googleapis.com/chart?cht=tx&chl=%24w_3%24https://chart.googleapis.com/chart?cht=tx&chl=%24w_4%24 ,那 https://chart.googleapis.com/chart?cht=tx&chl=%24%7B%5Cpartial%20C%20%5Cover%20%5Cpartial%20z%5E%7B'%7D%7D%24https://chart.googleapis.com/chart?cht=tx&chl=%24%7B%5Cpartial%20C%20%5Cover%20%5Cpartial%20z%5E%7B''%7D%7D%24 呢?怎麼感覺又繞回來一圈了?
  • 我們先整理一下現在 https://chart.googleapis.com/chart?cht=tx&chl=%24%7B%5Cpartial%20C%20%5Cover%20%5Cpartial%20z%7D%24 會長怎樣? https://chart.googleapis.com/chart?cht=tx&chl=%24%7B%5Cpartial%20C%20%5Cover%20%5Cpartial%20z%7D%20%3D%20%5Csigma%5E%7B'%7D(z)%5Bw_3%20%7B%5Cpartial%20C%20%5Cover%20%5Cpartial%20z%5E%7B'%7D%7D%20%2B%20w_4%20%7B%5Cpartial%20C%20%5Cover%20%5Cpartial%20z%5E%7B''%7D%7D%5D%24 ,換句話說我們其實只差最後一個步驟了,也就是我們只差知道 https://chart.googleapis.com/chart?cht=tx&chl=%24%7B%5Cpartial%20C%20%5Cover%20%5Cpartial%20z%5E%7B'%7D%7D%24https://chart.googleapis.com/chart?cht=tx&chl=%24%7B%5Cpartial%20C%20%5Cover%20%5Cpartial%20z%5E%7B''%7D%7D%24 整個問題就結束了,但是怎麼解?我們換個方向想
  • 如果我們從後面往前推,也就是我們把目標先放在答案那邊,從 output layer 往前推

  • 我們可以得到 https://chart.googleapis.com/chart?cht=tx&chl=%24%7B%5Cpartial%20C%20%5Cover%20%5Cpartial%20z%5E%7B'%7D%7D%20%3D%20%7B%5Cpartial%20y_1%20%5Cover%20%5Cpartial%20z%5E%7B'%7D%7D%7B%5Cpartial%20C%20%5Cover%20%5Cpartial%20y_1%7D%24https://chart.googleapis.com/chart?cht=tx&chl=%24%7B%5Cpartial%20C%20%5Cover%20%5Cpartial%20z%5E%7B''%7D%7D%20%3D%20%7B%5Cpartial%20y_2%20%5Cover%20%5Cpartial%20z%5E%7B''%7D%7D%7B%5Cpartial%20C%20%5Cover%20%5Cpartial%20y_2%7D%24 ,我們會發現因為 https://chart.googleapis.com/chart?cht=tx&chl=%24y_1%24https://chart.googleapis.com/chart?cht=tx&chl=%24y_2%24 都是已知了,因為我們正向運算一定會算出一個答案,我們的 https://chart.googleapis.com/chart?cht=tx&chl=%24%7B%5Cpartial%20C%20%5Cover%20%5Cpartial%20y_1%7D%24https://chart.googleapis.com/chart?cht=tx&chl=%24%7B%5Cpartial%20C%20%5Cover%20%5Cpartial%20y_2%7D%24 就可以利用 Cost function 來決定(例如 MSE),然後 https://chart.googleapis.com/chart?cht=tx&chl=%24%7B%5Cpartial%20y_1%20%5Cover%20%5Cpartial%20z%5E%7B'%7D%7D%24https://chart.googleapis.com/chart?cht=tx&chl=%24%7B%5Cpartial%20y_2%20%5Cover%20%5Cpartial%20z%5E%7B''%7D%7D%24 也可以運算了
  • 那如果現在不是在 output layer 呢?其實就是一直往下推一路到 output layer 就可以了,因為只有在 output layer ,我們才有辦法把 https://chart.googleapis.com/chart?cht=tx&chl=%24%7B%5Cpartial%20C%20%5Cover%20%5Cpartial%20z%5E%7B'%7D%7D%24 這種部份算出來
  • 所以概念上我們就是完全倒過來,從結果一路回推所有的 https://chart.googleapis.com/chart?cht=tx&chl=%24%7B%5Cpartial%20C%20%5Cover%20%5Cpartial%20z%7D%24

每日小結

  • Backpropagation 可以說是深度學習裡面最重要的觀念了,神經網路的構造複雜,本來就很難去計算和更新參數,因此普通的 Gradient Descent 會遇到很多計算上的困難, Backpropagation 則是利用 Chain Rule 的方式,將計算複雜度大大的下降,並利用一次 Forward pass 加一次 Backward pass 來達到快速更新參數計算參數的方式
  • 本日課程大量參考 李弘毅老師的開放式課程 ,這份教學非常非常好理解 Backpropagation,因此上面看不懂的部分都可以再去看看,筆者當初在學習的過程中,也深受此系列幫助
  • 到這裡我們已經完成了基本的觀念架設了,明天我們就可以開始介紹 PyTorch Framework 了~

上一篇
Day-10 深度學習的介紹
下一篇
Day-12 Pytorch 介紹
系列文
Deep Learning 從零開始到放棄的 30 天 PyTorch 數字辨識模型31
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言