Skip to content

Commit

Permalink
rvmBinEm is done
Browse files Browse the repository at this point in the history
  • Loading branch information
sth4nth committed Jan 12, 2016
1 parent 8151e05 commit d404911
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 46 deletions.
71 changes: 41 additions & 30 deletions chapter09/demo.m
Original file line number Diff line number Diff line change
Expand Up @@ -9,39 +9,50 @@
% [x,t] = linRnd(d,n);
% [model,llh] = linRegEm(x,t);
% plot(llh);
%%
d = 512; % signal length
k = 20; % number of spikes
n = 100; % number of measurements
%
% random +/- 1 signal
x = zeros(d,1);
q = randperm(d);
x(q(1:k)) = sign(randn(k,1));

% projection matrix
A = unitize(randn(d,n),1);
% noisy observations
sigma = 0.005;
e = sigma*randn(1,n);
y = x'*A + e;
[model,llh] = rvmRegEm(A,y);
plot(llh);


% [model,llh] = rvmRegEbFp(A,y);
%% demo: sparse signal recovery
% d = 512; % signal length
% k = 20; % number of spikes
% n = 100; % number of measurements
% %
% % random +/- 1 signal
% x = zeros(d,1);
% q = randperm(d);
% x(q(1:k)) = sign(randn(k,1));
%
% % projection matrix
% A = unitize(randn(d,n),1);
% % noisy observations
% sigma = 0.005;
% e = sigma*randn(1,n);
% y = x'*A + e;
% [model,llh] = rvmRegEm(A,y);
% plot(llh);
m = zeros(d,1);
m(model.index) = model.w;
%
%
% % [model,llh] = rvmRegEbFp(A,y);
% % plot(llh);
% m = zeros(d,1);
% m(model.index) = model.w;
%
% h = max(abs(x))+0.2;
% x_range = [1,d];
% y_range = [-h,+h];
% figure;
% subplot(2,1,1);plot(x); axis([x_range,y_range]); title('Original Signal');
% subplot(2,1,2);plot(m); axis([x_range,y_range]); title('Recovery Signal');
%% classification
clear; close all
k = 2;
d = 2;
n = 1000;
[X,t] = kmeansRnd(d,k,n);
[x1,x2] = meshgrid(linspace(min(X(1,:)),max(X(1,:)),n), linspace(min(X(2,:)),max(X(2,:)),n));

h = max(abs(x))+0.2;
x_range = [1,d];
y_range = [-h,+h];
[model, llh] = rvmBinEm(X,t-1);
plot(llh);
y = rvmBinPred(model,X)+1;
figure;
subplot(2,1,1);plot(x); axis([x_range,y_range]); title('Original Signal');
subplot(2,1,2);plot(m); axis([x_range,y_range]); title('Recovery Signal');


binPlot(model,X,y);
%% demo: kmeans
% close all; clear;
% d = 2;
Expand Down
71 changes: 55 additions & 16 deletions chapter09/rvmBinEm.m
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
function [model, llh] = rvmBinEm(X, t, alpha)
% Relevance Vector Machine (ARD sparse prior) for binary classification
% training by empirical bayesian (type II ML) using standard EM update
% training by empirical bayesian (type II ML) using fix point update (Mackay update)
% Written by Mo Chen ([email protected]).
if nargin < 3
alpha = 1;
Expand All @@ -9,29 +9,68 @@
X = [X;ones(1,n)];
d = size(X,1);
alpha = alpha*ones(d,1);
weight = zeros(d,1);
m = zeros(d,1);

tol = 1e-4;
maxiter = 100;
llh = -inf(1,maxiter);
infinity = 1e+10;
index = 1:d;
for iter = 2:maxiter
used = alpha < infinity;
a = alpha(used);
w = weight(used);
[w,energy,U] = optLogitNewton(X(used,:),t,a,w);
w2 = w.^2;
llh(iter) = energy(end)+0.5*(sum(log(a))-2*sum(log(diag(U)))-dot(a,w2)-n*log(2*pi)); % 7.114
if abs(llh(iter)-llh(iter-1)) < tol*llh(iter-1); break; end
% remove zeros
nz = 1./alpha > tol; % nonzeros
index = index(nz);
alpha = alpha(nz);
X = X(nz,:);
m = m(nz);

[m,e,U] = logitBin(X,t,alpha,m); % 7.110 ~ 7.113

m2 = m.^2;
llh(iter) = e(end)+0.5*(sum(log(alpha))-2*sum(log(diag(U)))-dot(alpha,m2)-n*log(2*pi)); % 7.114 & 7.118
if abs(llh(iter)-llh(iter-1)) < tol*abs(llh(iter-1)); break; end

V = inv(U);
dgS = dot(V,V,2);
alpha(used) = 1./(w2+dgS); % 9.67
weight(used) = w;
alpha = 1./(m2+dgS); % 9.67
end
llh = llh(2:iter);

model.used = used;
model.w = w; % nonzero elements of weight
model.a = a; % nonzero elements of alpha
model.weght = weight;
model.index = index;
model.w = m;
model.alpha = alpha;

function [w, llh, U] = logitBin(X, t, lambda, w)
% Logistic regression
[d,n] = size(X);
tol = 1e-4;
maxiter = 100;
llh = -inf(1,maxiter);
idx = (1:d)';
dg = sub2ind([d,d],idx,idx);
h = ones(1,n);
h(t==0) = -1;
a = w'*X;
for iter = 2:maxiter
y = sigmoid(a); % 4.87
r = y.*(1-y); % 4.98
Xw = bsxfun(@times, X, sqrt(r));
H = Xw*Xw'; % 4.97
H(dg) = H(dg)+lambda;
U = chol(H);
g = X*(y-t)'+lambda.*w; % 4.96
p = -U\(U'\g);
wo = w; % 4.92
w = wo+p;
a = w'*X;
llh(iter) = -sum(log1pexp(-h.*a))-0.5*sum(lambda.*w.^2); % 4.89
incr = llh(iter)-llh(iter-1);
while incr < 0 % line search
p = p/2;
w = wo+p;
a = w'*X;
llh(iter) = -sum(log1pexp(-h.*a))-0.5*sum(lambda.*w.^2);
incr = llh(iter)-llh(iter-1);
end
if incr < tol; break; end
end
llh = llh(2:iter);
7 changes: 7 additions & 0 deletions common/log1pexp.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
function y = log1pexp(x)
% accurately compute y = log(1+exp(x))
% reference: Accurately Computing log(1-exp(|a|)) Martin Machler
seed = 33.3;
y = x;
idx = x<seed;
y(idx) = log1p(exp(x(idx)));

0 comments on commit d404911

Please sign in to comment.