在模型做fit的章節,我們可以看到在訓練前會將Callback實體放入一個Container,然後於真正訓練迭代迴圈時使用。這邊會自定義Callback類別來看如何運作。
from tensorflow.keras import callbacks as callbacks_module
class MyCallbackWithBatchHooks(callbacks_module.Callback):
def __init__(self):
self.train_batches = 0
self.test_batches = 0
self.predict_batches = 0
def on_train_batch_end(self, batch, logs=None):
self.train_batches += 1
print('MyCallbackWithBatchHooks',' on_train_batch_end')
def on_test_batch_end(self, batch, logs=None):
self.test_batches += 1
print('MyCallbackWithBatchHooks',' on_test_batch_end')
def on_predict_batch_end(self, batch, logs=None):
self.predict_batches += 1
print('MyCallbackWithBatchHooks',' on_predict_batch_end')
class MyCallbackWithBatchHooks2(callbacks_module.Callback):
def __init__(self):
self.train_batches = 0
self.test_batches = 0
self.predict_batches = 0
def on_train_batch_end(self, batch, logs=None):
self.train_batches += 1
print('MyCallbackWithBatchHooks2',' on_train_batch_end')
def on_test_batch_end(self, batch, logs=None):
self.test_batches += 1
print('MyCallbackWithBatchHooks2',' on_test_batch_end')
def on_predict_batch_end(self, batch, logs=None):
self.predict_batches += 1
print('MyCallbackWithBatchHooks2',' on_predict_batch_end')
my_cb = MyCallbackWithBatchHooks()
my_cb2 = MyCallbackWithBatchHooks2()
cb_list = callbacks_module.CallbackList([my_cb2,my_cb], verbose=0)
cb_list.on_train_batch_end(0)
cb_list.on_test_batch_end(0)
print(my_cb2.train_batches, my_cb.train_batches)
print(my_cb2.test_batches, my_cb.test_batches)
print(my_cb2.predict_batches , my_cb.predict_batches)
依照範例,先宣告一個繼承 keras.callbacks.Callback 的MyCallbackWithBatchHooks實體,送入callback的Container類別,即 keras.callbacks.CallbackList。 Container實體的建立過程,先至keras.callbacks.CallbackList.__init__依照傳入的參數初始化相關屬性。
Container收集好指定傳入的callback物件後,會依照自己提供的function規格來逐一偵測執行所有的callback物件相對應的函式,如果callback物件沒有定義,則會往父類別keras.callbacks.Callback找,而父類別通常只是建立規格,沒有任何需要執行的程式內容。
換句話說,Callback Container提供了幾個介面讓callback物件可以執行自己內部的對應函式,這些介面可以針對不同的時機點呼叫 。
以下列出這些介面:
on_batch_begin
on_batch_end
on_epoch_begin
on_epoch_end
on_train_batch_begin
on_train_batch_end
on_test_batch_begin
on_test_batch_end
on_predict_batch_begin
on_predict_batch_end
on_train_begin
on_train_end
on_test_begin
on_test_end
on_predict_begin
on_predict_end
根據這些提供的介面,容器 "批次" 執行所有已註冊進來的callback物件。
再看此範例,當容器執行 cb_list.on_train_batch_end,就會去看所註冊的callback,my_cb2與my_cb,此二個的 on_train_batch_end 內容都被執行。
要特別注意的是,註冊callback的順序。本例特別將 my_cb2 物件擺在 my_cb 前面,而執行的順序也如註冊順序一樣。
以上是觀察callback Container大致上的運作。