還記得之前講過的準確率和混淆矩陣嗎? 在簡單的講一下,混淆矩陣可以判斷一個模型的各項準確率,我們會用到的判斷公式有四個PPV、ACC、TPR、TNR,今天就是要來教各位印出混淆矩陣和判斷模型的好壞。
首先要做的事情是將兩個圖都先印出來我們會用到Matplotlib中的許多函數,一開始要先將準確度的圖片印出來,再來用title設定標題,xlabel和ylabel分別可以印出x軸和y軸的標籤,最後再用show方法展示出來。在印出折線圖之前別忘記import喔。
plt.plot(range(10), train_history.history["accuracy"])
plt.title("iteration - acc")
plt.xlabel("iteration")
plt.ylabel("acc")
plt.show()
上面這張圖就是我們的準確率的圖,接下來要印混淆矩陣的圖了。
import sklearn.metrics as skm
import seaborn as sns
cm = skm.confusion_matrix(y_true=test_label, y_pred=prediction)
plt.figure(figsize=(10, 6))
labels=np.arange(10)
sns.heatmap(
cm, xticklabels=labels, yticklabels=labels,
annot=True, linewidths=0.1, fmt='d', cmap='YlGnBu')
plt.title('Confusion Matrix', fontsize=15)
plt.ylabel('Actual label')
plt.xlabel('Predict label')
一開始要先載入兩個套件skm那個是為了要計算混淆矩陣,詳細寫法在第三行的cm那個,sns是用來繪製混淆矩陣的熱度圖(heatmap),要設定的參數有混淆混淆矩陣、x軸y軸的刻度標籤、顯示數值於格子中、邊線寬度、數值為整數、顏色映射,最後在設定標題就可以了。
這就是我們的混淆矩陣,可以透過主對角線的數值觀察到經過了多次的訓練過後,它的準確率已經相當高了。
別忘了我們的目標是計算它的準確率,我這邊用ACC來舉例,剩下的三個(PPV、TPR、TNR)大家可以自己算算看。
ACC的公式是
我們將主對角線的數值加起來可以得到9805,而總共的數據是10000個,所以我們最後可以得到我們的準確率為98.05%。