7
7
import water .api .DocGen ;
8
8
import water .api .Progress2 ;
9
9
import water .api .Request ;
10
- import water .api .Request .API ;
11
- import water .api .Request .Default ;
12
10
import water .fvec .Chunk ;
13
11
import water .fvec .Frame ;
14
12
import water .fvec .NewChunk ;
@@ -60,12 +58,13 @@ public KMeans2() {
60
58
if ( sourceArg != null )
61
59
sourceKey = Key .make (sourceArg );
62
60
63
- // Drop ignored cols and, if users asks for it, cols with too many NAs
61
+ // Drop ignored cols and, if user asks for it, cols with too many NAs
64
62
Frame fr = DataInfo .prepareFrame (source , ignored_cols , false , false , drop_na_cols );
65
63
String [] names = fr .names ();
66
64
Vec [] vecs = fr .vecs ();
67
65
if (vecs == null || vecs .length == 0 )
68
66
throw new IllegalArgumentException ("No columns selected. Check that selected columns have not been dropped due to too many NAs." );
67
+ DataInfo dinfo = new DataInfo (fr , 0 , false , normalize , false );
69
68
70
69
// Fill-in response based on K99
71
70
String [] domain = new String [k ];
@@ -84,7 +83,7 @@ public KMeans2() {
84
83
means [i ] = (float ) vecs [i ].mean ();
85
84
if ( mults != null ) {
86
85
double sigma = vecs [i ].sigma ();
87
- mults [i ] = normalize (sigma ) ? 1 / sigma : 1 ;
86
+ mults [i ] = normalize (sigma ) ? 1.0 / sigma : 1.0 ;
88
87
}
89
88
}
90
89
@@ -94,7 +93,8 @@ public KMeans2() {
94
93
if ( initialization == Initialization .None ) {
95
94
// Initialize all clusters to random rows
96
95
clusters = new double [k ][vecs .length ];
97
- for (double [] cluster : clusters ) randomRow (vecs , rand , cluster , means , mults );
96
+ for (double [] cluster : clusters )
97
+ randomRow (vecs , rand , cluster , means , mults );
98
98
} else {
99
99
// Initialize first cluster to random row
100
100
clusters = new double [1 ][];
@@ -136,7 +136,10 @@ public KMeans2() {
136
136
task ._clusters = clusters ;
137
137
task ._means = means ;
138
138
task ._mults = mults ;
139
+ task ._ncats = dinfo ._cats ;
140
+ task ._nnums = dinfo ._nums ;
139
141
task .doAll (vecs );
142
+
140
143
model .centers = clusters = normalize ? denormalize (task ._cMeans , vecs ) : task ._cMeans ;
141
144
model .between_cluster_variances = task ._betwnSqrs ;
142
145
double [] variances = new double [task ._cSqrs .length ];
@@ -343,6 +346,7 @@ public static class KMeans2Model extends Model implements Progress {
343
346
// Normalization caches
344
347
private transient double [][] _normClust ;
345
348
private transient double [] _means , _mults ;
349
+ private transient int _ncats , _nnums ;
346
350
347
351
public KMeans2Model (KMeans2 params , Key selfKey , Key dataKey , String names [], String domains [][]) {
348
352
super (selfKey , dataKey , names , domains );
@@ -380,7 +384,8 @@ public KMeans2Model(KMeans2 params, Key selfKey, Key dataKey, String names[], St
380
384
}
381
385
data (tmp , chunks , rowInChunk , _means , _mults );
382
386
Arrays .fill (preds , 0 );
383
- int cluster = closest (cs , tmp , new ClusterDist ())._cluster ;
387
+ // int cluster = closest(cs, tmp, new ClusterDist())._cluster;
388
+ int cluster = closest (cs , tmp , _ncats , new ClusterDist ())._cluster ;
384
389
preds [0 ] = cluster ; // prediction in preds[0]
385
390
preds [1 +cluster ] = 1 ; // class distribution
386
391
return preds ;
@@ -401,13 +406,15 @@ public class Clusters extends MRTask2<Clusters> {
401
406
// IN
402
407
double [][] _clusters ; // Cluster centers
403
408
double [] _means , _mults ; // Normalization
409
+ int _ncats , _nnums ;
404
410
405
411
@ Override public void map (Chunk [] cs , NewChunk ncs ) {
406
412
double [] values = new double [_clusters [0 ].length ];
407
413
ClusterDist cd = new ClusterDist ();
408
414
for (int row = 0 ; row < cs [0 ]._len ; row ++) {
409
415
data (values , cs , row , _means , _mults );
410
- closest (_clusters , values , cd );
416
+ // closest(_clusters, values, cd);
417
+ closest (_clusters , values , _ncats , cd );
411
418
int clu = cd ._cluster ;
412
419
// ncs[0].addNum(clu);
413
420
ncs .addEnum (clu );
@@ -478,6 +485,7 @@ public static class Lloyds extends MRTask2<Lloyds> {
478
485
// IN
479
486
double [][] _clusters ;
480
487
double [] _means , _mults ; // Normalization
488
+ int _ncats , _nnums ;
481
489
482
490
// OUT
483
491
double [][] _cMeans , _cSqrs ; // Means and sum of squares for each cluster
@@ -499,7 +507,8 @@ public static class Lloyds extends MRTask2<Lloyds> {
499
507
int [] clusters = new int [cs [0 ]._len ];
500
508
for ( int row = 0 ; row < cs [0 ]._len ; row ++ ) {
501
509
data (values , cs , row , _means , _mults );
502
- closest (_clusters , values , cd );
510
+ // closest(_clusters, values, cd);
511
+ closest (_clusters , values , _ncats , cd );
503
512
int clu = clusters [row ] = cd ._cluster ;
504
513
_sqr += cd ._dist ;
505
514
if ( clu == -1 )
@@ -556,10 +565,6 @@ private static final class ClusterDist {
556
565
double _dist ;
557
566
}
558
567
559
- private static ClusterDist closest (double [][] clusters , double [] point , ClusterDist cd ) {
560
- return closest (clusters , point , cd , clusters .length );
561
- }
562
-
563
568
private static double minSqr (double [][] clusters , double [] point , ClusterDist cd ) {
564
569
return closest (clusters , point , cd , clusters .length )._dist ;
565
570
}
@@ -568,14 +573,43 @@ private static double minSqr(double[][] clusters, double[] point, ClusterDist cd
568
573
return closest (clusters , point , cd , count )._dist ;
569
574
}
570
575
571
- /** Return both nearest of N cluster/centroids, and the square-distance. */
576
+ private static ClusterDist closest (double [][] clusters , double [] point , ClusterDist cd ) {
577
+ return closest (clusters , point , cd , clusters .length );
578
+ }
579
+
580
+ private static ClusterDist closest (double [][] clusters , double [] point , int ncats , ClusterDist cd ) {
581
+ return closest (clusters , point , ncats , cd , clusters .length );
582
+ }
583
+
572
584
private static ClusterDist closest (double [][] clusters , double [] point , ClusterDist cd , int count ) {
585
+ return closest (clusters , point , 0 , cd , count );
586
+ }
587
+
588
+ private static ClusterDist closest (double [][] clusters , double [] point , int ncats , ClusterDist cd , int count ) {
589
+ return closest (clusters , point , ncats , cd , count , 1 );
590
+ }
591
+
592
+ /** Return both nearest of N cluster/centroids, and the square-distance. */
593
+ private static ClusterDist closest (double [][] clusters , double [] point , int ncats , ClusterDist cd , int count , double dist ) {
573
594
int min = -1 ;
574
595
double minSqr = Double .MAX_VALUE ;
575
596
for ( int cluster = 0 ; cluster < count ; cluster ++ ) {
576
597
double sqr = 0 ; // Sum of dimensional distances
577
598
int pts = point .length ; // Count of valid points
578
- for ( int column = 0 ; column < clusters [cluster ].length ; column ++ ) {
599
+
600
+ // Expand categoricals into binary indicator cols
601
+ for (int column = 0 ; column < ncats ; column ++) {
602
+ double d = point [column ];
603
+ if (Double .isNaN (d ))
604
+ pts --;
605
+ else {
606
+ // TODO: What is the distance between unequal categoricals?
607
+ if (d != clusters [cluster ][column ])
608
+ sqr += 2 * dist * dist ;
609
+ }
610
+ }
611
+
612
+ for ( int column = ncats ; column < clusters [cluster ].length ; column ++ ) {
579
613
double d = point [column ];
580
614
if ( Double .isNaN (d ) ) { // Bad data?
581
615
pts --; // Do not count
@@ -686,14 +720,16 @@ private static double[][] denormalize(double[][] clusters, Vec[] vecs) {
686
720
private static void data (double [] values , Vec [] vecs , long row , double [] means , double [] mults ) {
687
721
for ( int i = 0 ; i < values .length ; i ++ ) {
688
722
double d = vecs [i ].at (row );
689
- values [i ] = data (d , i , means , mults );
723
+ // values[i] = data(d, i, means, mults);
724
+ values [i ] = data (d , i , means , mults , vecs [i ].cardinality ());
690
725
}
691
726
}
692
727
693
728
private static void data (double [] values , Chunk [] chks , int row , double [] means , double [] mults ) {
694
729
for ( int i = 0 ; i < values .length ; i ++ ) {
695
730
double d = chks [i ].at0 (row );
696
- values [i ] = data (d , i , means , mults );
731
+ // values[i] = data(d, i, means, mults);
732
+ values [i ] = data (d , i , means , mults , chks [i ]._vec .cardinality ());
697
733
}
698
734
}
699
735
@@ -709,4 +745,20 @@ private static double data(double d, int i, double[] means, double[] mults) {
709
745
}
710
746
return d ;
711
747
}
748
+
749
+ private static double data (double d , int i , double [] means , double [] mults , int cardinality ) {
750
+ if (cardinality == -1 ) {
751
+ if ( Double .isNaN (d ) )
752
+ d = means [i ];
753
+ if ( mults != null ) {
754
+ d -= means [i ];
755
+ d *= mults [i ];
756
+ }
757
+ } else {
758
+ // TODO: If NaN, then replace with majority class?
759
+ if (Double .isNaN (d ))
760
+ d = Math .min (Math .round (means [i ]), cardinality -1 );
761
+ }
762
+ return d ;
763
+ }
712
764
}
0 commit comments