forked from PRML/PRMLT
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathkmeans.m
31 lines (30 loc) · 1.02 KB
/
kmeans.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
function [label, mu, energy] = kmeans(X, m)
% Perform kmeans clustering.
% Input:
% X: d x n data matrix
% m: initialization parameter
% Output:
% label: 1 x n sample labels
% mu: d x k center of clusters
% energy: optimization target value
% Written by Mo Chen ([email protected]).
label = init(X, m);
n = numel(label);
idx = 1:n;
last = zeros(1,n);
while any(label ~= last)
[~,~,last(:)] = unique(label); % remove empty clusters
mu = X*normalize(sparse(idx,last,1),1); % compute cluster centers
[val,label] = min(dot(mu,mu,1)'/2-mu'*X,[],1); % assign sample labels
end
energy = dot(X(:),X(:),1)+2*sum(val);
function label = init(X, m)
[d,n] = size(X);
if numel(m) == 1 % random initialization
mu = X(:,randperm(n,m));
[~,label] = min(dot(mu,mu,1)'/2-mu'*X,[],1);
elseif all(size(m) == [1,n]) % init with labels
label = m;
elseif size(m,1) == d % init with seeds (centers)
[~,label] = min(dot(m,m,1)'/2-m'*X,[],1);
end