iT邦幫忙

2024 iThome 鐵人賽

DAY 22
0
自我挑戰組

30天初探tensorflow之旅系列 第 22

Day 22 如何儲存和讀取模型

  • 分享至 

  • xImage
  •  

TensorFlow 有兩種存取模型的方式,
1.SavedModel 格式:
它是 TensorFlow 的標準格式,包含模型的架構、權重、訓練配置等。
優點是可以跨平台和語言使用,支持自定義層和操作,方便進行版本控制。
2. HDF5 格式:
它是一種通用的文件格式,主要用於儲存數據。
優點是文件格式簡單,適合小型模型,易於與其他應用程序集成。

它們的差別是 SaveModel 通用性較高,儲存和載入的過程也比較複雜,較 HDF5 適合生產環境,HDF5 結構就相較簡單,它適合的是快速原型開發或測試簡單模型。

那我們先簡單建立和訓練模型來實際做做看,這裡使用 tf.keras.Sequential 來建立模型,中間使用了 3 層卷積層搭配池化層:

import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.keras import layers
train = tfds.load('mnist', split='train', as_supervised=True)
test = tfds.load('mnist', split='test', as_supervised=True)
def format_image(image, label):
  image = tf.cast(image, dtype=tf.float32)
  image = image / 255.0
  return  image, label
BATCH_SIZE = 32
BUFFER_SIZE = 10000
train_batches = train.cache().shuffle(BUFFER_SIZE).map(format_image).batch(BATCH_SIZE).prefetch(1)
test_batches = test.cache().map(format_image).batch(BATCH_SIZE).prefetch(1)
model = tf.keras.Sequential([
  layers.Conv2D(16, kernel_size=(3, 3), activation='relu', input_shape=(28, 28, 1)),
  layers.MaxPool2D(pool_size=(2, 2)),
  layers.Conv2D(32, kernel_size=(3, 3), activation='relu'),
  layers.MaxPool2D(pool_size=(2, 2)),
  layers.Conv2D(64, kernel_size=(3, 3), activation='relu'),
  layers.MaxPool2D(pool_size=(2, 2)),
  layers.Flatten(),
  layers.Dense(512, activation='relu'),
  layers.Dense(10, activation='softmax')
])
model.compile(
  optimizer='adam', 
  loss='sparse_categorical_crossentropy', 
  metrics=['accuracy']
)
EPOCH = 5
history = model.fit(
  train_batches, 
  validation_data=test_batches, 
  epochs=EPOCH
)

1.將模型參數存成 TensorFlow SavedModel 格式,裡面除了包含訓練好的參數之外,TensorFlow 的計算圖與程式邏輯都包含在這個檔案格式裡。
儲存方式:

export_path = 'saved_model'
tf.saved_model.save(model, export_path)

特別註明若以 SavedModel 格式儲存模型時,會產生三種主要檔案和資料夾:
(1)saved_model.pb
它是模型的描述文件,包含模型的計算圖結構和元數據,它描述了模型的架構、輸入和輸出張量等信息。
(2)variables資料夾
這個資料夾會包含模型的權重和變數,通常有兩個檔案,分別是用來記錄模型變數的索引 variables.index ,
以及存儲模型權重的數據檔案 variables.data-00000-of-00001 。
(3)assets資料夾
這個資料夾通常用來存儲任何額外的資源或文件,模型在推理時可能需要使用到的資料,但不是所有模型都會需
要它。
讀取方式:

reload_sm = tf.saved_model.load(export_path)

2.利用 Keras 將模型存成 HDF5 格式,附檔名會是 '.h5',這是 Keras 模型默認的格式。
儲存方式:

model.save('./keras_model.h5')

讀取方式:

reload_model = tf.keras.models.load_model('./keras_model.h5')

雖然它們儲存和讀取的程式碼都很簡單明瞭,但還是能看出 HDF5 格式簡單的多。


上一篇
Day 21 反向傳播演算法
下一篇
Day 23 認識Inception Model
系列文
30天初探tensorflow之旅30
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言