iT邦幫忙

2024 iThome 鐵人賽

DAY 21
1

前言

昨天完成了一個基本的訓練,今天當然就要來拿來預測資料啦~~
(沒有文章庫存了,好緊張喔,每天都要努力產出文章,還要弄研究所推甄的東西/images/emoticon/emoticon02.gif

程式開發及解釋

為了方便未來的使用,我們會將這個模型的相關操作封裝在一個類別中,命名為MNISTModel。這樣一來,我們可以輕鬆初始化模型、進行訓練、保存以及載入模型,避免重複編寫相同的程式碼。

首先,我們將開始定義這個類別,並且在初始化時設置模型的相關屬性。

  1. 建立 MNISTModel 類別
class MNISTModel:
    def __init__(self, model_path='mnist_model.keras'):
        # 初始化模型存儲路徑
        self.model_path = model_path
        self.model = None
        
        # 加載或訓練模型
        self._load_or_train_model()
  1. 加載或訓練模型
    我們的 MNISTModel 類別包含了一個核心方法 _load_or_train_model,它負責檢查模型是否已經存在,並決定是加載還是訓練新的模型。

在這個方法裡,我們首先會準備 MNIST 資料集,對數據進行標準化處理,然後再進行模型的加載或訓練。以下是完整的程式碼及其步驟說明:


    def _load_or_train_model(self):
        # 加載並準備數據
        (x_train, y_train), (x_test, y_test) = mnist.load_data()

        # 將圖像數據標準化至 0 和 1 之間
        x_train = x_train / 255.0
        x_test = x_test / 255.0

        # 將標籤轉換為 one-hot 編碼
        y_train = to_categorical(y_train)
        y_test = to_categorical(y_test)

        # 檢查模型是否已存在於指定路徑
        if os.path.exists(self.model_path):
            # 如果模型檔案存在,則加載它
            self.model = load_model(self.model_path)
            print("Loaded trained model from disk.")
        else:
            # 如果模型不存在,則創建並訓練一個新模型
            self.model = Sequential([
                Flatten(input_shape=(28, 28)),
                Dense(128, activation='relu'),
                Dense(64, activation='relu'),
                Dense(10, activation='softmax')
            ])

            # 編譯模型,設置優化器、損失函數和評估指標
            self.model.compile(optimizer='adam',
                              loss='categorical_crossentropy',
                              metrics=['accuracy'])

            # 設置模型檢查點,僅保存最佳模型
            checkpoint = ModelCheckpoint(self.model_path, save_best_only=True, monitor='val_loss', mode='min')

            # 訓練模型並保存
            self.model.fit(x_train, y_train, epochs=5, batch_size=32, validation_split=0.2, callbacks=[checkpoint])
            print("Training completed and model saved.")

在這個步驟中,很大一部分的程式碼基本上都是跟前幾天的訓練是一樣的,唯一不同的地方就是我在訓練完成後將訓練完成的模型保存。而如果有發現已經訓練好的模型,我們則會直接將模型載入使用,不是重新訓練。

  1. 圖像預處理
    在圖像預測中,對圖像進行適當的預處理是很重要的步驟,因為模型期望輸入資料的格式與其訓練過程中使用的資料格式一致。以下是我們的 preprocess_image 函數的介紹:
def preprocess_image(self, img_path):
    img = Image.open(img_path).convert('L')  # 轉換為灰階圖像
    img = ImageOps.invert(img)  # 反轉顏色:黑底白字
    img = img.resize((28, 28))  # 調整大小為 28x28
    img = np.array(img) / 255.0  # 歸一化
    img = img.reshape(1, 28, 28)  # 調整形狀以適應模型輸入
    return img
  • 在這個函數中,圖像的預處理過程包括幾個重要的步驟:
    • 灰階轉換:使用 convert('L') 將圖像轉換為單通道的灰度圖像,因為 MNIST 手寫數字是黑白的。
    • 顏色反轉:使用 ImageOps.invert 將圖像顏色反轉,使手寫數字呈現出黑底白字的格式,這與 MNIST 資料集格式一致。
    • 圖像縮放:通過 resize((28, 28)) 將圖像大小調整為 28x28 像素,這是 MNIST 資料集中每張圖像的標準尺寸。
    • 歸一化:圖像的像素值被除以 255.0,使其歸一化到 [0, 1] 區間,這有助於加速模型的學習過程。
    • 調整形狀:最終將圖像形狀調整為 (1, 28, 28),以便符合模型的批次輸入格式。
      這些步驟確保圖像與模型的預期輸入形式一致,讓模型能夠進行正確的預測。
  1. 圖像預測
    當圖像預處理完成後,我們便可以使用模型來進行數字的預測。predict 函數負責處理圖像並輸出最終的預測結果。這裡是函數的具體實現:
    def predict(self, img_path):
        img = self.preprocess_image(img_path)
        prediction = self.model.predict(img)
        predicted_digit = np.argmax(prediction)
        print(f'The predicted digit is: {predicted_digit}')

        # 可視化
        plt.imshow(img.reshape(28, 28), cmap='gray')
        plt.title(f'Predicted Digit: {predicted_digit}')
        plt.show()
  • 在這個函數中,我們執行了以下操作:
    • 圖像預處理:首先,我們使用 preprocess_image 函數將輸入的圖像轉換為模型可接受的格式。
    • 模型預測:利用 self.model.predict(img) 來對預處理後的圖像進行預測。這將返回一個包含每個數字 (0-9) 的預測機率陣列。
    • 提取預測結果:使用 np.argmax(prediction) 來找到機率最大的索引,這個索引對應的數字即為模型預測的結果。
    • 結果輸出:我們使用 print 函數來輸出預測的數字。
    • 圖像可視化:最後,我們利用 matplotlib 顯示原始圖像,並在標題中展示預測結果,讓用戶直觀地看到模型預測的準確性。
  1. 主程式運行
    最後,我們來看看整個程式的執行部分。
if __name__ == "__main__":
    mnist_model = MNISTModel()
    # 請替換為你自己的圖像路徑
    img_path = 'test.png'
    mnist_model.predict(img_path)

我給了他一個圖片
image

最後運行出來的結果會顯示他預測我們給的圖片為3,這代表我們的訓練跟預測都有做到很好的效果。(但是其實他是要用手寫字來預測比較好啦,但我是直接網路上隨便找一個數字圖片)
image

結語

透過今天的這個程式,無論是模型訓練或預測的過程,都可以輕鬆地實現並且直觀展示結果。只需要替換圖像路徑,就可以開始預測自己的手寫數字圖像了!

希望今天這篇文章有幫助到你!


上一篇
[Day 20] 深度學習的Hello World!訓練模型並探討過度擬合
下一篇
[Day 22] 初見生成對抗網路
系列文
深度學習的學習之旅:從理論到實作30
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

1 則留言

0

我要留言

立即登入留言