iT邦幫忙

第 11 屆 iThome 鐵人賽

DAY 26
0
自我挑戰組

Tensorflow學習日記系列 第 26

Tensorflow學習日記Day26 Cifar-10 CNN 訓練

  • 分享至 

  • xImage
  •  

model.compile(loss='categorical_crossentropy',optimizer='adam',metrics=['accuracy'])

train_history=model.fit(x=x_img_train_normalize,
y=y_label_train_OneHot,
validation_split=0.2,
epochs=10,
batch_size=128,
verbose=1
)
output:
Train on 40000 samples, validate on 10000 samples
Epoch 1/10
40000/40000 [==============================] - 76s 2ms/step - loss: 1.5693 - accuracy: 0.4233 - val_loss: 1.3917 - val_accuracy: 0.5448
Epoch 2/10
40000/40000 [==============================] - 75s 2ms/step - loss: 1.2174 - accuracy: 0.5626 - val_loss: 1.2155 - val_accuracy: 0.6318
Epoch 3/10
40000/40000 [==============================] - 75s 2ms/step - loss: 1.0685 - accuracy: 0.6198 - val_loss: 1.1246 - val_accuracy: 0.6408
Epoch 4/10
40000/40000 [==============================] - 75s 2ms/step - loss: 0.9758 - accuracy: 0.6533 - val_loss: 1.0192 - val_accuracy: 0.6761
Epoch 5/10
40000/40000 [==============================] - 75s 2ms/step - loss: 0.8936 - accuracy: 0.6833 - val_loss: 1.0440 - val_accuracy: 0.6749
Epoch 6/10
40000/40000 [==============================] - 78s 2ms/step - loss: 0.8409 - accuracy: 0.7018 - val_loss: 0.9492 - val_accuracy: 0.6993
Epoch 7/10
40000/40000 [==============================] - 76s 2ms/step - loss: 0.7865 - accuracy: 0.7226 - val_loss: 0.9009 - val_accuracy: 0.7178
Epoch 8/10
40000/40000 [==============================] - 76s 2ms/step - loss: 0.7341 - accuracy: 0.7422 - val_loss: 0.8534 - val_accuracy: 0.7247
Epoch 9/10
40000/40000 [==============================] - 76s 2ms/step - loss: 0.6916 - accuracy: 0.7536 - val_loss: 0.7970 - val_accuracy: 0.7428
Epoch 10/10
40000/40000 [==============================] - 76s 2ms/step - loss: 0.6488 - accuracy: 0.7716 - val_loss: 0.7818 - val_accuracy: 0.7524

import matplotlib.pyplot as plt#定義show_train_history(之前訓練產生的)(train_history,訓練執行結果,驗證資料執行結果)
def show_train_history(train_history,train,validation):
plt.plot(train_history.history[train])
plt.plot(train_history.history[validation])
plt.title('Train History')#圖形標題
plt.ylabel(train)#顯示y軸標籤
plt.xlabel('Epoch')#設定x軸標籤是'Epoch'
plt.legend(['train','validation'],loc='upper left')#設定圖例顯示'train','validation'在左上角
plt.show()

output
https://scontent.ftpe8-2.fna.fbcdn.net/v/t1.15752-9/72092513_472027756723729_3758090817722384384_n.png?_nc_cat=103&_nc_oc=AQmwEXrTWyX8s67QzgDgSJdqf5H7M9Krg3CM3lWBKCiL1XaUYv-vMvUouK_BOVkyg0c&_nc_ht=scontent.ftpe8-2.fna&oh=534e1046188f6db50d0ce691c997d511&oe=5E259F47

show_train_history(train_history,'accuracy','val_accuracy')
output:
https://scontent.ftpe8-2.fna.fbcdn.net/v/t1.15752-9/72322809_2422861027984221_4003838933614985216_n.png?_nc_cat=100&_nc_oc=AQmSqWr1BuqPmyuyuQRdS4Xj4bUZ3hpavgqDxmRep4bq7jaKblp7QWsmAJcGwp3y6mY&_nc_ht=scontent.ftpe8-2.fna&oh=50239a39e4110aa609c38375e29dc73f&oe=5E1907E1

show_train_history(train_history,'loss','val_loss')
output:
https://scontent.ftpe8-3.fna.fbcdn.net/v/t1.15752-9/72037943_432189784079315_8352695656724299776_n.png?_nc_cat=107&_nc_oc=AQnTq14POQxbpOH5LdRLclGtq_mKKja157GFlG-VLjJZac2z-3dVko_pwOPtpWLEvWo&_nc_ht=scontent.ftpe8-3.fna&oh=94e580bae3297da218b409d5f7610379&oe=5E23B9E3

#準確率
scores=model.evaluate(x_img_test_normalize,y_label_test_OneHot,verbose=0)
scores[1]
output:
0.7461000084877014

