前一篇介紹完一個好用的雲端線上環境,今天就要開始AI機器學習的實作啦!
今天要實作的是前面有講過的卷積神經網路(Convolutional neural network),就是用我們的python環境和tensorflow套件下去寫。
以下是一段我用Chatgpt輔助去寫出來的簡單CNN模型
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
train_images = train_images.reshape((60000, 28, 28, 1)).astype('float32') / 255
test_images = test_images.reshape((10000, 28, 28, 1)).astype('float32') / 255
train_labels = to_categorical(train_labels)
test_labels = to_categorical(test_labels)
model = models.Sequential()
model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.Flatten())
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(10, activation='softmax'))
model.compile(optimizer='adam',
loss='categorical_crossentropy',
metrics=['accuracy'])
model.fit(train_images, train_labels, epochs=5, batch_size=64, validation_data=(test_images, test_labels))
test_loss, test_acc = model.evaluate(test_images, test_labels)
print(f'Test accuracy: {test_acc}')
這是一支訓練辨識手寫數字模型的CNN神經網路,在模型訓練完成後,最後會在終端將模型分辨準確度print出來。
接下來將會分段講解程式碼,因為我也是初學,若是哪裡錯誤或是不好敬請見諒。
//layers這個module裡有很多高級神經網絡層的類別,使得建構神經網絡模型更為方便。常見的層如 Dense(全連接層)、Conv2D(2D 卷積層)、MaxPooling2D(2D 最大池化層)這些CNN網路會用到的層都包含在這個module中。
//models這個module就是整個模型的skeleton。一個模型是由多個層(layers)組成的,定義了整個神經網絡的結構和行為。引入models讓我們能夠使用高級的 API 來建構和管理整個神經網絡模型。像是”model.add” :在模型中添加各種不同類型的層、"model.fit" : 將模型訓練在數據上,同時指定訓練的參數,如訓練次數、批次大小等。
// 在keras中,datasets裡有一些常見的數據集,方便user下載和使用。這些數據集通常用於機器學習的模型訓練和測試
// mnist 是一個很經典的手寫數字圖像的數據集,包括了 60,000 張 28x28 像素的訓練圖片和 10,000 張測試圖片,每張圖片都標有對應的數字標籤。
// 在 keras 中,utils 提供了一些輔助工具函數,用於處理數據、模型的保存和加載等功能。
// to_categorical 函數是用於將分類標籤轉換為 one-hot 編碼的函數。在機器學習的分類問題中,通常使用 one-hot 編碼表示目標類別。
接下來的篇幅會繼續講解剩餘的CNN程式碼~