iT邦幫忙

2018 iT 邦幫忙鐵人賽
DAY 28
0

今天就來讓我們實做看看GMM吧,我們以圖片為例子,將顏色做分群,我們利用Kmeans所得到的結果當作GMM得初始值。在GMM的部份,我們讓他重複執行一百次,並且紀錄下每一次的log likelihood,讓大家看一下我們的likelihood是真的有在上升,下圖是k = 5的時,log likelihood的變化曲線。

import numpy as np
from numpy.linalg import inv, det, pinv
import matplotlib.pyplot as plt
import cv2

img = cv2.imread('your_input_img.jpg')

shrunk = 1 #若圖太大,會很慢,可以縮小試試看效果就好了,這邊要是0~1之間
img = cv2.resize(img, (int(img.shape[1]*shrunk),int(img.shape[0]*shrunk)), interpolation=cv2.INTER_CUBIC)
img = img/255.

class kmeans:
    def __init__(self,k):
        self.k = k
    def fit(self, img, save):
        channel = 3
        means = np.random.rand(self.k, channel)
        while True:
            tag_map = np.zeros((img.shape[0], img.shape[1]))
            new_means = np.zeros((self.k, channel))
            cnt = np.zeros((self.k))
            for r in range(img.shape[0]):
                for c in range(img.shape[1]):
                    min_dist = 1000000.
                    tag = -1
                    for k in range(self.k):
                        if (img[r][c] - means[k]).dot(img[r][c] - means[k]) < min_dist:
                            min_dist = (img[r][c] - means[k]).dot(img[r][c] - means[k])
                            tag = k
                    tag_map[r][c] = tag
                    new_means[tag] += img[r][c]
                    cnt[tag] += 1
            for k in range(self.k):
                if cnt[k] == 0:
                    new_means[k] = 0.
                else:
                    new_means[k] /= float(cnt[k])
            if (np.absolute(new_means - means) < np.ones((self.k, channel)) * 0.003).all():
                break
            new_var = np.zeros((self.k, channel, channel))
            for r in range(img.shape[0]):
                for c in range(img.shape[1]):
                    for k in range(self.k):
                        new_var[k]+=float(tag_map[r][c] == k) * (img[r][c].reshape(channel,1) - new_means[k].reshape(channel,1)).dot((img[r][c].reshape(channel,1) - new_means[k].reshape(channel,1)).T)
            for k in range(self.k):
                if cnt[k] == 0:
                    new_var[k] = np.zeros((channel, channel))
                else:
                    new_var[k] /= float(cnt[k])
            means = np.copy(new_means)
            self.var = np.copy(new_var)
            self.means = np.copy(means)
            self.pi = cnt / float(img.shape[0]*img.shape[1])
        if save == 1:
            tmp = np.copy(img)
            for r in range(img.shape[0]):
                for c in range(img.shape[1]):
                    tmp[r][c] = means[int(tag_map[r][c])] * 255.
            cv2.imwrite('kmean_%d.jpg'%self.k, tmp)

    def get(self):
        return self.means, self.pi, self.var

class GMM:
    def __init__(self, k):
        self.k = k
    def fit(self, img, save):

        def normal_d(self,x, mean, cov):
            if det(cov) != 0:
                return ((2*np.pi)**(-1*self.k/2.)) * (det(cov)**-0.5) * np.exp(-0.5 * ((x - mean).T.dot(inv(cov)).dot(x - mean)))
            else:
                cov = cov + 0.0001*np.identity(3)
                return ((2*np.pi)**(-1*self.k/2.)) * (det(cov)**-0.5) * np.exp(-0.5 * ((x - mean).T.dot(inv(cov)).dot(x - mean)))
        def log_likelihood(self, img, pi, mean, var):
            log_like = 0.
            for r in range(img.shape[0]):
                for c in range(img.shape[1]):
                    tmp = 0.
                    xn = img[r][c].reshape(channel,1)
                    for k in range(self.k):
                        tmp += pi[k] * normal_d(self ,xn, mean[k].reshape(channel,1), var[k])
                    log_like += np.log(tmp)
            return log_like

        channel = 3

        #init
        k_means = kmeans(self.k)
        k_means.fit(img, save)
        means, pi, var = k_means.get()

        #training
        rnk = np.zeros((img.shape[0], img.shape[1],self.k))
        epoch = 0
        ex = []
        lx = []
        while True:
            lk = log_likelihood(self, img, pi, means, var)
            print epoch, lk
            ex += [epoch]
            lx += [lk.reshape(1)]

            nk = np.zeros(self.k)
            new_means = np.zeros((self.k, channel))
            for r in range(img.shape[0]):
                for c in range(img.shape[1]):
                    xn = img[r][c].reshape(channel,1)
                    mix_gau = np.zeros((self.k, 1))
                    for k in range(self.k):
                        mix_gau[k] = pi[k] * normal_d(self, xn, means[k].reshape(channel,1), var[k])
                    sum_ = np.sum(mix_gau)
                    for k in range(self.k):
                        rnk[r][c][k] = mix_gau[k] / sum_
                        nk[k] += rnk[r][c][k]
                        new_means[k] += rnk[r][c][k] * xn.reshape(channel)
            for k in range(self.k):
                new_means[k] = np.copy(new_means[k] / nk[k])

            means = new_means #update means
            new_var = np.zeros((self.k, channel, channel))
            for r in range(img.shape[0]):
                for c in range(img.shape[1]):
                    xn = img[r][c].reshape(channel,1)
                    for k in range(self.k):
                        new_var[k] += rnk[r][c][k] * (xn - means[k].reshape(channel,1)).dot((xn - means[k].reshape(channel,1)).T)
            new_pi = np.zeros(self.k)
            for k in range(self.k):
                new_var[k] /= nk[k]
                new_pi[k] = nk[k] / float(img.shape[0]*img.shape[1])
            pi = np.copy(new_pi)
            var = np.copy(new_var)
            epoch += 1
            if epoch > 100:
                break
        self.mean = means
        self.var = var
        self.pi = pi
        if save == 1:
            tmp = np.copy(img)
            for r in range(img.shape[0]):
                for c in range(img.shape[1]):
                    xn = tmp[r][c].reshape(channel,1)
                    p = np.zeros(self.k)
                    for k in range(self.k):
                        p[k] = normal_d(self, xn, self.mean[k].reshape(channel,1), self.var[k])
                    tmp[r][c] = self.mean[np.argmax(p)] * 255.
            cv2.imwrite('gmm_%d.jpg'%self.k, tmp)
        plt.figure()
        plt.plot(ex, lx, linestyle = '--')
        plt.savefig('log_likelihood_%d.png' %self.k)

for K in [2,3,5]:
    gmm = GMM(K)
    gmm.fit(img, 1)


上一篇
EM - Gaussian Mixture Model
下一篇
EM - EM in general
系列文
機器學習你也可以 - 文組帶你手把手實做機器學習聖經30

尚未有邦友留言

立即登入留言