-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathkmeans2.m
149 lines (130 loc) · 6.34 KB
/
kmeans2.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
% Fast version of kmeans clustering.
%
% Cluster the N x p matrix X into k clusters using the kmeans algorithm. It
% returns the cluster memberships for each data point in the N x 1 vector
% IDX and the K x p matrix of cluster means in C.
%
% This function is in some ways it is less general then Matlab's kmeans.m
% (for example only uses euclidian distance), but it has some options that
% the Matlab version does not (for example, it has a notion of outliers and
% min-cluster size). It is also many times faster than matlab's kmeans.
% General kmeans help can be found in help for the matlab implementation of
% kmeans. Note that the although the names and conventions for this
% algorithm are taken from Matlab's implementation, there are slight
% alterations (for example, IDX==-1 is used to indicate outliers).
%
% IDX is a n-by-1 vector used to indicated cluster membership. Let X be a
% set of n points. Then the ID of X - or IDX is a column vector of length
% n, where each element is an integer indicating the cluster membership of
% the corresponding element in X. IDX(i)=c indicates that the ith point in
% X belongs to cluster c. Cluster labels range from 1 to k, and thus
% k=max(IDX) is typically the number of clusters IDX divides X into. The
% cluster label "-1" is reserved for outliers. IDX(i)==-1 indicates that
% the given point does not belong to any of the discovered clusters. Note
% that matlab's version of kmeans does not have outliers.
%
% USAGE
% [ IDX, C, sumd ] = kmeans2( X, k, [prm] )
%
% INPUTS
% X - [n x p] matrix of n p-dim vectors.
% k - maximum nuber of clusters (actual number may be smaller)
% prm - parameters struct (all are optional)
% .k - [] alternate way of specifying k (if not given above)
% .nTrial - [1] number random restarts
% .maxIter - [100] max number of iterations
% .display - [0] Whether or not to display algorithm status
% .rndSeed - [] random seed for kmeans; useful for replicability
% .outFrac - [0] max frac points that can be treated as outliers
% .minCl - [1] min cluster size (smaller clusters get eliminated)
% .metric - [] metric for pdist2
%
% OUTPUTS
% IDX - [n x 1] cluster membership (see above)
% C - [k x p] matrix of centroid locations C(j,:) = mean(X(IDX==j,:))
% sumd - [1 x k] sumd(j) is sum of distances from X(IDX==j,:) to C(j,:)
% sum(sumd) is a typical measure of the quality of a clustering
%
% EXAMPLE
%
% See also DEMOCLUSTER
% Piotr's Image&Video Toolbox Version 2.0
% Written and maintained by Piotr Dollar pdollar-at-cs.ucsd.edu
% Please email me if you find bugs, or have suggestions or questions!
function [ IDX, C, sumd ] = kmeans2( X, k, prm )
%%% get input args
dfs = {'nTrial',1, 'maxIter',100, 'display',0, 'rndSeed',[],...
'outFrac',0, 'minCl',1, 'metric',[] };
if(isempty(k)); dfs={dfs{:} 'k', 'REQ'}; end;
if nargin<3 || isempty(prm); prm=struct(); end
prm = getPrmDflt( prm, dfs );
nTrial =prm.nTrial; maxIter =prm.maxIter; display =prm.display;
rndSeed =prm.rndSeed; outFrac =prm.outFrac; minCl =prm.minCl;
metric =prm.metric;
if(isempty(k)); k=prm.k; end;
% error checking
if(k<1); error('k must be greater than 1'); end
if(ndims(X)~=2 || any(size(X)==0)); error('Illegal X'); end
if(outFrac<0 || outFrac>=1)
error('fraction of outliers must be between 0 and 1'); end
nOutl = floor( size(X,1)*outFrac );
% initialize random seed if specified
if( ~isempty(rndSeed)); rand('state',rndSeed); end;
% run kmeans2main nTrial times
msg = ['Running kmeans2 with k=' num2str(k)];
if( nTrial>1); msg=[msg ', ' num2str(nTrial) ' times.']; end
if( display); disp(msg); end
bstSumd = inf;
for i=1:nTrial
tic
msg = ['kmeans iteration ' num2str(i) ' of ' num2str(nTrial) ', step: '];
if( display); disp(msg); end
[IDX,C,sumd,nIter]=kmeans2main(X,k,nOutl,minCl,maxIter,display,metric);
if( sum(sumd)<sum(bstSumd)); bstIDX=IDX; bstC=C; bstSumd=sumd; end
msg = ['\nCompleted iter ' num2str(i) ' of ' num2str(nTrial) '; ' ...
'num steps= ' num2str(nIter) '; sumd=' num2str(sum(sumd)) '\n'];
if( display && nTrial>1 ); fprintf(msg); toc, end
end
IDX = bstIDX; C = bstC; sumd = bstSumd; k = max(IDX);
msg = ['Final num clusters = ' num2str( k ) '; sumd=' num2str(sum(sumd))];
if(display); if(nTrial==1); fprintf('\n'); end; disp(msg); end
% sort IDX to have biggest clusters have lower indicies
cnts = zeros(1,k); for i=1:k; cnts(i) = sum( IDX==i ); end
[ids,order] = sort( -cnts ); C = C(order,:); sumd = sumd(order);
IDX2 = IDX; for i=1:k; IDX2(IDX==order(i))=i; end; IDX = IDX2;
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function [ IDX, C, sumd, nIter ] = kmeans2main( X, k, nOutl, ...
minCl, maxIter, display, metric )
% initialize cluster centers to be k random X points
[N p] = size(X); k = min(k,N);
IDX = ones(N,1); oldIDX = zeros(N,1);
index = randsample(N,k); C = X(index,:);
% MAIN LOOP: loop until the cluster assigments do not change
nIter = 0; ndisdigits = ceil( log10(maxIter-1) );
if( display ); fprintf( ['\b' repmat( '0',[1,ndisdigits] )] ); end
while( sum(abs(oldIDX - IDX)) ~= 0 && nIter < maxIter)
% assign each point to closest cluster center
oldIDX = IDX; D = pdist2( X, C, metric ); [mind IDX] = min(D,[],2);
% do not use most distant nOutl elements in computation of centers
mindsort = sort(mind); thr = mindsort(end-nOutl); IDX(mind > thr) = -1;
% discard small clusters [add to outliers, will get included next iter]
i=1; while(i<=k); if(sum(IDX==i)<minCl); IDX(IDX==i)=-1;
if(i<k); IDX(IDX==k)=i; end; k=k-1; else i=i+1; end; end
if( k==0 ); IDX( randint2( 1,1, [1,N] ) ) = 1; k=1; end;
for i=1:k;
if((sum(IDX==i))==0); error('internal error - empty cluster'); end
end;
% Recalculate means based on new assignment (faster than looping over k)
C = zeros(k,p); cnts = zeros(k,1);
for i=find(IDX>0)'
IDx = IDX(i); cnts(IDx)=cnts(IDx)+1;
C(IDx,:) = C(IDx,:)+X(i,:);
end
C = C ./ cnts(:,ones(1,p));
nIter = nIter+1;
if( display )
fprintf( [repmat('\b',[1 ndisdigits]) int2str2(nIter,ndisdigits)] );
end;
end
% record within-cluster sums of point-to-centroid distances
sumd = zeros(1,k); for i=1:k; sumd(i) = sum( mind(IDX==i) ); end