iT邦幫忙

2023 iThome 鐵人賽

DAY 21
0
自我挑戰組

python-資料分析與機器學習系列 第 21

DAY21-循環神經網路(RNN)(下)

  • 分享至 

  • xImage
  •  

資料預處理

from keras.datasets import mnist
from keras.src.utils import np_utils

# 載入數據集,並將其分為訓練集和測試集
(train_feature, train_label), (test_feature, test_label) = mnist.load_data()
#image轉換
train_feature_vector = train_feature.reshape(len(train_feature),28,28).astype('float32')
test_feature_vector = test_feature.reshape(len(test_feature),28,28).astype('float32')
#image標準化
train_feature_nor = train_feature_vector/255
test_feature_nor = test_feature_vector/255
#轉為One-Hot-Encoding編碼
train_label_onehot = np_utils.to_categorical(train_label)
test_label_onehot = np_utils.to_categorical(test_label)

建立循環神經網路模型

from keras.models import Sequential
from keras.layers import SimpleRNN,Dropout,Dense 

model = Sequential()
#建立SimpleRNN層
model.add(SimpleRNN(input_shape=(28,28),units=256,unroll=True))
#建立輸出層
model.add(Dropout(0,1))
#建立拋棄層
model.add(Dense(units=10,kernel_initializer='normal',activation='softmax'))

模型訓練

model.compile(loss='categorical_crossentropy',optimizer='adam',metrics=['accuracy'])

train_his = model.fit(x=train_feature_nor,y=train_label_onehot,validation_split=0.3,epochs=10,batch_size=250,verbose=2)

評估模型

scores = model.evaluate(test_feature_nor,test_label_onehot)
print('\n準確率=',scores[1])

圖片預測

import numpy as np
prediction=model.predict(test_feature_nor)
predicted_labels = np.argmax(prediction, axis=1)
print(predicted_labels)

但因為SimpleRNN記憶效果不好,因此可以利用長短期記憶神經循環網路。
長短期記憶(LSTM)

from keras.layers.recurrent import LSTM
model.add(LSTM(input_shape=(28,28),units=256,unroll=False))

---20231006---


上一篇
DAY20-循環神經網路(RNN)(上)
下一篇
DAY22-自然語言處理(jieba模組)
系列文
python-資料分析與機器學習30
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言