Ink One

GMM的EM算法实现

matlab下混合高斯模型的EM算法实现。

参数说明

N:样本数
M:Gauss数

高斯模型

模型表示

$$N(\mathbf{x;\mu,\Sigma})=\frac{1}{(2\pi)^{\frac{d}{2}}|\mathbf{\Sigma}|^{\frac{1}{2}}\exp{-\frac{1}{2}(\mathbf{x-\mu})^T\mathbf{\Sigma}^{-1}(\mathbf{x-\mu})}}$$
其中,d 为维数。

最大似然估计

$$L(\theta)=\sum_{n=1}^N\log N(\mathbf{x}_n;\mathbf{\mu,\Sigma})=-\frac{1}{2}\sum_{n=1}^N{d\log(2\pi)+\log|\mathbf{\Sigma}|+(\mathbf{x}_n-\mathbf{\mu})^T\mathbf{\Sigma}^{-1}(\mathbf{x}_n-\mathbf{\mu})}$$

得到的参数估计为:
$$\mathbf{\hat\mu}=\frac{1}{N}\sum_{n=1}^N\mathbf{x}_n$$
$$\mathbf{\hat\Sigma}=\frac{1}{N}\sum_{n=1}^N(\mathbf{x}_n-\mathbf{\hat\mu})(\mathbf{x}_n-\mathbf{\hat\mu})^T$$

混合高斯模型

模型表示

$$p(\mathbf{x})=\sum_{m=1}^Mc_mN(\mathbf{x;\mu}_m,\mathbf{\Sigma}_m))$$
$$s.t. \sum_{m=1}^Mc_m=1$$

最大似然估计

$$L(\theta)=\sum_{n=1}^N\log\left(\sum_{m=1}^Mc_mN(\mathbf{x}_n;\mathbf{\mu}_m,\mathbf{\Sigma}_m)\right)$$
$$\gamma_m(n)=P(m;\mathbf{x_n},\hat\theta)=\frac{p(\mathbf{x}_n;m,\hat\theta)P(m;\hat\theta)}{\sum\limits_{k=1}^Mp(\mathbf{x}_n;k,\hat\theta)P(k;\hat\theta)}$$
$$Q(\theta;\hat\theta)=K+\sum_{n=1}^N\sum_{m+1}^M\gamma_m(n)\log c_m-\frac{1}{2}\sum_{n=1}^N\sum_{m+1}^M\gamma_m(n)\left(\log|\mathbf{\Sigma}_m|+(\mathbf{x}_n-\mathbf{\mu}_m)^T\mathbf{\Sigma}_m^{-1}(\mathbf{x}_n-\mathbf{\mu}_m)\right)$$

EM算法步骤

  1. 初始化参数$k=0$。假设只有单Gauss,估计其模型参数$\mathbf{\mu}_0$,$\mathbf{\Sigma}_0$。对于混合高斯模型中的M个Gauss,采用全局初始化参数$\mathbf{\mu}_0$,$\mathbf{\Sigma}_0$加一随机移动(使得M个Gauss不同,否则等同于单Gauss),作为它们的初始化参数。
  2. 迭代次数$k:=k+1$。对于每个样本$\mathbf{x}_n$,采用$\theta^{k-1}$ 计算后验概率$\gamma_m^{(k)}(n)$.
    $$\gamma_m^{(k)}(n)=\frac{c_m^{(k-1)}N(\mathbf{x}_n;\mathbf{\mu}_m^{(k-1)},\mathbf{\Sigma}_m^{(k-1)})}{\sum_{j=1}^Mc_m^{(k-1)}N(\mathbf{x}_n;\mathbf{\mu}_j^{(k-1)},\mathbf{\Sigma}_j^{(k-1)})}$$
  3. 更新模型参数$\theta^{(k)}={c_m^{(k)},\mathbf{\mu}_m^{(k)},\mathbf{\Sigma}_m^{(k)}}$
    $$\gamma_m=\sum_{n=1}^N\gamma_m(n)$$
    $$\mathbf{\hat\mu}_m=\frac{1}{\gamma_m}\sum_{n=1}^N\gamma_m(n)\mathbf{x}_n$$
    $$\mathbf{\hat\Sigma}_m=\frac{1}{\gamma_m}\sum_{n=1}^N\gamma_m(n)(\mathbf{x}_n-\mathbf{\hat\mu}_m)(\mathbf{x}_n-\mathbf{\hat\mu}_m)^T$$
    $$\hat c_m=\frac{\gamma_m}{\sum\limits_{m=1}^M\gamma_m}$$
  4. 重复步骤 2,3 直至模型收敛。