#進行預測
prediction=model.predict_classes(x_img_test_normalize)
prediction[:10]
output:
array([3, 8, 8, 0, 6, 6, 5, 6, 3, 1]

#建立函式顯示預測結果
label_dict={0:"airplane",1:"automobile",2:"bird",3:"cat",4:"deer",5:"dog",6:"frog",7:"horse",8:"ship",9:"truck"}
import matplotlib.pyplot as plt
def plot_images_labels_prediction(images,labels,prediction,idx,num=10):
fig=plt.gcf()
fig.set_size_inches(12,14)
if num>25:num=25
for i in range(0,num):
ax=plt.subplot(5,5,1+i)
ax.imshow(images[idx],cmap='binary')
title=str(i)+','+label_dict[labels[i][0]]
if len(prediction)>0:
title+='=>'+label_dict[prediction[i]]
ax.set_title(title,fontsize=10)
ax.set_xticks([]);ax.set_yticks([])
idx+=1
plt.show()

plot_images_labels_prediction(x_img_test,y_label_test,prediction,0,10)
https://scontent.ftpe8-4.fna.fbcdn.net/v/t1.15752-9/72277846_2533306040279561_953224345889538048_n.png?_nc_cat=104&_nc_oc=AQlRwPdTQhnJ7BKJlOUebIkfDVL3UBQpeZoTd9DzqoUkQMrtBLmyjw9mKl6nVihbvdI&_nc_ht=scontent.ftpe8-4.fna&oh=6dd772ca52cce6922f10df1097803afb&oe=5E38835C

Predicted_Probability=model.predict(x_img_test_normalize)
In [54]:

def show_Predicted_Probability(y,prediction,x_img,Predicted_Probability,i):
print('label:',label_dict[y[i][0]],'predict',label_dict[prediction[i]])
plt.figure(figsize=(2,2))
plt.imshow(np.reshape(x_img_test[i],(32,32,3)))
plt.show()
for j in range(10):
print(label_dict[j]+'Probability:%1.9f'%(Predicted_Probability[i][j]))

show_Predicted_Probability(y_label_test,prediction,x_img_test,Predicted_Probability,0)
label: cat predict cat
https://scontent.ftpe8-4.fna.fbcdn.net/v/t1.15752-9/72277846_2533306040279561_953224345889538048_n.png?_nc_cat=104&_nc_oc=AQlRwPdTQhnJ7BKJlOUebIkfDVL3UBQpeZoTd9DzqoUkQMrtBLmyjw9mKl6nVihbvdI&_nc_ht=scontent.ftpe8-4.fna&oh=6dd772ca52cce6922f10df1097803afb&oe=5E38835C
airplaneProbability:0.004340991
automobileProbability:0.000622015
birdProbability:0.004717546
catProbability:0.527973473
deerProbability:0.003106668
dogProbability:0.436110795
frogProbability:0.013606725
horseProbability:0.005316502
shipProbability:0.003260329
truckProbability:0.000944890
In [56]:

show_Predicted_Probability(y_label_test,prediction,x_img_test,Predicted_Probability,3)
label: airplane predict airplane
https://scontent.ftpe8-4.fna.fbcdn.net/v/t1.15752-9/73028826_1593552037467490_2938984410562691072_n.png?_nc_cat=110&_nc_oc=AQl_HSHOvN5csqSj36YkPTZYWv_eS_JU3Z9phgCldwoPWs1QlyBiHN3mZkygNhZf2MQ&_nc_ht=scontent.ftpe8-4.fna&oh=64a7864a260c5096420b8cc063567f23&oe=5E387D2C
airplaneProbability:0.465026021
automobileProbability:0.016218476
birdProbability:0.149465665
catProbability:0.027856700
deerProbability:0.050020032
dogProbability:0.004426337
frogProbability:0.004115783
horseProbability:0.003603992
shipProbability:0.270612150
truckProbability:0.008654800

prediction.shape
output:
(10000,)

y_label_test.shape
output:
(10000, 1)

y_label_test.reshape(-1)
output:
array([3, 8, 8, ..., 5, 1, 7])

#顯示混淆矩陣
import pandas as pd
print(label_dict)
pd.crosstab(y_label_test.reshape(-1),prediction,rownames=['label'],colnames=['predict'])
{0: 'airplane', 1: 'automobile', 2: 'bird', 3: 'cat', 4: 'deer', 5: 'dog', 6: 'frog', 7: 'horse', 8: 'ship', 9: 'truck'}
output:
predict 0 1 2 3 4 5 6 7 8 9
label
0 819 15 38 27 28 2 7 15 31 18
1 24 866 2 16 7 5 7 5 17 51
2 63 1 588 97 148 35 43 16 5 4
3 20 4 48 634 81 132 40 30 5 6
4 12 1 39 61 809 14 10 50 4 0
5 8 0 37 236 71 586 13 45 4 0
6 7 1 30 98 83 19 753 2 4 3
7 11 0 37 47 82 31 4 786 0 2
8 68 21 21 19 18 6 2 4 828 13
9 43 67 6 36 8 5 6 15 22 792


上一篇
tensorflow學習日記Day24 Cifar-10模型建立
下一篇
tensorflow學習日記Day27 Cifar-10 CNN 建立三次卷積網路
系列文
Tensorflow學習日記30
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言