Skip to content

Commit

Permalink
Job tracking & cancel, first shot
Browse files Browse the repository at this point in the history
  • Loading branch information
cypof committed Feb 15, 2013
1 parent a89123a commit 8138483
Show file tree
Hide file tree
Showing 35 changed files with 443 additions and 183 deletions.
6 changes: 3 additions & 3 deletions .classpath
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
<?xml version="1.0" encoding="UTF-8"?>
<classpath>
<classpathentry kind="src" path="src/main/java"/>
<classpathentry kind="src" path="src/test/java"/>
<classpathentry kind="src" path="src/main/java"/>
<classpathentry kind="src" path="src/test/java"/>
<classpathentry kind="lib" path="lib/apache/commons-configuration-1.6.jar"/>
<classpathentry kind="lib" path="lib/apache/commons-lang-2.4.jar" sourcepath="lib/apache/commons-lang-2.4-sources.jar"/>
<classpathentry kind="lib" path="lib/apache/commons-logging-1.1.1.jar"/>
Expand Down Expand Up @@ -29,5 +29,5 @@
<classpathentry kind="lib" path="lib/jets3t/commons-httpclient-3.1.jar"/>
<classpathentry kind="lib" path="lib/jets3t/jets3t-0.6.1.jar"/>
<classpathentry kind="lib" path="lib/apache/commons-codec-1.3.jar"/>
<classpathentry kind="output" path="build/classes"/>
<classpathentry kind="output" path="target/classes"/>
</classpath>
3 changes: 0 additions & 3 deletions .settings/org.eclipse.core.resources.prefs

This file was deleted.

4 changes: 3 additions & 1 deletion src/main/java/H2OInit/Boot.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import java.util.zip.ZipEntry;
import java.util.zip.ZipFile;

import water.util.Utils;


