Skip to content

Commit

Permalink
Merge branch 'master' of github.com:jgustave/h2o into jgustave-master
Browse files Browse the repository at this point in the history
Conflicts:
	h2o-scala/build.sbt
  • Loading branch information
mmalohlava committed Feb 26, 2014
2 parents 6069c1d + a62bd2c commit 2d67ed1
Show file tree
Hide file tree
Showing 5 changed files with 323 additions and 18 deletions.
4 changes: 3 additions & 1 deletion h2o-scala/build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ libraryDependencies += "org.apache.hadoop" % "hadoop-client" % "1.1.0"

libraryDependencies += "joda-time" % "joda-time" % "2.3"

libraryDependencies += "com.github.wookietreiber" %% "scala-chart" % "latest.integration"

libraryDependencies <+= scalaVersion { v => "org.scala-lang" % "scala-library" % v }

libraryDependencies <+= scalaVersion { v => "org.scala-lang" % "scala-compiler" % v }
Expand Down Expand Up @@ -64,7 +66,7 @@ unmanagedClasspath in Compile += h2oSources.value

unmanagedClasspath in Runtime += h2oClasses.value

// Setup run
// Setup run
// - Fork in run
fork in run := true

Expand Down
122 changes: 122 additions & 0 deletions h2o-scala/src/main/java/water/api/dsl/util/Reservoir.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
package water.api.dsl.util;

import water.AutoBuffer;
import water.Iced;

import java.util.PriorityQueue;

/**
* Simple utility to perform reservoir sampling in order to make huge data small enough to
* bring in to local memory and display/etc.
* I did this in Java instead of Scala, because I wasn't sure about Serialization/transient issues in Scala.
*
* Not Entirely Sure about threading/synchronization in this model
*/
public class Reservoir extends Iced {

private transient final PriorityQueue<SampleItem> minHeap = new PriorityQueue<SampleItem>();
public final int reservoirSize;

//Needed For Serialization?
public Reservoir() {
this.reservoirSize = -1;
}

public Reservoir(int reservoirSize) {
this.reservoirSize = reservoirSize;
}

@Override
public water.AutoBuffer write(AutoBuffer bb) {
int[] order = new int[minHeap.size()];
double[] vals = new double[minHeap.size()]; //TODO: Till we figure out nulls

int x=0;
for(SampleItem item : minHeap) {
order[x] = item.getRandomOrder();
vals[x] = item.getValue();//==null?0:item.getValue(); //TODO: till we figure out nulls
x++;
}
bb.put4(reservoirSize);
bb.putA4(order);
bb.putA8d(vals);
return( bb );
}

@SuppressWarnings("unchecked")
@Override
public water.api.dsl.util.Reservoir read(AutoBuffer bb) {
int rSize = bb.get4();
int[] order = bb.getA4();
double[] vals = bb.getA8d();
Reservoir reservoir = new Reservoir(rSize);

for( int x=0;x<order.length;x++) {
reservoir.minHeap.add(new SampleItem(order[x], vals[x]));
}

return(reservoir);
}

public void add(double item) {
add( new SampleItem(item) );
}

synchronized public void add(SampleItem item) {
if( item != null ) {
if( minHeap.size() < reservoirSize) {
minHeap.add(item);
}else {
SampleItem head = minHeap.peek();
//If Item is > than the lest item in the heap.. then swap them out.
if( item.getRandomOrder() > head.getRandomOrder() ) {
minHeap.poll();
minHeap.add(item);
}
}
}
}

// synchronized public void merge( Reservoir other ) {
// if( other != null ) {
// for( SampleItem item : other.minHeap) {
// add(item);
// }
// }
// }
synchronized public Reservoir merge( Reservoir other ) {
Reservoir result = new Reservoir(this.reservoirSize);
for(SampleItem item : this.minHeap ) {
result.add(item);
}
if( other != null ) {
for(SampleItem item : other.minHeap ) {
result.add(item);
}
}
return( result );
}


synchronized public double[] getValues() {
double[] result = new double[minHeap.size()];

int x=0;
for(SampleItem item : minHeap) {
result[x] = item.getValue();
x++;
}
return( result );
}

synchronized public int getNumValues(){
return( minHeap.size() );
}

// public static void main(String[] args) {
// Reservoir reservoir = new Reservoir(10);
// for( int x=0;x<15;x++) {
// reservoir.add((double)x);
// }
// }
}
51 changes: 51 additions & 0 deletions h2o-scala/src/main/java/water/api/dsl/util/SampleItem.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package water.api.dsl.util;

import water.Iced;

import java.util.Random;

/**
* Wraps an value, with a random int (Used in reservoir sampling)
*/
public class SampleItem extends Iced implements Comparable<SampleItem> {
private static final Random rand = new Random();
public final int randomOrder;
public final double value;

public SampleItem(int randomOrder, double value) {
this.randomOrder = randomOrder;
this.value = value;
}

public SampleItem(double value) {
this.value = value;
this.randomOrder = rand.nextInt();
}

public int getRandomOrder() {
return randomOrder;
}

public double getValue() {
return value;
}

@Override
public int compareTo(SampleItem that) {
if(this == that) {
return(0);
}else if(that == null ) {
return(1);
}else {
return( java.lang.Integer.compare(this.randomOrder,that.randomOrder) );
}
}

@Override
public String toString() {
return "SampleItemStub{" +
"randomOrder=" + randomOrder +
", value=" + value +
'}';
}
}
38 changes: 38 additions & 0 deletions h2o-scala/src/main/scala/water/api/dsl/util/Sampler.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package water.api.dsl.util

import water.api.dsl.T_T_Collect

