上一篇中我們把資料寫進了x跟y兩個變數中,我們要把這些資料用圖片表示出來。
接下來要畫出我們的預測線,我會寫一個自定義函數來印出預測線,在用到這個函數時要寫入兩個參數,分別是w和b,他們代表著我們要印出的預測線的斜率和起點高度,如下:
def plot_pred(w, b):
y_pred = x*w + b
plt.plot(x, y_pred, color="blue", label="predict_line")
plt.scatter(x, y, marker="x", color="red", label="real_data")
plt.title("Height - Weight")
plt.xlabel("Height(m)")
plt.ylabel("Weight(kg)")
plt.xlim([1, 3])
plt.ylim([40, 100])
plt.legend()
plt.show()
plot_pred(0, 50)#參數可以自己調整看看
如果想在圖中調整參數也可以用下面這個試試看,會出現兩個可調整的數字,更方便觀察。
from ipywidgets import interact
interact(plot_pred, w=(-100, 100, 1), b=(0, 100, 1))#w跟b的數值也可以隨意調整
學會印出圖片和預測線後我們來講一下成本函數,成本函數是用來判斷我們所畫出的預測線的好壞,舉個例子:
上面這兩張圖很明顯看得出誰的預測線比較準對吧,那是因為上面那張圖的成本函數比下面的圖少很多,成本函數是怎麼計算的呢?
線性回歸的成本函數計算可以想成每筆資料到預測線的距離加總(a1-a2)²+(b1-b2)²+(c1-c2)²,上面的圖的成本函數就是(2-2)²+(4-4)²+(6-6)²=0,下面的成本函數為(1-2)²+(2-4)²+(3-6)²=14
回到我們的範例中,模型的成本函數可以用下列程式碼進行計算。
w = 0
b = 0#w跟b可自行調整
y_pred = w*x + b
cost = (y - y_pred)**2
cost.sum() / len(x)