iT邦幫忙

2018 iT 邦幫忙鐵人賽
DAY 20
0
AI & Machine Learning

機器學習你也可以 - 文組帶你手把手實做機器學習聖經系列 第 20

kernel method - 動手做看看 Gaussian Process篇

  • 分享至 

  • xImage
  •  

講了兩天Gaussian Process,是時候實做一下了,資料集使用之前用過的這個。最後應該會得到這樣的fitting結果,這邊我們假設noise的精度beta是1

import numpy as np
from numpy.linalg import inv
import matplotlib.pyplot as plt
import scipy.io as io

X = io.loadmat('2_data.mat')['x']
T = io.loadmat('2_data.mat')['t']
x_train = X[0:60]
x_test = X[60:100]
t_train = T[0:60]
t_test = T[60:100]

class kernel:
    def __init__(self, t0, t1, t2, t3):
        self.t0 = float(t0)
        self.t1 = float(t1)
        self.t2 = float(t2)
        self.t3 = float(t3)
    def show(self):
        print (self.t0, self.t1, self.t2, self.t3), '\n'

    def k(self,xn, xm):
        return  self.t0 * np.exp(-0.5*self.t1*(xn-xm).dot(xn-xm)) + self.t2 + self.t3*xn.T.dot(xm)

class gaussian_process:
    def __init__(self,t):
        self.theta = np.asarray(t).reshape(4,1)
        self.k = kernel(t[0],t[1],t[2],t[3])
    def fit(self, x, t):
        self.t = t
        self.x = x
        self.C = np.zeros((len(self.x),len(self.x)))
        for n in range(len(self.x)):
            for m in range(len(self.x)):
                self.C[n][m] = self.k.k(x[n],x[m]) + float(n == m) # since B = 1
    def predict(self, x): #return mean/variance
        ka = np.zeros((len(self.x),1))
        for n in range(len(self.x)):
            ka[n] = self.k.k(self.x[n],x)
        mean = ka.T.dot(inv(self.C)).dot(self.t)
        var = (self.k.k(x,x) + 1.0) - ka.T.dot(inv(self.C)).dot(ka)
        return mean, var
    def rms(self, datas, target):
        e = 0.
        for i,data in enumerate(datas):
            mean, var = self.predict(data)
            e += ((mean - target[i])**2)
        e /= len(datas)
        return np.sqrt(e)
    def ard(self, lr):  #lr for learning rate
        def c_diff(self, term):
            dc = np.zeros((len(self.x),len(self.x)))
            for n in range(len(self.x)):
                for m in range(len(self.x)):
                    if term == 0: # theta0 
                        dc[n][m] = np.exp(-0.5*self.theta[1]*((self.x[n] - self.x[m])**2))
                    elif term == 1: #eta
                        dc[n][m] = self.theta[0] * np.exp(-0.5*self.theta[1]*((self.x[n] - self.x[m])**2)) * (-0.5**((self.x[n] - self.x[m])**2))
                    elif term == 2: #theta2
                        dc[n][m] = 1.
                    else: #theta3
                        dc[n][m] = self.x[n].T.dot(self.x[m])
            return dc
        epoch = 0
        ex = []
        while True:
            ex += [epoch]
            update = np.zeros((4,1))
            flag = 0
            for i in range(4):
                update[i] = -0.5*np.trace(inv(self.C).dot(c_diff(self,i))) + 0.5*self.t.T.dot(inv(self.C)).dot(c_diff(self,i)).dot(inv(self.C)).dot(self.t)
                if np.absolute(update[i]) < 6.:
                    flag += 1
            self.theta += lr*update
            self.k = kernel(self.theta[0][0],self.theta[1][0],self.theta[2][0],self.theta[3][0])
            self.C = np.zeros((len(self.x),len(self.x)))
            for n in range(len(self.x)):
                for m in range(len(self.x)):
                    self.C[n][m] = self.k.k(self.x[n],self.x[m]) + float(n == m) # since B = 1
            epoch += 1
            if flag == 4:
                break

theta = [3.,6.,4.,5.]
gp = gaussian_process(theta)
gp.fit(x_train, t_train)
gp.ard(0.001)

line = np.linspace(0.,2.,50).reshape(50,1)
mx = []
vx = []
for sample in line:
    mean, var =  gp.predict(sample)
    mx += [mean]
    vx += [var]
mx = np.asarray(mx).reshape(50,1)
vx = np.asarray(vx).reshape(50,1)
plt.plot(x_train, t_train,'bo')
plt.plot(line, mx, linestyle = '-', color = 'red')
plt.fill_between(line.reshape(50), (mx-vx).reshape(50), (mx+vx).reshape(50), color = 'pink')
print gp.rms(x_train, t_train), gp.rms(x_test, t_test)
plt.savefig('output.png')

上一篇
kernel method - Gaussian Process 參數選擇
下一篇
SVM - 前言
系列文
機器學習你也可以 - 文組帶你手把手實做機器學習聖經30
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言