Skip to content

Commit

Permalink
chapter04 is done for now
Browse files Browse the repository at this point in the history
  • Loading branch information
sth4nth committed Dec 8, 2015
1 parent 001c62d commit 60f4700
Show file tree
Hide file tree
Showing 9 changed files with 89 additions and 66 deletions.
5 changes: 2 additions & 3 deletions chapter04/TODO.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
binPlot
multiPlot
demo
multiPlot: plot multclass decison boundary

23 changes: 18 additions & 5 deletions chapter04/binPlot.m
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,24 @@ function binPlot(model, X, t)
% Plot binary classification result for 2d data
% X: 2xn data matrix
% t: 1xn label
assert(size(X,1) == 2);
w = model.w;
w0 = model.w0;
figure;
spread(X,t);
y = w'*X+w0;
xi = min(X,[],2);
xa = max(X,[],2);
[x1,x2] = meshgrid(linspace(xi(1),xa(1)), linspace(xi(2),xa(2)));

color = 'brgmcyk';
m = length(color);
figure(gcf);
axis equal
clf;
hold on;
contour(X(1,:),X(2,:),y,1);
hold off;
view(2);
for i = 1:max(t)
idc = t==i;
scatter(X(1,idc),X(2,idc),36,color(mod(i-1,m)+1));
end
y = w0+w(1)*x1+w(2)*x2;
contour(x1,x2,y,[-0 0]);
hold off;
42 changes: 10 additions & 32 deletions chapter04/demo.m
Original file line number Diff line number Diff line change
@@ -1,39 +1,17 @@

%
clear; close all;
k = 2;
n = 1000;
[X,t] = kmeansRnd(2,k,n);

[x1,x2] = meshgrid(linspace(min(X(1,:)),max(X(1,:)),n), linspace(min(X(2,:)),max(X(2,:)),n));
[model, llh] = logitReg(X,t-1,0);
[y,p] = logitPred(model,X);

w = model.w;
w0 = model.w0;
plot(llh);
figure;
spread(X,t);

y = w(1)*x1+w(2)*x2+w0;

hold on;
contour(x1,x2,y,1);
hold off;
binPlot(model,X,t)
pause
%%
% clear; close all;
% k = 3;
% n = 200;
% [X,t] = rndKCluster(2,k,n);
%
% [x1,x2] = meshgrid(linspace(min(X(1,:)),max(X(1,:)),n), linspace(min(X(2,:)),max(X(2,:)),n));
% [model, llh] = mnReg(X,t, 1e-4,2);
% plot(llh);
% figure;
% spread(X,t);
%
% W = model.W;
% % y = w(1)*x1+w(2)*x2+w(3);
%
% hold on;
% contour(x1,x2,t,1);
% hold off;
clear
k = 3;
n = 1000;
[X,t] = kmeansRnd(2,k,n);
[model, llh] = mnReg(X,t);
y = mnPred(model,X);
spread(X,y)
12 changes: 0 additions & 12 deletions chapter04/multiPlot.m

This file was deleted.

2 changes: 1 addition & 1 deletion chapter06/knCenterize.m
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
% Centerize the data in the kernel space
% kn: kernel function
% X: dxn data matrix of which the center is computed
% Xt(option): dxn test data to be centerized by the center of X
% Xt(optional): dxn test data to be centerized by the center of X
% Written by Mo Chen ([email protected]).
K = kn(X,X);
mK = mean(K);
Expand Down
25 changes: 25 additions & 0 deletions chapter06/knKmeans.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
function [label, energy, model] = knKmeans(X, k, kn)
% Perform kernel k-means clustering.
% K: nxn kernel matrix
% k: number of cluster
% Reference: Kernel Methods for Pattern Analysis
% by John Shawe-Taylor, Nello Cristianini
% Written by Mo Chen ([email protected]).
K = kn(X,X);
n = size(X,2);
label = ceil(k*rand(1,n));
last = 0;
while any(label ~= last)
E = sparse(label,1:n,1,k,n,n);
E = bsxfun(@times,E,1./sum(E,2));
T = E*K;
Z = repmat(diag(T*E'),1,n)-2*T;
last = label;
[val, label] = min(Z,[],1);
end
energy = sum(val)+trace(K);
if nargout == 3
model.X = X;
model.kn = kn;
model.label = label;
end
7 changes: 7 additions & 0 deletions chapter06/knKmeansPred.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
function [ output_args ] = knKmeansPred( input_args )
%KNKMEANSPRED Summary of this function goes here
% Detailed explanation goes here


end

13 changes: 0 additions & 13 deletions chapter06/knPred.m

This file was deleted.

26 changes: 26 additions & 0 deletions chapter06/knRegPred.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
function [y, sigma, p] = knRegPred(model, x, t)
% Prediction for kernel regression model
% Written by Mo Chen ([email protected]).
kn = model.kn;
a = model.a;
X = model.X;
tbar = model.tbar;
y = a'*knCenterize(kn,X,x)+tbar;
if nargin == 3
sigma = sqrt(1/beta+dot(X,X,1)); % 3.59
p = exp(((t-y).^2/sigma2+log(2*pi*sigma2))/(-2));
end

% if nargout > 1
% beta = model.beta;
% if isfield(model,'V') % V*V'=inv(S) 3.54
% U = model.V'*bsxfun(@minus,X,model.xbar);
% sigma = sqrt(1/beta+dot(U,U,1)); % 3.59
% else
% sigma = sqrt(1/beta);
% end
% if nargin == 3 && nargout == 3
% p = exp(logGauss(t,y,sigma));
% % p = exp(-0.5*(((t-y)./sigma).^2+log(2*pi))-log(sigma));
% end
% end

0 comments on commit 60f4700

Please sign in to comment.