/** Initializer class for H2O.
*
Expand Down Expand Up @@ -49,7 +51,7 @@ private byte[] getMD5(InputStream is) throws IOException {
} catch( NoSuchAlgorithmException e ) {
throw new RuntimeException(e);
} finally {
try { is.close(); } catch( IOException e ) { }
Utils.close(is);
}
}

Expand Down
105 changes: 67 additions & 38 deletions src/main/java/hex/KMeans.java
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
package hex;

import java.util.*;
import com.google.gson.*;

import water.*;
import water.Jobs.Job;
import water.Jobs.Progress;

import com.google.gson.*;

/**
* Scalable K-Means++ (KMeans||)<br>
Expand All @@ -15,88 +19,102 @@ public abstract class KMeans {

public static class KMeansModel extends Model {
public static final String KEY_PREFIX = "__KMeansModel_";
public double[][] _clusters; // The cluster centers, normalized according to _va
public int _iteration;
public int _k;
public double[][] _clusters; // The cluster centers, normalized according to _va
public int _k;

// Empty constructor for deserialization
public KMeansModel() {}
public KMeansModel() {
}

KMeansModel(Key selfKey, int cols[], Key dataKey, int k) {
// Unlike other models, k-means is a discovery-only procedure and does
// not require a response-column to train. This also means the clusters
// not require a response-column to train. This also means the clusters
// are not classes (although, if a class/response is associated with each
// row we could count the number of each class in each cluster).
super(selfKey,cols,dataKey);
super(selfKey, cols, dataKey);
_k = k;
}

// Accept only columns with a defined mean. Used during the Model.<init> call.
@Override public boolean columnFilter(ValueArray.Column C) {
// Accept only columns with a defined mean. Used during the Model.<init> call.
@Override
public boolean columnFilter(ValueArray.Column C) {
return !Double.isNaN(C._mean);
}

public JsonObject toJson() {
JsonObject res = new JsonObject();
res.addProperty("iterations", _iteration);
JsonArray ary = new JsonArray();
for( double[] dd : clusters() ) {
JsonArray ary2 = new JsonArray();
for( double d : dd )
ary2.add(new JsonPrimitive(d));
ary.add(ary2);
}
res.add("clusters",ary);
res.add("clusters", ary);
return res;
}

// Return the clusters, denormalized
public double[][] clusters() {
double dd[][] = _clusters.clone();
for( double ds[] : dd )
for( int i=0; i<ds.length; i++ ) {
for( int i = 0; i < ds.length; i++ ) {
ValueArray.Column C = _va._cols[i];
double d = ds[i];
if( C._sigma != 0.0 && !Double.isNaN(C._sigma) ) d *= C._sigma;
if( C._sigma != 0.0 && !Double.isNaN(C._sigma) )
d *= C._sigma;
d += C._mean;
ds[i] = d;
}
return dd;
}

/** Single row scoring, on properly ordered data. Will return NaN if any
* data element contains a NaN. Returns the cluster-number, which is
* mostly an internal value. */
protected double score0( double[] data ) {
for( int i=0; i<data.length; i++ ) { // Normalize the data before scoring
/**
* Single row scoring, on properly ordered data. Will return NaN if any data element contains a NaN. Returns the
* cluster-number, which is mostly an internal value.
*/
protected double score0(double[] data) {
for( int i = 0; i < data.length; i++ ) { // Normalize the data before scoring
ValueArray.Column C = _va._cols[i];
double d = data[i] - C._mean;
if( C._sigma != 0.0 && !Double.isNaN(C._sigma) ) d /= C._sigma;
if( C._sigma != 0.0 && !Double.isNaN(C._sigma) )
d /= C._sigma;
data[i] = d;
}
return closest(_clusters,data);
return closest(_clusters, data);
}

/** Single row scoring, on a compatible ValueArray (when pushed throw the mapping) */
protected double score0( ValueArray data, int row, int[] mapping ) {
protected double score0(ValueArray data, int row, int[] mapping) {
throw H2O.unimpl();
}

/** Bulk scoring API, on a compatible ValueArray (when pushed throw the mapping) */
protected double score0( ValueArray data, AutoBuffer ab, int row_in_chunk, int[] mapping ) {
protected double score0(ValueArray data, AutoBuffer ab, int row_in_chunk, int[] mapping) {
throw H2O.unimpl();
}
}

// Return a normalized value. If missing, return the mean (which we know
// Return a normalized value. If missing, return the mean (which we know
// exists because we filtered out columns with no mean).
private static double datad( ValueArray va, AutoBuffer bits, int row, ValueArray.Column C) {
if( va.isNA(bits,row,C) ) return C._mean;
private static double datad(ValueArray va, AutoBuffer bits, int row, ValueArray.Column C) {
if( va.isNA(bits, row, C) )
return C._mean;
double d = va.datad(bits, row, C) - C._mean;
return (C._sigma == 0.0 || Double.isNaN(C._sigma)) ? d : d/C._sigma;
return (C._sigma == 0.0 || Double.isNaN(C._sigma)) ? d : d / C._sigma;
}

static public void run(Key dest, ValueArray va, int k, double epsilon, int... cols) {
final KMeansModel res = new KMeansModel(dest,cols,va._key,k);
public static void run(Key dest, ValueArray va, int k, double epsilon, int... cols) {
Job job = startJob(dest, va, k, epsilon, cols);
run(job, va, k, epsilon, cols);
}

public static Job startJob(Key dest, ValueArray va, int k, double epsilon, int... cols) {
return Jobs.start("KMeans K: " + k + ", Cols: " + cols.length, dest);
}

public static void run(Job job, ValueArray va, int k, double epsilon, int... cols) {
KMeansModel res = new KMeansModel(job._dest, cols, va._key, k);
// Updated column mapping selection after removing various junk columns
cols = res.columnMapping(va.colNames());

Expand All @@ -105,9 +123,11 @@ static public void run(Key dest, ValueArray va, int k, double epsilon, int... co
clusters[0] = new double[cols.length];
AutoBuffer bits = va.getChunk(0);
for( int c = 0; c < cols.length; c++ )
clusters[0][c] = datad(va,bits, 0, va._cols[cols[c]]);
clusters[0][c] = datad(va, bits, 0, va._cols[cols[c]]);

for( int i = 0; i < 5; i++ ) {
int iteration = 0;
float expected = 20;
while( iteration < 5 ) {
// Sum squares distances to clusters
Sqr sqr = new Sqr();
sqr._arykey = va._key;
Expand All @@ -125,12 +145,14 @@ static public void run(Key dest, ValueArray va, int k, double epsilon, int... co
sampler.invoke(va._key);
clusters = DRemoteTask.merge(clusters, sampler._newClusters);

res._iteration++;
UKV.put(dest, res);
if( Jobs.cancelled(job._key) ) {
Jobs.remove(job._key);
return;
}
UKV.put(job._progress, new Progress(++iteration / expected));
}

clusters = recluster(clusters, k);
res._clusters = clusters; // sharing is caring....

// Iterate until no cluster mean moves more than epsilon
boolean moved = true;
Expand All @@ -150,9 +172,16 @@ static public void run(Key dest, ValueArray va, int k, double epsilon, int... co
clusters[cluster][column] = value;
}
}
res._iteration++;

float progress = Math.min(++iteration / expected, 1f);
UKV.put(job._progress, new Progress(progress));
res._clusters = clusters;
UKV.put(job._dest, res);
if( Jobs.cancelled(job._key) )
break;
}
UKV.put(dest, res);

Jobs.remove(job._key);
}

public static class Sqr extends MRTask {
Expand All @@ -176,7 +205,7 @@ public void map(Key key) {

for( int row = 0; row < rows; row++ ) {
for( int column = 0; column < _cols.length; column++ )
values[column] = datad(va,bits, row, va._cols[_cols[column]]);
values[column] = datad(va, bits, row, va._cols[_cols[column]]);

_sqr += minSqr(_clusters, _clusters.length, values);
}
Expand Down Expand Up @@ -217,7 +246,7 @@ public void map(Key key) {

for( int row = 0; row < rows; row++ ) {
for( int column = 0; column < _cols.length; column++ )
values[column] = datad(va,bits, row, va._cols[_cols[column]]);
values[column] = datad(va, bits, row, va._cols[_cols[column]]);

double sqr = minSqr(_clusters, _clusters.length, values);

Expand Down Expand Up @@ -274,7 +303,7 @@ public void map(Key key) {
// Find closest cluster for each row
for( int row = 0; row < rows; row++ ) {
for( int column = 0; column < _cols.length; column++ )
values[column] = datad(va,bits, row, va._cols[_cols[column]]);
values[column] = datad(va, bits, row, va._cols[_cols[column]]);

int cluster = closest(_clusters, values);

Expand Down
1 change: 1 addition & 0 deletions src/main/java/hex/rf/Confusion.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import water.*;
import water.ValueArray.Column;
import water.util.Utils;

import com.google.common.primitives.Ints;

Expand Down
1 change: 1 addition & 0 deletions src/main/java/hex/rf/DABuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import jsr166y.ForkJoinTask;
import jsr166y.RecursiveAction;
import water.*;
import water.util.Utils;

class DABuilder {

Expand Down
1 change: 1 addition & 0 deletions src/main/java/hex/rf/DRF.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import water.*;
import water.ValueArray.Column;
import water.util.Utils;
import water.Timer;

/** Distributed RandomForest */
Expand Down
1 change: 1 addition & 0 deletions src/main/java/hex/rf/Data.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import java.util.*;

import water.MemoryManager;
import water.util.Utils;

public class Data implements Iterable<Row> {
/** Use stratified sampling */
Expand Down
1 change: 1 addition & 0 deletions src/main/java/hex/rf/DataAdapter.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import water.*;
import water.ValueArray.Column;
import water.util.Utils;

/**A DataAdapter maintains an encoding of the original data. Every raw value (of type float)
* is represented by a short value. When the number of unique raw value is larger that binLimit,
Expand Down
2 changes: 2 additions & 0 deletions src/main/java/hex/rf/EntropyStatistic.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import java.util.Random;

import water.util.Utils;


/**The entropy formula is the classic Shannon entropy gain, which is:
*
Expand Down
2 changes: 2 additions & 0 deletions src/main/java/hex/rf/GiniStatistic.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import java.util.Random;

import water.util.Utils;

/** Computes the gini split statistics.
*
* The Gini fitness is calculated as a probability that the element will be
Expand Down
1 change: 1 addition & 0 deletions src/main/java/hex/rf/RandomForest.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import water.*;
import water.Timer;
import water.util.TestUtil;
import water.util.Utils;

/**
* A RandomForest can be used for growing or validation. The former starts with a known target number of trees,
Expand Down
2 changes: 2 additions & 0 deletions src/main/java/hex/rf/Statistic.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

import java.util.*;

import water.util.Utils;

/** Keeps track of the column distributions and analyzes the column splits in the
* end producing the single split that will be used for the node. */
abstract class Statistic {
Expand Down
1 change: 1 addition & 0 deletions src/main/java/hex/rf/StratifiedDABuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import jsr166y.RecursiveAction;
import water.*;
import water.ValueArray.Column;
import water.util.Utils;

public class StratifiedDABuilder extends DABuilder {

Expand Down
1 change: 1 addition & 0 deletions src/main/java/hex/rf/Tree.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import jsr166y.RecursiveTask;
import water.*;
import water.Timer;
import water.util.Utils;

public class Tree extends CountedCompleter {
static public enum StatType { ENTROPY, GINI };
Expand Down
3 changes: 2 additions & 1 deletion src/main/java/hex/rng/MersenneTwisterRNG.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@
// ============================================================================
package hex.rng;

import hex.rf.Utils;

import java.util.Random;
import java.util.concurrent.locks.ReentrantLock;

import water.util.Utils;

/**
* <p>
* Random number generator based on the <a
Expand Down
3 changes: 2 additions & 1 deletion src/main/java/water/H2O.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import water.hdfs.HdfsLoader;
import water.nbhm.NonBlockingHashMap;
import water.store.s3.PersistS3;
import water.util.Utils;

import com.google.common.base.Strings;
import com.google.common.collect.Lists;
Expand Down Expand Up @@ -248,7 +249,7 @@ private static InetAddress guessInetAddress(List<InetAddress> ips) {
} catch( Throwable t ) {
return null;
} finally {
try { if(s != null) s.close(); } catch( Throwable t ) { }
Utils.close(s);
}
}

Expand Down
Loading

0 comments on commit 8138483

Please sign in to comment.