Skip to content

Commit 2fe96a8

Browse files
author
cypof
committed
KMeans2 tests and fixes
1 parent 59a9dd1 commit 2fe96a8

20 files changed

+359
-253
lines changed

experiments/src/main/java/hex/MnistDist16x.java

+22-19
Original file line numberDiff line numberDiff line change
@@ -4,22 +4,21 @@
44

55
import water.*;
66
import water.deploy.Cloud;
7-
import water.fvec.NFSFileVec;
8-
import water.fvec.ParseDataset2;
7+
import water.fvec.*;
98
import water.util.Log;
9+
import water.util.Utils;
1010

1111
public class MnistDist16x {
1212
public static void main(String[] args) throws Exception {
1313
Cloud cloud = new Cloud();
1414
for( int i = LOW; i < LOW + LEN; i++ )
15-
cloud._publicIPs.add("192.168.1." + (161 + i));
16-
cloud._clientRSyncIncludes.add("../libs/jdk");
17-
cloud._clientRSyncIncludes.add("smalldata");
18-
cloud._clientRSyncIncludes.add("experiments/target");
19-
cloud._fannedRSyncIncludes.add("jdk");
20-
cloud._fannedRSyncIncludes.add("smalldata");
15+
cloud.publicIPs.add("192.168.1." + (161 + i));
16+
cloud.clientRSyncIncludes.add("smalldata");
17+
cloud.clientRSyncIncludes.add("experiments/target");
18+
cloud.fannedRSyncIncludes.add("smalldata");
19+
cloud.jdk = "../libs/jdk";
2120
String java = "-ea -Xmx120G -Dh2o.debug";
22-
String node = "-mainClass " + MnistDist16x.UserCode.class.getName() + " -beta";
21+
String node = "-mainClass " + UserCode.class.getName() + " -beta";
2322
cloud.start(java.split(" "), node.split(" "));
2423
}
2524

@@ -29,22 +28,25 @@ public static class UserCode {
2928
public static void userMain(String[] args) throws Exception {
3029
H2O.main(args);
3130

32-
Log.info("blah: " + System.getProperty("java.home"));
31+
Log.info("java: " + System.getProperty("java.home"));
3332

3433
TestUtil.stall_till_cloudsize(LEN);
3534
//Sample08_DeepNeuralNet_EC2.run();
3635
//Sample07_NeuralNet_Mnist8m.run();
37-
//Sample07_NeuralNet_Mnist.run();
36+
//new Sample07_NeuralNetLowLevel().run();
3837

39-
File f = new File("smalldata/mnist/train.csv.gz");
40-
Key dest = Key.make("train.hex");
41-
Key fkey = NFSFileVec.make(f);
42-
ParseDataset2.parse(dest, new Key[] { fkey });
38+
Key fkey = NFSFileVec.make(new File("/home/0xdiag/home-0xdiag-datasets/mnist/mnist8m.csv"));
39+
Key mnist8m = Key.make("mnist8m.csv");
40+
Frame frame = ParseDataset2.parse(mnist8m, new Key[] { fkey });
4341

44-
f = new File("smalldata/mnist/test.csv.gz");
45-
dest = Key.make("test.hex");
46-
fkey = NFSFileVec.make(f);
47-
ParseDataset2.parse(dest, new Key[] { fkey });
42+
Vec response = frame.vecs()[0];
43+
Vec[] vecs = Utils.remove(frame.vecs(), 0);
44+
Key train = Key.make("train.hex");
45+
UKV.put(train, new Frame(frame.names(), Utils.append(vecs, response)));
46+
47+
Key dest = Key.make("test.hex");
48+
Key ftest = NFSFileVec.make(new File("smalldata/mnist/test.csv.gz"));
49+
ParseDataset2.parse(dest, new Key[] { ftest });
4850

4951
// Basic visualization of images and weights
5052
// JFrame frame = new JFrame("H2O");
@@ -54,6 +56,7 @@ public static void userMain(String[] args) throws Exception {
5456
// frame.pack();
5557
// frame.setLocationRelativeTo(null);
5658
// frame.setVisible(true);
59+
Log.info("Ready");
5760
}
5861
}
5962
}

src/main/java/hex/GridSearch.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ public class GridSearch extends Job {
3838

3939
@Override public Response redirect() {
4040
String n = GridSearchProgress.class.getSimpleName();
41-
return new Response(Response.Status.redirect, this, -1, -1, n, "job", job_key, "dst_key", destination_key);
41+
return new Response(Response.Status.redirect, this, -1, -1, n, "job_key", job_key, "destination_key", destination_key);
4242
}
4343

4444
public static class GridSearchProgress extends Progress2 {
@@ -105,7 +105,7 @@ public static class GridSearchProgress extends Progress2 {
105105
sb.append("<td>").append(speed).append("</td>");
106106

107107
String link = info._job.destination_key.toString();
108-
if( info._job.start_time != 0 ) {
108+
if( info._job.start_time != 0 && DKV.get(info._job.destination_key) != null ) {
109109
if( info._model instanceof GBMModel )
110110
link = GBMModelView.link(link, info._job.destination_key);
111111
else if( info._model instanceof NeuralNetModel )

src/main/java/hex/KMeans.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ private void run(KMeansModel res, ValueArray va, int k, Initialization init, int
102102
sampler._probability = k * 3; // Over-sampling
103103
sampler._seed = res._randSeed;
104104
sampler.invoke(va._key);
105-
clusters = DRemoteTask.merge(clusters, sampler._clust2);
105+
clusters = Utils.append(clusters, sampler._clust2);
106106

107107
if( cancelled() ) {
108108
remove();
@@ -210,7 +210,7 @@ public static class Sampler extends MRTask {
210210

211211
@Override public void reduce(DRemoteTask rt) {
212212
Sampler task = (Sampler) rt;
213-
_clust2 = _clust2 == null ? task._clust2 : merge(_clust2, task._clust2);
213+
_clust2 = _clust2 == null ? task._clust2 : Utils.append(_clust2, task._clust2);
214214
}
215215
}
216216

src/main/java/hex/KMeans2.java

+58-39
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
import water.*;
99
import water.Job.ColumnsJob;
1010
import water.api.*;
11-
import water.fvec.*;
11+
import water.fvec.Chunk;
12+
import water.fvec.Vec;
1213
import water.util.Utils;
1314

1415
/**
@@ -49,14 +50,13 @@ public KMeans2() {
4950
for( int i = 0; i < cols.length; i++ )
5051
names[i] = source._names[cols[i]];
5152
Vec[] vecs = selectVecs(source);
52-
Frame frame = new Frame(names, vecs);
5353
// Fill-in response based on K
54-
Vec response = frame.anyVec().makeZero();
55-
response._domain = new String[k];
56-
for( int i = 0; i < response._domain.length; i++ )
57-
response._domain[i] = "Cluster " + i;
58-
frame.add("response", response);
59-
KMeans2Model model = new KMeans2Model(destination_key, sourceKey, frame);
54+
String[] domain = new String[k];
55+
for( int i = 0; i < domain.length; i++ )
56+
domain[i] = "Cluster " + i;
57+
String[] namesResp = Utils.append(names, "response");
58+
String[][] domaiResp = (String[][]) Utils.append(source.domains(), (Object) domain);
59+
KMeans2Model model = new KMeans2Model(destination_key, sourceKey, namesResp, domaiResp);
6060

6161
double[] subs = null, muls = null;
6262
if( normalize ) {
@@ -101,7 +101,7 @@ public KMeans2() {
101101
sampler._subs = subs;
102102
sampler._muls = muls;
103103
sampler.doAll(vecs);
104-
clusters = DRemoteTask.merge(clusters, sampler._sampled);
104+
clusters = Utils.append(clusters, sampler._sampled);
105105

106106
if( cancelled() )
107107
return;
@@ -120,15 +120,11 @@ public KMeans2() {
120120
task._subs = subs;
121121
task._muls = muls;
122122
task.doAll(vecs);
123-
for( int cluster = 0; cluster < clusters.length; cluster++ ) {
124-
if( task._counts[cluster] > 0 ) {
125-
for( int vec = 0; vec < vecs.length; vec++ ) {
126-
double value = task._sums[cluster][vec] / task._counts[cluster];
127-
clusters[cluster][vec] = value;
128-
}
129-
}
130-
}
131-
model.clusters = normalize ? denormalize(clusters, vecs) : clusters;
123+
model.clusters = normalize ? denormalize(task._means, vecs) : task._means;
124+
for( int clu = 0; clu < task._sigms.length; clu++ )
125+
for( int col = 0; col < task._sigms[clu].length; col++ )
126+
task._sigms[clu][col] = task._sigms[clu][col] / (task._rows[clu] - 1);
127+
model.variances = task._sigms;
132128
model.error = task._sqr;
133129
model.iterations++;
134130
UKV.put(destination_key, model);
@@ -141,7 +137,9 @@ public KMeans2() {
141137

142138
@Override protected Response redirect() {
143139
String n = KMeans2Progress.class.getSimpleName();
144-
return new Response(Response.Status.redirect, this, -1, -1, n, "job", job_key, "dst_key", destination_key);
140+
return new Response(Response.Status.redirect, this, -1, -1, n, //
141+
"job_key", job_key, //
142+
"destination_key", destination_key);
145143
}
146144

147145
public static class KMeans2Progress extends Progress2 {
@@ -204,7 +202,7 @@ public static class KMeans2Model extends Model implements Progress {
204202
@API(help = "Sum of min square distances")
205203
public double error;
206204

207-
@API(help = "Whether data should be normalized")
205+
@API(help = "Whether data was normalized")
208206
public boolean normalized;
209207

210208
@API(help = "Maximum number of iterations before stopping")
@@ -213,11 +211,14 @@ public static class KMeans2Model extends Model implements Progress {
213211
@API(help = "Iterations the algorithm ran")
214212
public int iterations;
215213

214+
@API(help = "In-cluster variances")
215+
public double[][] variances;
216+
216217
private transient double[] _subs, _muls; // Normalization
217218
private transient double[][] _normClust;
218219

219-
public KMeans2Model(Key selfKey, Key dataKey, Frame fr) {
220-
super(selfKey, dataKey, fr);
220+
public KMeans2Model(Key selfKey, Key dataKey, String names[], String domains[][]) {
221+
super(selfKey, dataKey, names, domains);
221222
}
222223

223224
@Override public float progress() {
@@ -301,47 +302,65 @@ public static class Sampler extends MRTask2<Sampler> {
301302
}
302303

303304
@Override public void reduce(Sampler other) {
304-
_sampled = DRemoteTask.merge(_sampled, other._sampled);
305+
_sampled = Utils.append(_sampled, other._sampled);
305306
}
306307
}
307308

308309
public static class Lloyds extends MRTask2<Lloyds> {
309310
// IN
310311
double[][] _clusters;
311-
double[] _subs, _muls; // Normalization
312+
double[] _subs, _muls; // Normalization
312313

313314
// OUT
314-
double[][] _sums; // Sum of (normalized) features in each cluster
315-
int[] _counts; // Count of rows in cluster
316-
double _sqr; // Total sqr distance
315+
double[][] _means, _sigms; // Means and sigma for each cluster
316+
long[] _rows; // Rows per cluster
317+
double _sqr; // Total sqr distance
317318

318319
@Override public void map(Chunk[] cs) {
319-
double[] values = new double[_clusters[0].length];
320-
_sums = new double[_clusters.length][values.length];
321-
_counts = new int[_clusters.length];
322-
ClusterDist cd = new ClusterDist();
320+
_means = new double[_clusters.length][_clusters[0].length];
321+
_sigms = new double[_clusters.length][_clusters[0].length];
322+
_rows = new long[_clusters.length];
323323

324324
// Find closest cluster for each row
325+
double[] values = new double[_clusters[0].length];
326+
ClusterDist cd = new ClusterDist();
327+
int[] clusters = new int[cs[0]._len];
325328
for( int row = 0; row < cs[0]._len; row++ ) {
326329
data(values, cs, row, _subs, _muls);
327330
closest(_clusters, values, cd);
328-
int cluster = cd._cluster;
331+
int clu = clusters[row] = cd._cluster;
329332
_sqr += cd._dist;
330-
if( cluster == -1 )
333+
if( clu == -1 )
331334
continue; // Ignore broken row
332335

333336
// Add values and increment counter for chosen cluster
334-
Utils.add(_sums[cluster], values);
335-
_counts[cluster]++;
337+
for( int col = 0; col < values.length; col++ )
338+
_means[clu][col] += values[col];
339+
_rows[clu]++;
340+
}
341+
for( int clu = 0; clu < _means.length; clu++ )
342+
for( int col = 0; col < _means[clu].length; col++ )
343+
_means[clu][col] /= _rows[clu];
344+
// Second pass for in-cluster variances
345+
for( int row = 0; row < cs[0]._len; row++ ) {
346+
int clu = clusters[row];
347+
if( clu == -1 )
348+
continue;
349+
data(values, cs, row, _subs, _muls);
350+
for( int col = 0; col < values.length; col++ ) {
351+
double delta = values[col] - _means[clu][col];
352+
_sigms[clu][col] += delta * delta;
353+
}
336354
}
337355
_clusters = null;
338356
_subs = _muls = null;
339357
}
340358

341-
@Override public void reduce(Lloyds other) {
342-
Utils.add(_sums, other._sums);
343-
Utils.add(_counts, other._counts);
344-
_sqr += other._sqr;
359+
@Override public void reduce(Lloyds mr) {
360+
for( int clu = 0; clu < _means.length; clu++ )
361+
Layer.Stats.reduce(_means[clu], _sigms[clu], _rows[clu], mr._means[clu], mr._sigms[clu], mr._rows[clu]);
362+
Utils.add(_rows, mr._rows);
363+
_sqr += mr._sqr;
345364
}
346365
}
347366

src/main/java/hex/KMeansGrid.java

-83
This file was deleted.

0 commit comments

Comments
 (0)