前言:
藉由訓練過程中的檢查點紀錄,可以知道此模型的訓練次數,不過若不是特別需要,平常可以註解掉,讓訓練效率更快
程式碼:
checkpoint_path = "training_2/cp-{epoch:04d}.ckpt"
ckpt = tf.train.Checkpoint(optimizer=Adam(lr=1e-5), model=net_final)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=1)
cp_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=checkpoint_path,
verbose=1,
save_weights_only=True,
save_freq=835)
if os.path.exists(DATASET_PATH):
if os.path.exists(DATASET_PATH + WEIGHTS_FINAL):
print(WEIGHTS_FINAL + "模型存在,將繼續訓練模型")
# net_final.save(WEIGHTS_FINAL)
new_net_final = load_model(WEIGHTS_FINAL)
while True :
new_net_final.fit(train_batches,
steps_per_epoch = train_batches.samples // BATCH_SIZE,
validation_data = valid_batches,
validation_steps = valid_batches.samples // BATCH_SIZE,
epochs = NUM_EPOCHS,
callbacks = [cp_callback])
# 儲存訓練好的模型
print("儲存訓練模型")
new_net_final.save(WEIGHTS_FINAL)
else:
print(WEIGHTS_FINAL + '模型不存在,將新建訓練模型')
# 訓練模型
while True :
net_final.fit(train_batches,
steps_per_epoch = train_batches.samples // BATCH_SIZE,
validation_data = valid_batches,
validation_steps = valid_batches.samples // BATCH_SIZE,
epochs = NUM_EPOCHS,
callbacks = [cp_callback])
# 儲存訓練好的模型
print("儲存訓練模型")
net_final.save(WEIGHTS_FINAL)
else:
print(WEIGHTS_FINAL + '路徑不存在,請確認路徑')