iT邦幫忙

第 12 屆 iThome 鐵人賽

DAY 6
1

文章說明

文章分段

  1. 文章說明
  2. 簡介datasets
  3. 描述模型訓練的流程
  4. 程式架構:上程式的虛擬碼,預告要寫的檔案、功能有哪些
  5. 主程式的實際撰寫,註明程式開始不能單跑的地方
  6. 總結

 

前情提要

.
└── project1
     ├── core/
     ├── datasets/
     ├── logs/
     ├── models/
     ├── evaluate.py
     ├── predict.py
     ├── README.md
     └── train.py

前面都在概括性地描述訓練流程,和project架構,ep.4有講到訓練主程式train.py的虛擬碼,今天則要繼續將其code在其中填滿。

讓我們繼續開始吧。

 

主程式的實際撰寫

def main():

    ## prepare arguments
    data_params = {
        'batch_size': FLAGS.batch_size,
        'dataset_size': len(list(tf.python_io.tf_record_iterator(FLAGS.train_tfrecords_path))),
        ...
    }
    
    model_params = {
        'optimizer': FLAGS.optimizer,
        'loss_fn': FLAGS.loss_fn,
        'learning_rate': FLAGS.learning_rate,
        ...
    }

我們可以透過python的dictionary來整理hyper parameter,主要可將data與model有關的參數分開放置,接著再將各參數,傳至datasets/core/的程式之中。

 

這系列處理資料流的部分,是透過tensorflow的tf.data API,下一篇會詳細介紹這個API,如何用它來customize dataset,這裡會先把使用情境帶過。

from datasets.data_generater import DataGenerater
def main():
    
    ## prepare data according to argument
    dgen = DataGenerater()
    train_datas, train_labels = dgen.get_train_samples(tfrecords_path=FLAGS.train_tfrecords_path, params=train_data_params)

處理data的程式會如同deeplab一樣,被寫為一個class,所有對dataset的操作都集中在這個class,像是解析tfrecord、shuffle、batch size設定、data augmentation等。

 

from core import md
def main():
    ## prepare model architecture according to arguments
    model = md.some_method_to_build_model(model_params)
    print(model.summary())

接著是從core/取得寫好的模型架構,tf.keras可以用model.summary()印出模型架構和參數數量。

 

from tensorflow.keras import optimizers
def main():
    ## configurate model with arguments
    optimizer = optimizers.Adam(lr=model_params['learning_rate'])
    loss = 'binary_crossentropy'
    model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])

接著定義model運算的部分,大致有optimizer、loss、metrics要定義。

 

def main():
    ## prepare arguments of saving model information
    ckpt_file_pattern = "weights-improvement-bce{val_loss:.2f}-{epoch:02d}.hdf5"
        checkpoint_filepath = os.path.join(TRAINED_BASE_DIR, 
                                           FLAGS.model_name, 
                                           'checkpoint', 
                                           ckpt_file_pattern)
    model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
        filepath=checkpoint_filepath,
        save_weights_only=True,
        monitor='val_loss',
        mode='auto',
        verbose=1,
        save_best_only=True)

callbacks中有個ModelCheckpoint的method,是用來設定如何儲存model的callbacks。

可以設定儲存的檔案名稱格式,還有是要儲存weight、還是weigth+architecture,keras的儲存方式比較偏向於monitor的數值最好,就儲存哪個;亦可選擇每隔幾個epoch就存一次weight,再自行決定要使用哪個時期的訓練結果。

 

def main():
    ## training model in keras way
    history = model.fit(
        x=train_datas,
        y=train_labels,
        steps_per_epoch=data_params['dataset_size'] // FLAGS.batch_size,
        validation_steps=val_data_params['dataset_size'] // FLAGS.batch_size,
        validation_data=(val_datas, val_labels),
        callbacks=[model_checkpoint_callback],
        epochs=FLAGS.epochs)

最後訓練就是用tf.keras的model本身帶有的method fit()進行訓練。比較要注意的是,如果是使用tf.data的話,validataion data必須另外設定,因為tf.data不支援fit()的validataion_split的argument。

總結

主程式的模樣就大概是以上介紹的形式,第一篇概論結束,接下來本來想講tf.estimator的架構以進行比較,但思來想去,還是想先介紹tf.data的部分,將一整個訓練流程建立起來,tf.keras與tf.estimator只是在管理模型與模型運算部分的方式有主要的不同,tf.data在處理資料的部分可以說是一模一樣。

感謝你的收看,明天再戰吧!


上一篇
一、用skeleton code解釋tensorflow model程式執行方式(tf.keras) ep.4:用虛擬碼開始解釋囉!
下一篇
二、教你怎麼看source code,找到核心程式碼(e.g. DeepLab)
系列文
從零.4開始我的深度學習之旅:從 用tf.data處理資料 到 用tf.estimator或tf.keras 訓練模型30
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言