iT邦幫忙

1

(已解決!)Perceptron python 實做問題

import numpy as np
import random

data = np.loadtxt('hw1_6_train.txt')
print(data)                
print(type(data))          
print(np.shape(data))      
w = [0,0,0,0]
error = 1
iterator = 1
def sign(z): 
    if z > 0:
        return 1
    else:
        return -1
while error != 0:
    error = 0
    for i in range(400):
        data_x = data[i][0:4]
        #print(data_x)
        data_y = data[i][4]
        #print(data_y)
        if sign(np.dot(w,data_x)) != data_y:
            print("iterator: "+str(iterator))
            iterator += 1
            error += 1
            w += data_y * data_x
            #print(w)

這份程式碼如上,目的是為了把400筆4維的數據做2元分類,但我在跑400筆資料時,perceptron總是停不下來,而正確解答應該40、50多次就解決了?!實在是非常絕望,希望各位大大能幫忙看看問題出在哪,我也參考很多人寫的,我還是不知道自己問題出在哪,拜託大家幫忙。

data在此,
https://www.csie.ntu.edu.tw/~htlin/course/mlfound19fall/hw1/hw1_6_train.dat

順帶一題這也是林軒田老師機器學習基石作業一的題目.

1 個回答

0
ccutmis
iT邦高手 7 級 ‧ 2019-10-06 23:40:03
最佳解答

改成這樣試試看...你原本的while 變成無窮迴圈了

while error!=0:
    for i in range(0,len(data)):
        data_x = data[i][0:4]
        #print(data_x)
        data_y = data[i][4]
        #print(data_y)
        if sign(np.dot(w,data_x)) != data_y:
            print("iterator: "+str(iterator))
            iterator += 1
            error = 1
            w += data_y * data_x
            print(w)
        else:
            error = 0

p.s: 我這樣寫只是確定它停的下來,不確定結果是不是你要的喔...

感謝你的回答,幫我找到bug。
但是我發現真的問題是出在data,
在每一筆data裡都要加上x0 = 1 這項,不然跑到天荒地老都不會結束,即使程式都對。
為了寫出這個演算法真的是足足弄了2、3天...
真的深刻體悟"ALL FROM DATA"
確保DATA沒問題真的很重要.

ccutmis iT邦高手 7 級 ‧ 2019-10-07 07:53:26 檢舉

不客氣 恭喜你找到bug /images/emoticon/emoticon42.gif
"ALL FROM DATA" 讓我想到一句古老的名言:
GIGO "GARBAGE IN GARBAGE OUT".

That's true!

我要發表回答

立即登入回答