iT邦幫忙

2023 iThome 鐵人賽

DAY 17
0
AI & Data

從Keras框架與數學概念了解機器學習系列 第 17

[從Keras框架與數學概念了解機器學習] - 17. 自定義Callback與如何應用

  • 分享至 

  • xImage
  •  

https://ithelp.ithome.com.tw/upload/images/20230917/20144614x6oE4N21Q1.jpg

在模型做fit的章節,我們可以看到在訓練前會將Callback實體放入一個Container,然後於真正訓練迭代迴圈時使用。這邊會自定義Callback類別來看如何運作。

from tensorflow.keras import callbacks as callbacks_module

class MyCallbackWithBatchHooks(callbacks_module.Callback):
    def __init__(self):
        self.train_batches = 0
        self.test_batches = 0
        self.predict_batches = 0
        
    def on_train_batch_end(self, batch, logs=None):
        self.train_batches += 1
        print('MyCallbackWithBatchHooks',' on_train_batch_end')
        
    def on_test_batch_end(self, batch, logs=None):
        self.test_batches += 1
        print('MyCallbackWithBatchHooks',' on_test_batch_end')
        
    def on_predict_batch_end(self, batch, logs=None):
        self.predict_batches += 1
        print('MyCallbackWithBatchHooks',' on_predict_batch_end')
    
class MyCallbackWithBatchHooks2(callbacks_module.Callback):
    def __init__(self):
        self.train_batches = 0
        self.test_batches = 0
        self.predict_batches = 0
        
    def on_train_batch_end(self, batch, logs=None):
        self.train_batches += 1
        print('MyCallbackWithBatchHooks2',' on_train_batch_end')
        
    def on_test_batch_end(self, batch, logs=None):
        self.test_batches += 1
        print('MyCallbackWithBatchHooks2',' on_test_batch_end')
        
    def on_predict_batch_end(self, batch, logs=None):
        self.predict_batches += 1
        print('MyCallbackWithBatchHooks2',' on_predict_batch_end')
    
my_cb = MyCallbackWithBatchHooks()
my_cb2 = MyCallbackWithBatchHooks2()
cb_list = callbacks_module.CallbackList([my_cb2,my_cb], verbose=0)
cb_list.on_train_batch_end(0)
cb_list.on_test_batch_end(0)

print(my_cb2.train_batches, my_cb.train_batches)
print(my_cb2.test_batches, my_cb.test_batches)
print(my_cb2.predict_batches , my_cb.predict_batches)

依照範例,先宣告一個繼承 keras.callbacks.Callback 的MyCallbackWithBatchHooks實體,送入callback的Container類別,即 keras.callbacks.CallbackList。 Container實體的建立過程,先至keras.callbacks.CallbackList.__init__依照傳入的參數初始化相關屬性。

Container收集好指定傳入的callback物件後,會依照自己提供的function規格來逐一偵測執行所有的callback物件相對應的函式,如果callback物件沒有定義,則會往父類別keras.callbacks.Callback找,而父類別通常只是建立規格,沒有任何需要執行的程式內容。

換句話說,Callback Container提供了幾個介面讓callback物件可以執行自己內部的對應函式,這些介面可以針對不同的時機點呼叫 。

以下列出這些介面:
on_batch_begin
on_batch_end
on_epoch_begin
on_epoch_end
on_train_batch_begin
on_train_batch_end
on_test_batch_begin
on_test_batch_end
on_predict_batch_begin
on_predict_batch_end
on_train_begin
on_train_end
on_test_begin
on_test_end
on_predict_begin
on_predict_end

根據這些提供的介面,容器 "批次" 執行所有已註冊進來的callback物件。
再看此範例,當容器執行 cb_list.on_train_batch_end,就會去看所註冊的callback,my_cb2與my_cb,此二個的 on_train_batch_end 內容都被執行。

要特別注意的是,註冊callback的順序。本例特別將 my_cb2 物件擺在 my_cb 前面,而執行的順序也如註冊順序一樣。

以上是觀察callback Container大致上的運作。


上一篇
[從Keras框架與數學概念了解機器學習] - 16. 自定義 activation 與內建 activation
下一篇
[從Keras框架與數學概念了解機器學習] - 18. loss function 的使用與自定義方式
系列文
從Keras框架與數學概念了解機器學習30
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言