iT邦幫忙

2023 iThome 鐵人賽

DAY 25
0
AI & Data

機器學習不難嘛系列 第 25

Day25-MNIST準確率和混淆矩陣

  • 分享至 

  • xImage
  •  

還記得之前講過的準確率和混淆矩陣嗎? 在簡單的講一下,混淆矩陣可以判斷一個模型的各項準確率,我們會用到的判斷公式有四個PPV、ACC、TPR、TNR,今天就是要來教各位印出混淆矩陣和判斷模型的好壞。

首先要做的事情是將兩個圖都先印出來我們會用到Matplotlib中的許多函數,一開始要先將準確度的圖片印出來,再來用title設定標題,xlabel和ylabel分別可以印出x軸和y軸的標籤,最後再用show方法展示出來。在印出折線圖之前別忘記import喔。

plt.plot(range(10), train_history.history["accuracy"])
plt.title("iteration - acc")
plt.xlabel("iteration")
plt.ylabel("acc")
plt.show()

https://ithelp.ithome.com.tw/upload/images/20231008/20162311PwcgNIrQ5m.png

上面這張圖就是我們的準確率的圖,接下來要印混淆矩陣的圖了。

import sklearn.metrics as skm
import seaborn as sns
cm = skm.confusion_matrix(y_true=test_label, y_pred=prediction)

plt.figure(figsize=(10, 6))
labels=np.arange(10)
sns.heatmap(
    cm, xticklabels=labels, yticklabels=labels,
    annot=True, linewidths=0.1, fmt='d', cmap='YlGnBu')
plt.title('Confusion Matrix', fontsize=15)
plt.ylabel('Actual label')
plt.xlabel('Predict label')

一開始要先載入兩個套件skm那個是為了要計算混淆矩陣,詳細寫法在第三行的cm那個,sns是用來繪製混淆矩陣的熱度圖(heatmap),要設定的參數有混淆混淆矩陣、x軸y軸的刻度標籤、顯示數值於格子中、邊線寬度、數值為整數、顏色映射,最後在設定標題就可以了。

https://ithelp.ithome.com.tw/upload/images/20231008/20162311u2aMT9IbzA.png

這就是我們的混淆矩陣,可以透過主對角線的數值觀察到經過了多次的訓練過後,它的準確率已經相當高了。

別忘了我們的目標是計算它的準確率,我這邊用ACC來舉例,剩下的三個(PPV、TPR、TNR)大家可以自己算算看。

ACC的公式是

https://ithelp.ithome.com.tw/upload/images/20231008/20162311Eo2d6Z5vb0.png

我們將主對角線的數值加起來可以得到9805,而總共的數據是10000個,所以我們最後可以得到我們的準確率為98.05%。


上一篇
Day24-MNIST模型預測
下一篇
Day26-線性回歸 介紹
系列文
機器學習不難嘛30
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言