iT邦幫忙

第 12 屆 iThome 鐵人賽

DAY 29
0

機器學習的測試

今天我們要來介紹機器學習該如何測試。

該測試甚麼?

我們知道對機器學習來說,資料的正確性以及模型是否正確很重要,那麼我們要怎麼知道資料是否正確呢?

從形狀來下手

最簡單的就是從形狀下手了,假設我們輸入一張28X28的圖片,那麼我們的輸入大小就是28X28,如果檢查輸入大小不為28X28,就是有錯。

Code

model

假設我們的輸入大小為28,28的圖片,輸出為10個類別。

def model():
  input_shape=(28,28,1)

  input = Input(input_shape, name='input')
  layer=Flatten()(input)
  output = Dense(nb_classes, name="Dense_10nb", activation='softmax')(layer)

  model = Model(inputs=[input], outputs=[output])

  model.compile(loss='sparse_categorical_crossentropy',optimizer=keras.optimizers.Adam(lr=0.0001,decay=1e-6),metrics = ['accuracy'])
  return model

測試

from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
import numpy as np


class ModelShapeTest(tf.test.TestCase):
    def setUp(self):
      super(ModelShapeTest, self).setUp()
      self.model = model()
    
    #測試輸入大小
    def test_model_input_shape(self):
      input_shape=self.model.input.shape
      self.assertAllEqual((None,28,28,1), input_shape)
    #測試輸出大小  
    def test_model_input_shape(self):
      output_shape=self.model.output.shape
      self.assertAllEqual((None,10), output_shape)

那麼從這個測試,我們可以知道我們的model輸入以及輸出是對的。

除了形狀外,還要確保其他事項是對的

例如我們在之前寫的Voting,到底是否正確呢?我們也可以來寫個測試。

Code

def voting(predictlist,numbers,y_test):
  # predictlist是放所有model的預測結果
  # numbers則為投票次數
  predictlist = np.reshape(predictlist, (-1, len(y_test)))
  voting_acc = 0
  
  for testnumber in range(len(y_test)):
    decide=np.zeros(10)
    for number in range(numbers):
      decide[int(predictlist[number][testnumber])]+=1
    #查看投票結果與實際答案是否相同
    if np.argmax(decide) == y_test[testnumber]:
      voting_acc+=1

  print("voting_acc:", voting_acc/len(y_test))
  return voting_acc/len(y_test)

測試

class VotingTest(tf.test.TestCase):  
    def setUp(self):
      super(VotingTest, self).setUp()

    def test_voting_max(self):
      test_array=[5,6,6,5,5]
      y_test=[5]
      acc=voting(test_array,5,y_test)
      self.assertAllEqual(1.0, acc)

    def test_voting_same(self):
      test_array=[1,1,1,1,1]
      y_test=[1]
      acc=voting(test_array,5,y_test)
      self.assertAllEqual(1.0, acc)

假設5個人投票,3人投給5號,2人頭給6號,那麼結果為5號。

結論

今天我們介紹了如何測試,大部分的model都很難做測試,但是我們至少可以確保輸入或輸出,以及我們自己寫的其他相關function(Voting或是標準化輸入等)是沒有錯誤的。

參考資料

Unit Testing in Tensorflow 2.0


上一篇
Day 28 機器學習的技術債
下一篇
Day 30 完賽心得
系列文
Machine Learning與軟工是否搞錯了什麼?30
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言