iT邦幫忙

0

梯度下降法(9) -- 損失函數

  • 分享至 

  • xImage
  •  

個人認為,損失函數是神經網路最重要的核心,當我們定義好一個損失函數,梯度下降法就會想辦法最小化損失函數,求得最佳解,例如第五篇的聯立方程式,定義每個方程式的損失平方和愈小愈好,梯度下降法就可求出近似解,又例如風格轉換(Style Transfer)演算法,定義總損失函數為輸出圖與原圖的差異加上輸出圖與風格圖的差異,即可生成最接近原圖且具有特殊風格的圖像。
https://ithelp.ithome.com.tw/upload/images/20250618/20001976zVoGbyWIHc.png
圖一. 左邊為原圖,中間為風格圖,右邊為輸出圖

TensorFlow/Keras提供的損失函數

TensorFlow/Keras內建許多損失函數,概分為3類:

  1. 迴歸損失(Regression losses):適用於預測連續型變數,最常用的是均方誤差(Mean Squared Error, MSE),為Σ(實際值 - 預測值) ** 2/n,可用於許多案例,包括迴歸、聯立方程式求解...等,其他損失函數還有Mean Absolute Error(MAE)、Log MSE、Huber loss(降低離群值的影響)以及衡量相似度的Cosine Similarity。
    https://ithelp.ithome.com.tw/upload/images/20250619/20001976vqKSTPTCEx.png
    圖二. 簡單線性迴歸

  2. 機率損失(Probabilistic losses):適用於分類(Classification),例如手寫阿拉伯數字辨識,假設要辨識0,A模型預測為8,B模型預測為9,我們不能以MSE衡量準確度,(8-0) ** 2 < (9-0) ** 2,認定A模型預測準確度高於B模型,應為兩模型均預測錯誤。SparseCategoricalCrossentropy是一個較特別的損失函數,它會將訓練資料的標記(Y)進行One-hot encoding,再與預測值作比較。

  3. Hinge losses:適用於以最大間隔(Maximum margin)為目標的演算法,例如支援向量機(SVM)。

如果使用錯誤類別的損失函數,訓練出來的模型準確度會很離譜。

實測

第一篇的手寫阿拉伯數字辨識為例(01_tf1.py)。
範例1. 手寫阿拉伯數字辨識,改採MSE的實測。

  1. 修改01_tf1.py的模型:最後一層的輸出神經元個數為1,表預測一個連續型變數(Y)。
model = tf.keras.models.Sequential([
  tf.keras.layers.Input((28, 28)),
  tf.keras.layers.Flatten(),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(1) 
])
  1. 修改損失函數為MeanSquaredError,即MSE。
model.compile(optimizer='adam',
  loss=tf.keras.losses.MeanSquaredError(), # MSE
  metrics=['accuracy'])
  1. 另存為26_MNIST_MSE.py。
  2. 執行:python 26_MNIST_MSE.py。
  3. 執行結果:測試資料準確率只有16%。
    https://ithelp.ithome.com.tw/upload/images/20250619/20001976eMT5ZoGftj.png

風格轉換(Style Transfer)

除了TensorFlow/Keras內建許多損失函數外,要發明一個新穎的演算法常會自訂損失函數,例如生成式AI,包括風格轉換(Style Transfer)、生成對抗網路(GAN)、擴散模型(Diffusion Model)...等。其中風格轉換是將一張內容圖或稱原圖(Content image)轉換為具有特殊風格(Style)的合成圖,演算法定義的損失函數如下:
總損失函數 = 內容損失(Content loss) + 風格損失(Style loss)
其中:
內容損失 = 輸出圖與內容圖的差異,類似MSE,輸出圖與內容圖特徵的差異平方和,下列公式的P為內容圖特徵,F為輸出圖特徵。
https://ithelp.ithome.com.tw/upload/images/20250619/20001976ARsHDGBFm0.png
風格損失 = 輸出圖與風格圖的差異,演算法利用Gram matrix定義兩個特徵的關聯性,藉以表達風格的差異性,Gram定義如下:
https://ithelp.ithome.com.tw/upload/images/20250619/20001976I91Dkq2hUu.png
風格損失如下:
https://ithelp.ithome.com.tw/upload/images/20250619/20001976OOS2dbsnQW.png
其中E:
https://ithelp.ithome.com.tw/upload/images/20250619/20001976mZ5GdpKQ9k.png
G為輸出圖的Gram,A為風格圖的Gram。

