iT邦幫忙

2022 iThome 鐵人賽

DAY 27
1
AI & Data

PyTorch 生態鏈實戰運用系列 第 27

[Day27] Weight Decay Regularization

  • 分享至 

  • xImage
  •  

前言

本篇將討論第三個預計介紹的Weight Decay Regularization技術,會先從L2開始講再帶到目前使用Optimizer的Weight Decay實作方式。

L2 Regularization

在講Weight Decay之前,首先先介紹其關係緊密相連的L2 Regularization。是一種利用L2-Norm(也就是所謂很常見的歐基理德距離):

來進行Regularization的方式。

那怎麼利用這個norm去作到Regularization呢?中心私路是這樣,一個overfitting的模型,一般被認為可能具有過度複雜以及不具無意義的pattern。而過度複雜的權重則很可能是這些過於取巧的pattern呈現的方式之一,L2 Regularization則是一種在學習時,利用L2-norm去限制權重的,進而避免overfitting的技術。另外,還有一種很常見的說法會說這個限制項,是對使用過大權重的懲罰(Penalty)項。

具體限制的方式則是把L2-norm混和加入loss再進行學習,可以寫成類似下面的概念式:

  • 一般在考慮到計算上具較佳的性質(例如微分較簡潔),因此會直接使用平方合(差一個根號)

如此一來在模型訓練的過程時,就會考量到權重的大小,避免使用過大,通常可以降低很多無意義的權重。還是不太懂的朋友可以參考Google爸爸的這個視覺化工具

Gradient and weight decay

那有了新的Loss函數後,下一步自然便是計算梯度並優化了。那是否有任何可以簡化梯度計算的方式呢?由於L2所增加的項,具有解析解,因此是有的:

簡化後的式子是原始LOSS所計算出的梯度、第二項則是原先權重的倍數,新增的第二項會讓有點權重逐漸衰變的意味在,也因此被稱作Weight Decay(起碼我是這樣理解的...)。

實作方式

在實作的部分,如果是最單純的SGD,想實作L2 Regularization的話,基本上便等同於在每次迭代時進行梯度的修改。因此可以看到文件中是有一個weight decay可以進行調整,便能直接實作。

但我們最常用的Adam呢?雖然裡頭一樣也有weight decay這個參數,可以進行類似SGD的方式來進行訓練,但Adam所採用的是Adaptive Gradient的方式,實質上與L2 Regularization並不等價,而且其表現也不佳,實驗結果可以參考這篇DECOUPLED WEIGHT DECAY REGULARIZATION

所幸的是,這篇論文提出了一種新的優化器叫AdamW,是一種修正了上述問題的方法。並且也有實作於PyTorch當中,我們也可以直接使用。文件參考,要注意的是這個優化器,預設weight decay rate是0.01,大部分優化器則是0。

實作結果

基本上一樣參考這個commit,我們也寫入新的config(這邊設定weight decay 是 0.02):

  optimizer: 
    name: 'AdamW'
    learning_rate: 0.001
    weight_decay: 0.02
    warmup_epochs: 5

經過訓練後,變可以得到下列的結果:

  • 收斂的結果跟效率都好上不少
  • 主因應該是這個搭配比較不會掉到所謂Adam的local optimal所導致

另外這裡我也有紀錄權重的L2-norm結果:

可以看到確實其L2-norm有不斷在下降,確實是有正常在工作。

本日小節

  • 使用AdamW進行Weight decay
  • 得到更加的結果

上一篇
[Day26] Data Augmentation
下一篇
[Day28] Lp Regularization
系列文
PyTorch 生態鏈實戰運用30
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言