模型在做Compiler時,指定 loss function ,如下:
model.compile(optimizer="rmsprop",
loss="sparse_categorical_crossentropy",
metrics=["accuracy"])
我們可以指定 "sparse_categorical_crossentropy" 這個字串讓框架去找對應的模組來使用。那還有什麼方式可以指定 loss function,或有機會自己定義另外的算法來計算 ?
透過觀察,如果使用以上方式,於模型做 compiler 時,必定會建立與初始化 LossesContainer 物件,會將此物件指定到模型的compiled_loss屬性, 指定的 loss 參數也會包進來此物件中。
當LossesContainer建立完成時,並未執行build的動作,相對也就是其中被包進來的 loss 也未完全建立起來(如果是傳指定的loss function name)。
直到模型訓練的過程(參考模型fit文章),在某一輪訓練每層的計算完成後,會需要 loss function,這時會到 keras.engine.training.compute_loss 去找 loss function,因為偵測到 LossesContainer 並未執行過 build 的動作,所以這個時機點會先做 LossesContainer 的 build 。
途中會透過以下過程:
keras.engine.compile_utils._get_loss_object
-> keras.losses.get
-> keras.losses.deserialize
-> keras.saving.legacy.serialization.deserialize_keras_object
會於框架中的模組物件搜尋對應的 loss function 物件。
框架中的loss function 物件也依照傳入的字串 identifier 來找出對應的loss物件。其mapping的內容如下:
function :
'squared_hinge' :keras.losses.squared_hinge
'hinge' :keras.losses.hinge
'categorical_hinge' :keras.losses.categorical_hinge
'huber' :keras.losses.huber
'log_cosh' :keras.losses.log_cosh
'categorical_crossentropy' :keras.losses.categorical_crossentropy
'categorical_focal_crossentropy' :keras.losses.categorical_focal_crossentropy
'sparse_categorical_crossentropy' :keras.losses.sparse_categorical_crossentropy
'binary_crossentropy' :keras.losses.binary_crossentropy
'binary_focal_crossentropy' :keras.losses.binary_focal_crossentropy
'kl_divergence' :keras.losses.kl_divergence
'poisson' :keras.losses.poisson
'cosine_similarity' :keras.losses.cosine_similarity
'bce' :keras.losses.binary_crossentropy
'BCE' :keras.losses.binary_crossentropy
'mse' :keras.losses.mean_squared_error
'MSE' :keras.losses.mean_squared_error
'mae' :keras.losses.mean_absolute_error
'MAE' :keras.losses.mean_absolute_error
'mape' :keras.losses.mean_absolute_percentage_error
'MAPE' :keras.losses.mean_absolute_percentage_error
'msle' :keras.losses.mean_squared_logarithmic_error
'MSLE' :keras.losses.mean_squared_logarithmic_error
'kld' :keras.losses.kullback_leibler_divergence
'KLD' :keras.losses.kullback_leibler_divergence
'kullback_leibler_divergence' :keras.losses.kullback_leibler_divergence
'logcosh' :keras.losses.logcosh
'huber_loss' :keras.losses.huber_loss
class:
'Loss' :'keras.losses.Loss'
'LossFunctionWrapper' :'keras.losses.LossFunctionWrapper'
'MeanSquaredError' :'keras.losses.MeanSquaredError'
'MeanAbsoluteError' :'keras.losses.MeanAbsoluteError'
'MeanAbsolutePercentageError' :'keras.losses.MeanAbsolutePercentageError'
'MeanSquaredLogarithmicError' :'keras.losses.MeanSquaredLogarithmicError'
'BinaryCrossentropy' :'keras.losses.BinaryCrossentropy'
'BinaryFocalCrossentropy' :'keras.losses.BinaryFocalCrossentropy'
'CategoricalCrossentropy' :'keras.losses.CategoricalCrossentropy'
'CategoricalFocalCrossentropy' :'keras.losses.CategoricalFocalCrossentropy'
'SparseCategoricalCrossentropy' :'keras.losses.SparseCategoricalCrossentropy'
'CosineSimilarity' :'keras.losses.CosineSimilarity'
'Hinge' :'keras.losses.Hinge'
'SquaredHinge' :'keras.losses.SquaredHinge'
'CategoricalHinge' :'keras.losses.CategoricalHinge'
'Poisson' :'keras.losses.Poisson'
'LogCosh' :'keras.losses.LogCosh'
'KLDivergence' :'keras.losses.KLDivergence'
'Huber' :'keras.losses.Huber'
如果找回的是 function 物件,則會透過 keras.losses.LossFunctionWrapper 類別將 loss function 物件融入進來,讓之後計算可以透過keras.losses.LossFunctionWrapper.call 來叫用。
自定義 loss function 類別
那如果自定義 loss function 要如何實作類別? 以下做簡單的範例來實現:
(1) 繼承keras.losses.LossFunctionWrapper ,並傳入loss function物件(或自己客製):
class MySparseCategoricalCrossentropy(LossFunctionWrapper):
def __init__(
self, fn , from_logits=False, ignore_class=None, reduction=losses_utils.ReductionV2.AUTO,
name="sparse_categorical_crossentropy",
):
super().__init__(
fn,
name=name,
reduction=reduction,
from_logits=from_logits,
ignore_class=ignore_class,
)
(2) 繼承 keras.losses.Loss , 並實作 call 函式:
class MySparseCategoricalCrossentropy(keras.losses.Loss):
def call(self, y_true, y_pred):
return keras.losses.sparse_categorical_crossentropy(y_true, y_pred, from_logits=False,
axis=-1, ignore_class=None)
所以這些自定義的方式,還可以加上客製的處理,做一些需要額外的運算、更有效率的公式運作等等。
以上為觀察並記錄於此。