從上一節可以看出,在模型做compiler時,可以指定 loss function,也能自定義客製的 loss function。 一旦 loss function 決定好後,之後的模型訓練就可以拿此損失函數內容來計算。 拿損失函數計算的運作,是位於模型的 compute_loss 函式; 而內建原本的內容,就是拿上一章節提到的 LossesContainer 來執行 call 函式。
compute_loss 函式式可以自訂客製的,也可以自行增加一些追蹤的屬性來協助觀察模型訓練的效果。
我們來看keras 提供的範例:
from tensorflow.keras.models import Sequential
from tensorflow.keras import layers
from tensorflow.keras.models import Model
import tensorflow as tf
class MyModel(Model):
def __init__(self, *args, **kwargs):
super(MyModel, self).__init__(*args, **kwargs)
self.loss_tracker = tf.keras.metrics.Mean(name='loss')
def compute_loss(self, x, y, y_pred, sample_weight):
loss = tf.reduce_mean(tf.math.squared_difference(y_pred, y))
loss += tf.add_n(self.losses)
self.loss_tracker.update_state(loss)
return loss
def reset_metrics(self):
self.loss_tracker.reset_states()
@property
def metrics(self):
return [self.loss_tracker]
tensors = tf.random.uniform((10, 10)), tf.random.uniform((10,))
dataset = tf.data.Dataset.from_tensor_slices(tensors).repeat().batch(1)
inputs = tf.keras.layers.Input(shape=(10,), name='my_input')
outputs = tf.keras.layers.Dense(10)(inputs)
model = MyModel(inputs, outputs)
model.add_loss(tf.reduce_sum(outputs))
optimizer = tf.keras.optimizers.SGD()
model.compile(optimizer, loss='mse', steps_per_execution=10)
model.fit(dataset, epochs=2, steps_per_epoch=10)
print('My custom loss: ', model.loss_tracker.result().numpy())
自製模型這邊初始時,新增了loss_tracker 屬性物件,此物件為 keras.metrics.base_metric.Mean 的實體,
繼承了keras.metrics.base_metric.Reduce、keras.metrics.base_metric.Metric。
在模型實作 compute_loss 中, 計算了自己的 loss value, 在將結果傳入 model.loss_tracker 物件的 update_state 函式, 會找到keras.metrics.base_metric.Reduce.update_state 去存結果到 loss_tracker.total 中。
模型在做fit時,每次epoch迴圈都會執行模型的reset_metrics,這邊會針對此模型使用到的 Metric 物件,也就是 model.loss_tracker 物件,針對此 Metric 執行自身的 reset_states 函式。 而真正執行 model.train_function時, 每層的預測值計算完後,會利用 model.compute_loss來計算其損失函數計算結果,這邊透過自定義計算完成後,在利用 model.loss_tracker 物件的 update_state 函式儲存結果。
透過這個範例,我們可以得知模型的 compute_loss 會與 metrics 有關,實際上真正做模型訓練時,compute_loss 會取得設定的 metrics 資料集,依據每種Metric物件, 迭代逐一執行update_state。
以上是透過keras提供的範例,透過觀察將其運作紀錄於此。