Skip to content

Commit 9e59c07

Browse files
committed
simplified logGauss, removed other cases
1 parent f795d91 commit 9e59c07

File tree

1 file changed

+15
-36
lines changed

1 file changed

+15
-36
lines changed

chapter02/logGauss.m

Lines changed: 15 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,21 @@
1-
function y = logGauss(X, mu, sigma)
1+
function y = logGauss(X, mu, Sigma)
22
% Compute log pdf of a Gaussian distribution.
33
% Input:
44
% X: d x n data matrix
5-
% mu: mean of Gaussian
6-
% sigma: variance of Gaussian
5+
% mu: d x 1 mean vector of Gaussian
6+
% Sigma: d x d covariance matrix of Gaussian
77
% Output:
8-
% y: probability density in logrithm scale y=log p(x)
8+
% y: 1 x n probability density in logrithm scale y=log p(x)
99
% Written by Mo Chen ([email protected]).
10-
11-
[d,n] = size(X);
12-
k = size(mu,2);
13-
if n == k && size(sigma,1) == 1
14-
X = bsxfun(@times,X-mu,1./sigma);
15-
q = dot(X,X,1); % M distance
16-
c = d*log(2*pi)+2*log(sigma); % normalization constant
17-
y = -0.5*(c+q);
18-
elseif size(sigma,1)==d && size(sigma,2)==d && k==1 % one mu and one dxd sigma
19-
X = bsxfun(@minus,X,mu);
20-
[R,p]= chol(sigma);
21-
if p ~= 0
22-
error('ERROR: sigma is not PD.');
23-
end
24-
Q = R'\X;
25-
q = dot(Q,Q,1); % quadratic term (M distance)
26-
c = d*log(2*pi)+2*sum(log(diag(R))); % normalization constant
27-
y = -0.5*(c+q);
28-
elseif size(sigma,1)==d && size(sigma,2)==k % k mu and k diagonal sigma
29-
lambda = 1./sigma;
30-
ml = mu.*lambda;
31-
q = bsxfun(@plus,X'.^2*lambda-2*X'*ml,dot(mu,ml,1)); % M distance
32-
c = d*log(2*pi)+2*sum(log(sigma),1); % normalization constant
33-
y = -0.5*bsxfun(@plus,q,c);
34-
elseif size(sigma,1)==1 && (size(sigma,2)==k || size(sigma,2)==1) % k mu and (k or one) scalar sigma
35-
X2 = repmat(dot(X,X,1)',1,k);
36-
D = bsxfun(@plus,X2-2*X'*mu,dot(mu,mu,1));
37-
q = bsxfun(@times,D,1./sigma); % M distance
38-
c = d*(log(2*pi)+2*log(sigma)); % normalization constant
39-
y = -0.5*bsxfun(@plus,q,c);
40-
else
41-
error('Parameters are mismatched.');
10+
[d,k] = size(mu);
11+
assert(all(size(Sigma)==d) && k==1) % one mu and one dxd Sigma
12+
X = bsxfun(@minus,X,mu);
13+
[R,p]= chol(Sigma);
14+
if p ~= 0
15+
error('ERROR: Sigma is not PD.');
4216
end
17+
Q = R'\X;
18+
q = dot(Q,Q,1); % quadratic term (M distance)
19+
c = d*log(2*pi)+2*sum(log(diag(R))); % normalization constant
20+
y = -0.5*(c+q);
21+

0 commit comments

Comments
 (0)