/**
* Resevoir Sampler to extract a column from a DataFrame and bring it to the local context.
*
*
* Some Scala REPL foo:
*
* val g = parse ("/Users/jerdavis/temp/export1.gz")
* import water.api.dsl.util._
* val smallData = g(100) collect ( new Reservoir(1000), new Sampler() )
*
*
* import scalax.chart._
* import scalax.chart.Charting._
*
* val data = (1 to smallData.getNumValues) zip smallData.getValues
* val dataset = data.toXYSeriesCollection("some points")
* val chart = XYLineChart(dataset)
* chart.show
*
* Not entirely sure about threading / synchronization in this model
*/
class Sampler extends T_T_Collect[Reservoir,scala.Double] {

override def apply(acc:Reservoir, rhs:Array[scala.Double]):Reservoir = {
for( x <- rhs ) {
acc.add(x)
}
acc
}

override def reduce(lhs:Reservoir,rhs:Reservoir) = {
lhs.merge(rhs)
}
}
126 changes: 109 additions & 17 deletions src/main/java/hex/LinearRegression.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
public abstract class LinearRegression {

static public JsonObject run( ValueArray ary, int colA, int colB ) {
return( exec(ary,colA,colB).toJson() );
}
static public LRResult exec( ValueArray ary, int colA, int colB ) {
// Pass 1: compute sums & sums-of-squares
long start = System.currentTimeMillis();
CalcSumsTask lr1 = new CalcSumsTask();
Expand Down Expand Up @@ -42,23 +45,22 @@ static public JsonObject run( ValueArray ary, int colA, int colB ) {
double svar1 = svar / lr2._XXbar;
double svar0 = svar/n + lr2._Xbar*lr2._Xbar*svar1;

JsonObject res = new JsonObject();
res.addProperty("Key", ary._key.toString());
res.addProperty("ColA", ary._cols[colA]._name);
res.addProperty("ColB", ary._cols[colB]._name);
res.addProperty("Pass1Msecs", pass1 - start);
res.addProperty("Pass2Msecs", pass2-pass1);
res.addProperty("Pass3Msecs", pass3-pass2);
res.addProperty("Rows", n);
res.addProperty("Beta0", lr3._beta0);
res.addProperty("Beta1", lr3._beta1);
res.addProperty("RSquared", R2);
res.addProperty("Beta0StdErr", Math.sqrt(svar0));
res.addProperty("Beta1StdErr", Math.sqrt(svar1));
res.addProperty("SSTO", lr2._YYbar);
res.addProperty("SSE", lr3._rss);
res.addProperty("SSR", lr3._ssr);
return res;
LRResult result = new LRResult(ary._key.toString(),
ary._cols[colA]._name,
ary._cols[colB]._name,
pass1-start,
pass2-pass1,
pass3-pass2,
n,
lr3._beta0,
lr3._beta1,
R2,
Math.sqrt(svar0),
Math.sqrt(svar1),
lr2._YYbar,
lr3._rss,
lr3._ssr );
return( result );
}

public static class CalcSumsTask extends MRTask {
Expand Down Expand Up @@ -220,4 +222,94 @@ public void reduce( DRemoteTask rt ) {
_ssr += lr3._ssr;
}
}
public static class LRResult{
public final String key;
public final String colA;
public final String colB;
public final long pass1Msecs;
public final long pass2Msecs;
public final long pass3Msecs;
public final long rows;
public final double beta0;
public final double beta1;
public final double rSquared;
public final double beta0StdErr;
public final double beta1StdErr;
public final double ssto;
public final double sse;
public final double ssr;

public LRResult (String key,
String colA,
String colB,
long pass1Msecs,
long pass2Msecs,
long pass3Msecs,
long rows,
double beta0,
double beta1,
double rSquared,
double beta0StdErr,
double beta1StdErr,
double ssto,
double sse,
double ssr) {
this.key = key;
this.colA = colA;
this.colB = colB;
this.pass1Msecs = pass1Msecs;
this.pass2Msecs = pass2Msecs;
this.pass3Msecs = pass3Msecs;
this.rows = rows;
this.beta0 = beta0;
this.beta1 = beta1;
this.rSquared = rSquared;
this.beta0StdErr = beta0StdErr;
this.beta1StdErr = beta1StdErr;
this.ssto = ssto;
this.sse = sse;
this.ssr = ssr;
}

@Override
public String toString () {
return "LRResult{" +
"key='" + key + '\'' +
", colA='" + colA + '\'' +
", colB='" + colB + '\'' +
", pass1Msecs=" + pass1Msecs +
", pass2Msecs=" + pass2Msecs +
", pass3Msecs=" + pass3Msecs +
", rows=" + rows +
", beta0=" + beta0 +
", beta1=" + beta1 +
", rSquared=" + rSquared +
", beta0StdErr=" + beta0StdErr +
", beta1StdErr=" + beta1StdErr +
", ssto=" + ssto +
", sse=" + sse +
", ssr=" + ssr +
'}';
}

public JsonObject toJson() {
JsonObject res = new JsonObject();
res.addProperty("Key", key);
res.addProperty("ColA", colA);
res.addProperty("ColB", colB);
res.addProperty("Pass1Msecs", pass1Msecs);
res.addProperty("Pass2Msecs", pass2Msecs);
res.addProperty("Pass3Msecs", pass3Msecs);
res.addProperty("Rows", rows);
res.addProperty("Beta0", beta0);
res.addProperty("Beta1", beta1);
res.addProperty("RSquared", rSquared);
res.addProperty("Beta0StdErr", beta0StdErr);
res.addProperty("Beta1StdErr", beta1StdErr);
res.addProperty("SSTO", ssto);
res.addProperty("SSE", sse);
res.addProperty("SSR", ssr);
return( res );
}
}
}

0 comments on commit 2d67ed1

Please sign in to comment.