本篇將討論第三個預計介紹的Weight Decay Regularization技術,會先從L2開始講再帶到目前使用Optimizer的Weight Decay實作方式。
在講Weight Decay之前,首先先介紹其關係緊密相連的L2 Regularization。是一種利用L2-Norm(也就是所謂很常見的歐基理德距離):
來進行Regularization的方式。
那怎麼利用這個norm去作到Regularization呢?中心私路是這樣,一個overfitting的模型,一般被認為可能具有過度複雜以及不具無意義的pattern。而過度複雜的權重則很可能是這些過於取巧的pattern呈現的方式之一,L2 Regularization則是一種在學習時,利用L2-norm去限制權重的,進而避免overfitting的技術。另外,還有一種很常見的說法會說這個限制項,是對使用過大權重的懲罰(Penalty)項。
具體限制的方式則是把L2-norm混和加入loss再進行學習,可以寫成類似下面的概念式:
如此一來在模型訓練的過程時,就會考量到權重的大小,避免使用過大,通常可以降低很多無意義的權重。還是不太懂的朋友可以參考Google爸爸的這個視覺化工具。
那有了新的Loss函數後,下一步自然便是計算梯度並優化了。那是否有任何可以簡化梯度計算的方式呢?由於L2所增加的項,具有解析解,因此是有的:
簡化後的式子是原始LOSS所計算出的梯度、第二項則是原先權重的倍數,新增的第二項會讓有點權重逐漸衰變的意味在,也因此被稱作Weight Decay(起碼我是這樣理解的...)。
在實作的部分,如果是最單純的SGD,想實作L2 Regularization的話,基本上便等同於在每次迭代時進行梯度的修改。因此可以看到文件中是有一個weight decay
可以進行調整,便能直接實作。
但我們最常用的Adam呢?雖然裡頭一樣也有weight decay
這個參數,可以進行類似SGD的方式來進行訓練,但Adam所採用的是Adaptive Gradient的方式,實質上與L2 Regularization並不等價,而且其表現也不佳,實驗結果可以參考這篇DECOUPLED WEIGHT DECAY REGULARIZATION。
所幸的是,這篇論文提出了一種新的優化器叫AdamW,是一種修正了上述問題的方法。並且也有實作於PyTorch當中,我們也可以直接使用。文件參考,要注意的是這個優化器,預設weight decay rate是0.01,大部分優化器則是0。
基本上一樣參考這個commit,我們也寫入新的config(這邊設定weight decay 是 0.02):
optimizer:
name: 'AdamW'
learning_rate: 0.001
weight_decay: 0.02
warmup_epochs: 5
經過訓練後,變可以得到下列的結果:
另外這裡我也有紀錄權重的L2-norm結果:
可以看到確實其L2-norm有不斷在下降,確實是有正常在工作。