相關的內容可詳閱『A Neural Algorithm of Artistic Style』,而TensorFlow官方教學網頁也提供一個完整的範例『Neural style transfer』,讀者可研讀及執行該程式,資料均在程式中下載。另一個範例『Fast Style Transfer for Arbitrary Styles』可以自TF Hub下載訓練好的模型,體驗各種轉換的效果。

生成對抗網路(GAN)

生成對抗網路(Generative Adversarial Networ, GAN)也是一個非常有意思的神經網路,它含有兩個子網路--生成模型(Generative model)及判別模型(Discriminative model),兩者互相對抗,好比一個遊戲有兩個角色,一個是偽造者(counterfeiter),他不斷製造假畫,另一個角色是警察,不斷從偽造者那邊拿到假畫,判斷是真或假,然後,偽造者就根據警察判斷結果的回饋,不斷改良,最後假畫變成真假難辨。

判別模型及生成模型的損失函數如下:
https://ithelp.ithome.com.tw/upload/images/20250619/20001976sHzckFVsQJ.png

同樣以梯度下降法求解,程序如下:
https://ithelp.ithome.com.tw/upload/images/20250619/20001976aTupfg3bbu.png

以上的模型使用卷積神經網路,故稱DCGAN(Deep Convolutional Generative Adversarial Network),相關的程式碼可參閱『Deep Convolutional Generative Adversarial Network』

GAN的變形非常多,他們稍微改變了損失函數及模型結構,就可生成不同風格的圖像,包括深度偽造(Deep fake)。
https://ithelp.ithome.com.tw/upload/images/20250619/200019766BJMvDSsmv.png
圖三. 各式GAN變形的損失函數,部分截圖來自『hwalsuklee/tensorflow-generative-model-collections』

結語

談到這裡,我們可以回顧手寫阿拉伯數字辨識的程式碼,透過一系列的說明,每一行程式及參數是否都已詳細說明?如有疏漏,還請各位讀者提醒及指正。

import tensorflow as tf
mnist = tf.keras.datasets.mnist

(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam',
  loss='sparse_categorical_crossentropy',
  metrics=['accuracy'])

model.fit(x_train, y_train, epochs=5)
model.evaluate(x_test, y_test)

下一篇,我們將再整理一份資料,說明各種常見的神經網路發展的動機及關聯,希望能幫助讀者快速掌握各式神經網路的重點。

工商廣告:)

《深度學習最佳入門與專題實戰》導讀講座 2025/07/11 歡迎報名

書籍:

  1. 深度學習最佳入門與專題實戰:理論基礎與影像篇, 2025/05 再版
  2. 深度學習最佳入門與專題實戰:自然語言處理、大型語言模型與強化學習篇, 2025/05 再版
  3. 開發者傳授 PyTorch 秘笈
  4. Scikit-learn 詳解與企業應用

影音課程:

深度學習PyTorch入門到實戰應用

企業包班

系列文章目錄

徹底理解神經網路的核心 -- 梯度下降法 (1)
徹底理解神經網路的核心 -- 梯度下降法 (2)
徹底理解神經網路的核心 -- 梯度下降法 (3)
徹底理解神經網路的核心 -- 梯度下降法 (4)
徹底理解神經網路的核心 -- 梯度下降法的應用 (5)
梯度下降法(6) -- 學習率動態調整
梯度下降法(7) -- 優化器(Optimizer)
梯度下降法(8) -- Activation Function
梯度下降法(9) -- 損失函數
梯度下降法(10) -- 總結


圖片
  熱門推薦
圖片
{{ item.channelVendor }} | {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言