Skip to content

Commit

Permalink
tweak kmeans
Browse files Browse the repository at this point in the history
  • Loading branch information
sth4nth committed Mar 11, 2017
1 parent 3979a23 commit 3ca6be8
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 27 deletions.
16 changes: 7 additions & 9 deletions chapter06/knKmeans.m
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
function [label, energy, model] = knKmeans(X, init, kn)
% Perform kernel k-means clustering.
function [label, model, energy] = knKmeans(X, init, kn)
% Perform kernel kmeans clustering.
% Input:
% K: n x n kernel matrix
% init: either number of clusters (k) or initial label (1xn)
Expand All @@ -21,15 +21,13 @@
kn = @knGauss;
end
K = kn(X,X);
last = 0;
last = zeros(1,n);
while any(label ~= last)
[u,~,label(:)] = unique(label); % remove empty clusters
k = numel(u);
E = sparse(label,1:n,1,k,n,n);
E = spdiags(1./sum(E,2),0,k,k)*E;
[~,~,last(:)] = unique(label); % remove empty clusters
E = sparse(last,1:n,1);
E = E./sum(E,2);
T = E*K;
last = label;
[val, label] = max(bsxfun(@minus,T,diag(T*E')/2),[],1);
[val, label] = max(T-diag(T*E')/2,[],1);
end
energy = trace(K)-2*sum(val);
if nargout == 3
Expand Down
20 changes: 9 additions & 11 deletions chapter09/kmeans.m
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
function [label, energy, model] = kmeans(X, init)
% Perform k-means clustering.
function [label, m, energy] = kmeans(X, init)
% Perform kmeans clustering.
% Input:
% X: d x n data matrix
% init: k number of clusters or label (1 x n vector)
Expand All @@ -9,20 +9,18 @@
% model: trained model structure
% Written by Mo Chen ([email protected]).
n = size(X,2);
idx = 1:n;
last = zeros(1,n);
if numel(init)==1
k = init;
label = ceil(k*rand(1,n));
elseif numel(init)==n
label = init;
end
last = 0;
while any(label ~= last)
[u,~,label(:)] = unique(label); % remove empty clusters
k = numel(u);
E = sparse(1:n,label,1,n,k,n); % transform label into indicator matrix
m = X*(E*spdiags(1./sum(E,1)',0,k,k)); % compute centers
last = label;
[val,label] = max(bsxfun(@minus,m'*X,dot(m,m,1)'/2),[],1); % assign labels
[~,~,last(:)] = unique(label); % remove empty clusters
E = sparse(idx,last,1); % transform label into indicator matrix
m = X*(E./sum(E,1)); % compute centers
[val,label] = min(dot(m,m,1)'/2-m'*X,[],1); % assign labels
end
energy = dot(X(:),X(:))-2*sum(val);
model.means = m;
energy = dot(X(:),X(:),1)+2*sum(val);
8 changes: 4 additions & 4 deletions chapter09/kmeansPred.m
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
function [label, energy] = kmeansPred(model, Xt)
function [label, energy] = kmeansPred(m, X)
% Prediction for kmeans clusterng
% Input:
% model: trained model structure
% Xt: d x n testing data
% model: dx k cluster center matrix
% X: d x n testing data
% Output:
% label: 1 x n cluster label
% energy: optimization target value
% Written by Mo Chen ([email protected]).
[val,label] = min(sqdist(model.means, Xt));
[val,label] = min(dot(X,X,1)+dot(m,m,1)'-2*m'*X,[],1); % assign labels
energy = sum(val);
4 changes: 2 additions & 2 deletions demo/ch06/knLin_demo.m
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
n = 500;
[X,y] = kmeansRnd(d,k,n);
init = ceil(k*rand(1,n));
[y_kn,en_kn,model_kn] = knKmeans(X,init,@knLin);
[y_lin,en_lin,model_lin] = kmeans(X,init);
[y_kn,model_kn,en_kn] = knKmeans(X,init,@knLin);
[y_lin,model_lin,en_lin] = kmeans(X,init);

idx = 1:2:n;
Xt = X(:,idx);
Expand Down
2 changes: 1 addition & 1 deletion demo/ch09/kmeans_demo.m
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
k = 3;
n = 5000;
[X,label] = kmeansRnd(d,k,n);
y = kmeans(X,k);
y = litekmeans(X,k);
plotClass(X,label);
figure;
plotClass(X,y);

0 comments on commit 3ca6be8

Please sign in to comment.