今天就來讓我們實做看看GMM吧,我們以圖片為例子,將顏色做分群,我們利用Kmeans所得到的結果當作GMM得初始值。在GMM的部份,我們讓他重複執行一百次,並且紀錄下每一次的log likelihood,讓大家看一下我們的likelihood是真的有在上升,下圖是k = 5的時,log likelihood的變化曲線。
import numpy as np
from numpy.linalg import inv, det, pinv
import matplotlib.pyplot as plt
import cv2
img = cv2.imread('your_input_img.jpg')
shrunk = 1 #若圖太大,會很慢,可以縮小試試看效果就好了,這邊要是0~1之間
img = cv2.resize(img, (int(img.shape[1]*shrunk),int(img.shape[0]*shrunk)), interpolation=cv2.INTER_CUBIC)
img = img/255.
class kmeans:
def __init__(self,k):
self.k = k
def fit(self, img, save):
channel = 3
means = np.random.rand(self.k, channel)
while True:
tag_map = np.zeros((img.shape[0], img.shape[1]))
new_means = np.zeros((self.k, channel))
cnt = np.zeros((self.k))
for r in range(img.shape[0]):
for c in range(img.shape[1]):
min_dist = 1000000.
tag = -1
for k in range(self.k):
if (img[r][c] - means[k]).dot(img[r][c] - means[k]) < min_dist:
min_dist = (img[r][c] - means[k]).dot(img[r][c] - means[k])
tag = k
tag_map[r][c] = tag
new_means[tag] += img[r][c]
cnt[tag] += 1
for k in range(self.k):
if cnt[k] == 0:
new_means[k] = 0.
else:
new_means[k] /= float(cnt[k])
if (np.absolute(new_means - means) < np.ones((self.k, channel)) * 0.003).all():
break
new_var = np.zeros((self.k, channel, channel))
for r in range(img.shape[0]):
for c in range(img.shape[1]):
for k in range(self.k):
new_var[k]+=float(tag_map[r][c] == k) * (img[r][c].reshape(channel,1) - new_means[k].reshape(channel,1)).dot((img[r][c].reshape(channel,1) - new_means[k].reshape(channel,1)).T)
for k in range(self.k):
if cnt[k] == 0:
new_var[k] = np.zeros((channel, channel))
else:
new_var[k] /= float(cnt[k])
means = np.copy(new_means)
self.var = np.copy(new_var)
self.means = np.copy(means)
self.pi = cnt / float(img.shape[0]*img.shape[1])
if save == 1:
tmp = np.copy(img)
for r in range(img.shape[0]):
for c in range(img.shape[1]):
tmp[r][c] = means[int(tag_map[r][c])] * 255.
cv2.imwrite('kmean_%d.jpg'%self.k, tmp)
def get(self):
return self.means, self.pi, self.var
class GMM:
def __init__(self, k):
self.k = k
def fit(self, img, save):
def normal_d(self,x, mean, cov):
if det(cov) != 0:
return ((2*np.pi)**(-1*self.k/2.)) * (det(cov)**-0.5) * np.exp(-0.5 * ((x - mean).T.dot(inv(cov)).dot(x - mean)))
else:
cov = cov + 0.0001*np.identity(3)
return ((2*np.pi)**(-1*self.k/2.)) * (det(cov)**-0.5) * np.exp(-0.5 * ((x - mean).T.dot(inv(cov)).dot(x - mean)))
def log_likelihood(self, img, pi, mean, var):
log_like = 0.
for r in range(img.shape[0]):
for c in range(img.shape[1]):
tmp = 0.
xn = img[r][c].reshape(channel,1)
for k in range(self.k):
tmp += pi[k] * normal_d(self ,xn, mean[k].reshape(channel,1), var[k])
log_like += np.log(tmp)
return log_like
channel = 3
#init
k_means = kmeans(self.k)
k_means.fit(img, save)
means, pi, var = k_means.get()
#training
rnk = np.zeros((img.shape[0], img.shape[1],self.k))
epoch = 0
ex = []
lx = []
while True:
lk = log_likelihood(self, img, pi, means, var)
print epoch, lk
ex += [epoch]
lx += [lk.reshape(1)]
nk = np.zeros(self.k)
new_means = np.zeros((self.k, channel))
for r in range(img.shape[0]):
for c in range(img.shape[1]):
xn = img[r][c].reshape(channel,1)
mix_gau = np.zeros((self.k, 1))
for k in range(self.k):
mix_gau[k] = pi[k] * normal_d(self, xn, means[k].reshape(channel,1), var[k])
sum_ = np.sum(mix_gau)
for k in range(self.k):
rnk[r][c][k] = mix_gau[k] / sum_
nk[k] += rnk[r][c][k]
new_means[k] += rnk[r][c][k] * xn.reshape(channel)
for k in range(self.k):
new_means[k] = np.copy(new_means[k] / nk[k])
means = new_means #update means
new_var = np.zeros((self.k, channel, channel))
for r in range(img.shape[0]):
for c in range(img.shape[1]):
xn = img[r][c].reshape(channel,1)
for k in range(self.k):
new_var[k] += rnk[r][c][k] * (xn - means[k].reshape(channel,1)).dot((xn - means[k].reshape(channel,1)).T)
new_pi = np.zeros(self.k)
for k in range(self.k):
new_var[k] /= nk[k]
new_pi[k] = nk[k] / float(img.shape[0]*img.shape[1])
pi = np.copy(new_pi)
var = np.copy(new_var)
epoch += 1
if epoch > 100:
break
self.mean = means
self.var = var
self.pi = pi
if save == 1:
tmp = np.copy(img)
for r in range(img.shape[0]):
for c in range(img.shape[1]):
xn = tmp[r][c].reshape(channel,1)
p = np.zeros(self.k)
for k in range(self.k):
p[k] = normal_d(self, xn, self.mean[k].reshape(channel,1), self.var[k])
tmp[r][c] = self.mean[np.argmax(p)] * 255.
cv2.imwrite('gmm_%d.jpg'%self.k, tmp)
plt.figure()
plt.plot(ex, lx, linestyle = '--')
plt.savefig('log_likelihood_%d.png' %self.k)
for K in [2,3,5]:
gmm = GMM(K)
gmm.fit(img, 1)