程序实现

初始化参数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
function [mu, sigma] = initialize(x, m, k)
% USAGE: [mu, sigma] = initialize(x, m[, k=0.1])
% Initialize GMM parameters.
% Input:
% x -> features [#samples * #feature dims]
% m -> #Gauss_in_GMM
% k -> random shift range, 0.1 for default
% Output:
% mu -> mean [#feature dims * #Gauss_in_GMM]
% sigma -> covariance [#feature dims * #feature dims * #Gauss_in_GMM]
if ~exist('k', 'var') || isempty(k)
k = 0.1;
end
% initialize mu, sigma
[n, d] = size(x);
mu = zeros(d, m);
sigma = zeros(d, d, m);
% suppose there is one Gauss, compute it's mu and sigma
init_mu = mean(x,1);
init_sigma = 1/n * (x - ones(n,1)*init_mu)' * (x - ones(n,1)*init_mu);
% add random shift for every Gauss' parameters
for i = 1:m
mu(:,i) = init_mu' * (1+k*(2*rand-1));
sigma(:,:,i) = init_sigma * (1+k*(2*rand-1));
end
end

EM 算法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
function [mu, sigma, c] = gmm_em(x, m, min_inc)
% USAGE: [mu, sigma, c] = gmm_em(x, m[, min_inc=0.01])
% Train GMM with EM algorithm.
% Input:
% x -> features [#samples * #feature dims]
% m -> #Gauss_in_GMM
% min_inc -> minimal increment of iteration, 0.01 for default
% Output:
% mu -> means of GMM [#feature dims * #Gauss_in_GMM]
% sigma -> covariances of GMM [#feature dims * #feature dims * #Gauss_in_GMM]
% c -> mixture weights of GMM [1 * #Gauss_in_GMM]
if ~exist('min_inc', 'var') || isempty(min_inc)
min_inc = 0.01;
end
% Initialize m Gauss models' parameters
[n, d] = size(x);
[mu, sigma] = initialize(x, m);
c = 1/m * ones(1, m);
p = zeros(n,m);
Q = -inf;
iter = 0;
while(1)
iter = iter + 1;
% E step
% calculate the probilities of every Gauss
for i = 1:m
p(:,i) = gauss(x, mu(:,i)', sigma(:,:,i));
end
gamma_i = (p.*(ones(n,1)*c))./(p*c'*ones(1,m)); % n*m
% M step
% update GMM parameters
gamma = sum(gamma_i, 1); % 1*m
mu = ones(d,1)*(1./gamma) .* (x'*gamma_i); % d*m
for i = 1:m
sigma(:,:,i) = 1/gamma(i) * ((gamma_i(:,i)*ones(1,d).*(x-ones(n,1)*mu(:,i)'))'*(x-ones(n,1)*mu(:,i)'));
end
c = gamma/sum(gamma);
Q_last = Q;
% compute Q
q = 0;
for i = 1:n
for j = 1:m
q = q + gamma_i(i,j) * (log(det(sigma(:,:,j)))+(x(i,:)-mu(:,j)')*inv(sigma(:,:,j))*(x(i,:)-mu(:,j)')');
end
end
Q = sum(sum(gamma_i.*(ones(n,1)*c))) - 1/2 * q;
fprintf('iteration: %d\t', iter);
fprintf('Q-value: %f\n', Q);
% stop condition
if Q-Q_last < min_inc
break;
end
end