forked from PRML/PRMLT
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathGaussWishart.m
108 lines (90 loc) · 2.76 KB
/
GaussWishart.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
% Class for Gaussian-Wishart distribution used by Dirichlet process
classdef GaussWishart
properties
kappa_
m_
nu_
U_
end
methods
function obj = GaussWishart(kappa,m,nu,S)
U = chol(S+kappa*(m*m'));
obj.kappa_ = kappa;
obj.m_ = m;
obj.nu_ = nu;
obj.U_ = U;
end
function obj = clone(obj)
end
function d = dim(obj)
d = numel(obj.m_);
end
function obj = addData(obj, X)
kappa0 = obj.kappa_;
m0 = obj.m_;
nu0 = obj.nu_;
U0 = obj.U_;
n = size(X,2);
kappa = kappa0+n;
m = (kappa0*m0+sum(X,2))/kappa;
nu = nu0+n;
U = chol(U0'*U0+X*X');
obj.kappa_ = kappa;
obj.m_ = m;
obj.nu_ = nu;
obj.U_ = U;
end
function obj = addSample(obj, x)
kappa = obj.kappa_;
m = obj.m_;
nu = obj.nu_;
U = obj.U_;
kappa = kappa+1;
m = m+(x-m)/kappa;
nu = nu+1;
U = cholupdate(U,x,'+');
obj.kappa_ = kappa;
obj.m_ = m;
obj.nu_ = nu;
obj.U_ = U;
end
function obj = delSample(obj, x)
kappa = obj.kappa_;
m = obj.m_;
nu = obj.nu_;
U = obj.U_;
kappa = kappa-1;
m = m-(x-m)/kappa;
nu = nu-1;
U = cholupdate(U,x,'-');
obj.kappa_ = kappa;
obj.m_ = m;
obj.nu_ = nu;
obj.U_ = U;
end
function y = logPredPdf(obj,X)
kappa = obj.kappa_;
m = obj.m_;
nu = obj.nu_;
U = obj.U_;
d = size(X,1);
v = (nu-d+1);
U = sqrt((1+1/kappa)/v)*cholupdate(U,sqrt(kappa)*m,'-');
X = bsxfun(@minus,X,m);
Q = U'\X;
q = dot(Q,Q,1); % quadratic term (M distance)
o = -log(1+q/v)*((v+d)/2);
c = gammaln((v+d)/2)-gammaln(v/2)-(d*log(v*pi)+2*sum(log(diag(U))))/2;
y = c+o;
end
function [mu, Sigma] = sample(obj)
% Sample a Gaussian distribution from GaussianWishart prior
kappa = obj.kappa_;
m = obj.m_;
nu = obj.nu_;
U = obj.U_;
Sigma = iwishrnd(U'*U,nu);
mu = gaussRnd(m,Sigma/kappa);
end
end
end