昨天完成了一個基本的訓練,今天當然就要來拿來預測資料啦~~
(沒有文章庫存了,好緊張喔,每天都要努力產出文章,還要弄研究所推甄的東西
為了方便未來的使用,我們會將這個模型的相關操作封裝在一個類別中,命名為MNISTModel
。這樣一來,我們可以輕鬆初始化模型、進行訓練、保存以及載入模型,避免重複編寫相同的程式碼。
首先,我們將開始定義這個類別,並且在初始化時設置模型的相關屬性。
class MNISTModel:
def __init__(self, model_path='mnist_model.keras'):
# 初始化模型存儲路徑
self.model_path = model_path
self.model = None
# 加載或訓練模型
self._load_or_train_model()
_load_or_train_model
,它負責檢查模型是否已經存在,並決定是加載還是訓練新的模型。在這個方法裡,我們首先會準備 MNIST 資料集,對數據進行標準化處理,然後再進行模型的加載或訓練。以下是完整的程式碼及其步驟說明:
def _load_or_train_model(self):
# 加載並準備數據
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# 將圖像數據標準化至 0 和 1 之間
x_train = x_train / 255.0
x_test = x_test / 255.0
# 將標籤轉換為 one-hot 編碼
y_train = to_categorical(y_train)
y_test = to_categorical(y_test)
# 檢查模型是否已存在於指定路徑
if os.path.exists(self.model_path):
# 如果模型檔案存在,則加載它
self.model = load_model(self.model_path)
print("Loaded trained model from disk.")
else:
# 如果模型不存在,則創建並訓練一個新模型
self.model = Sequential([
Flatten(input_shape=(28, 28)),
Dense(128, activation='relu'),
Dense(64, activation='relu'),
Dense(10, activation='softmax')
])
# 編譯模型,設置優化器、損失函數和評估指標
self.model.compile(optimizer='adam',
loss='categorical_crossentropy',
metrics=['accuracy'])
# 設置模型檢查點,僅保存最佳模型
checkpoint = ModelCheckpoint(self.model_path, save_best_only=True, monitor='val_loss', mode='min')
# 訓練模型並保存
self.model.fit(x_train, y_train, epochs=5, batch_size=32, validation_split=0.2, callbacks=[checkpoint])
print("Training completed and model saved.")
在這個步驟中,很大一部分的程式碼基本上都是跟前幾天的訓練是一樣的,唯一不同的地方就是我在訓練完成後將訓練完成的模型保存。而如果有發現已經訓練好的模型,我們則會直接將模型載入使用,不是重新訓練。
def preprocess_image(self, img_path):
img = Image.open(img_path).convert('L') # 轉換為灰階圖像
img = ImageOps.invert(img) # 反轉顏色:黑底白字
img = img.resize((28, 28)) # 調整大小為 28x28
img = np.array(img) / 255.0 # 歸一化
img = img.reshape(1, 28, 28) # 調整形狀以適應模型輸入
return img
def predict(self, img_path):
img = self.preprocess_image(img_path)
prediction = self.model.predict(img)
predicted_digit = np.argmax(prediction)
print(f'The predicted digit is: {predicted_digit}')
# 可視化
plt.imshow(img.reshape(28, 28), cmap='gray')
plt.title(f'Predicted Digit: {predicted_digit}')
plt.show()
if __name__ == "__main__":
mnist_model = MNISTModel()
# 請替換為你自己的圖像路徑
img_path = 'test.png'
mnist_model.predict(img_path)
我給了他一個圖片
最後運行出來的結果會顯示他預測我們給的圖片為3,這代表我們的訓練跟預測都有做到很好的效果。(但是其實他是要用手寫字來預測比較好啦,但我是直接網路上隨便找一個數字圖片)
透過今天的這個程式,無論是模型訓練或預測的過程,都可以輕鬆地實現並且直觀展示結果。只需要替換圖像路徑,就可以開始預測自己的手寫數字圖像了!
希望今天這篇文章有幫助到你!