diff --git a/h2o-scala/build.sbt b/h2o-scala/build.sbt index c5894342e4..0bc5081a49 100644 --- a/h2o-scala/build.sbt +++ b/h2o-scala/build.sbt @@ -33,6 +33,8 @@ libraryDependencies += "org.apache.hadoop" % "hadoop-client" % "1.1.0" libraryDependencies += "joda-time" % "joda-time" % "2.3" +libraryDependencies += "com.github.wookietreiber" %% "scala-chart" % "latest.integration" + libraryDependencies <+= scalaVersion { v => "org.scala-lang" % "scala-library" % v } libraryDependencies <+= scalaVersion { v => "org.scala-lang" % "scala-compiler" % v } @@ -64,7 +66,7 @@ unmanagedClasspath in Compile += h2oSources.value unmanagedClasspath in Runtime += h2oClasses.value -// Setup run +// Setup run // - Fork in run fork in run := true diff --git a/h2o-scala/src/main/java/water/api/dsl/util/Reservoir.java b/h2o-scala/src/main/java/water/api/dsl/util/Reservoir.java new file mode 100644 index 0000000000..4712b36f2d --- /dev/null +++ b/h2o-scala/src/main/java/water/api/dsl/util/Reservoir.java @@ -0,0 +1,122 @@ +package water.api.dsl.util; + +import water.AutoBuffer; +import water.Iced; + +import java.util.PriorityQueue; + +/** + * Simple utility to perform reservoir sampling in order to make huge data small enough to + * bring in to local memory and display/etc. + * I did this in Java instead of Scala, because I wasn't sure about Serialization/transient issues in Scala. + * + * Not Entirely Sure about threading/synchronization in this model + */ +public class Reservoir extends Iced { + + private transient final PriorityQueue minHeap = new PriorityQueue(); + public final int reservoirSize; + + //Needed For Serialization? + public Reservoir() { + this.reservoirSize = -1; + } + + public Reservoir(int reservoirSize) { + this.reservoirSize = reservoirSize; + } + + @Override + public water.AutoBuffer write(AutoBuffer bb) { + int[] order = new int[minHeap.size()]; + double[] vals = new double[minHeap.size()]; //TODO: Till we figure out nulls + + int x=0; + for(SampleItem item : minHeap) { + order[x] = item.getRandomOrder(); + vals[x] = item.getValue();//==null?0:item.getValue(); //TODO: till we figure out nulls + x++; + } + bb.put4(reservoirSize); + bb.putA4(order); + bb.putA8d(vals); + return( bb ); + } + + @SuppressWarnings("unchecked") + @Override + public water.api.dsl.util.Reservoir read(AutoBuffer bb) { + int rSize = bb.get4(); + int[] order = bb.getA4(); + double[] vals = bb.getA8d(); + Reservoir reservoir = new Reservoir(rSize); + + for( int x=0;x than the lest item in the heap.. then swap them out. + if( item.getRandomOrder() > head.getRandomOrder() ) { + minHeap.poll(); + minHeap.add(item); + } + } + } + } + +// synchronized public void merge( Reservoir other ) { +// if( other != null ) { +// for( SampleItem item : other.minHeap) { +// add(item); +// } +// } +// } + synchronized public Reservoir merge( Reservoir other ) { + Reservoir result = new Reservoir(this.reservoirSize); + for(SampleItem item : this.minHeap ) { + result.add(item); + } + if( other != null ) { + for(SampleItem item : other.minHeap ) { + result.add(item); + } + } + return( result ); + } + + + synchronized public double[] getValues() { + double[] result = new double[minHeap.size()]; + + int x=0; + for(SampleItem item : minHeap) { + result[x] = item.getValue(); + x++; + } + return( result ); + } + + synchronized public int getNumValues(){ + return( minHeap.size() ); + } + +// public static void main(String[] args) { +// Reservoir reservoir = new Reservoir(10); +// for( int x=0;x<15;x++) { +// reservoir.add((double)x); +// } +// } +} diff --git a/h2o-scala/src/main/java/water/api/dsl/util/SampleItem.java b/h2o-scala/src/main/java/water/api/dsl/util/SampleItem.java new file mode 100644 index 0000000000..d80c3b6df5 --- /dev/null +++ b/h2o-scala/src/main/java/water/api/dsl/util/SampleItem.java @@ -0,0 +1,51 @@ +package water.api.dsl.util; + +import water.Iced; + +import java.util.Random; + +/** + * Wraps an value, with a random int (Used in reservoir sampling) + */ +public class SampleItem extends Iced implements Comparable { + private static final Random rand = new Random(); + public final int randomOrder; + public final double value; + + public SampleItem(int randomOrder, double value) { + this.randomOrder = randomOrder; + this.value = value; + } + + public SampleItem(double value) { + this.value = value; + this.randomOrder = rand.nextInt(); + } + + public int getRandomOrder() { + return randomOrder; + } + + public double getValue() { + return value; + } + + @Override + public int compareTo(SampleItem that) { + if(this == that) { + return(0); + }else if(that == null ) { + return(1); + }else { + return( java.lang.Integer.compare(this.randomOrder,that.randomOrder) ); + } + } + + @Override + public String toString() { + return "SampleItemStub{" + + "randomOrder=" + randomOrder + + ", value=" + value + + '}'; + } +} diff --git a/h2o-scala/src/main/scala/water/api/dsl/util/Sampler.scala b/h2o-scala/src/main/scala/water/api/dsl/util/Sampler.scala new file mode 100644 index 0000000000..edc6933c5c --- /dev/null +++ b/h2o-scala/src/main/scala/water/api/dsl/util/Sampler.scala @@ -0,0 +1,38 @@ +package water.api.dsl.util + +import water.api.dsl.T_T_Collect + +/** + * Resevoir Sampler to extract a column from a DataFrame and bring it to the local context. + * + * + * Some Scala REPL foo: + * + * val g = parse ("/Users/jerdavis/temp/export1.gz") + * import water.api.dsl.util._ + * val smallData = g(100) collect ( new Reservoir(1000), new Sampler() ) + * + * + * import scalax.chart._ + * import scalax.chart.Charting._ + * + * val data = (1 to smallData.getNumValues) zip smallData.getValues + * val dataset = data.toXYSeriesCollection("some points") + * val chart = XYLineChart(dataset) + * chart.show + * + * Not entirely sure about threading / synchronization in this model + */ +class Sampler extends T_T_Collect[Reservoir,scala.Double] { + + override def apply(acc:Reservoir, rhs:Array[scala.Double]):Reservoir = { + for( x <- rhs ) { + acc.add(x) + } + acc + } + + override def reduce(lhs:Reservoir,rhs:Reservoir) = { + lhs.merge(rhs) + } +} \ No newline at end of file diff --git a/src/main/java/hex/LinearRegression.java b/src/main/java/hex/LinearRegression.java index 665f39e71b..d44e169508 100644 --- a/src/main/java/hex/LinearRegression.java +++ b/src/main/java/hex/LinearRegression.java @@ -5,6 +5,9 @@ public abstract class LinearRegression { static public JsonObject run( ValueArray ary, int colA, int colB ) { + return( exec(ary,colA,colB).toJson() ); + } + static public LRResult exec( ValueArray ary, int colA, int colB ) { // Pass 1: compute sums & sums-of-squares long start = System.currentTimeMillis(); CalcSumsTask lr1 = new CalcSumsTask(); @@ -42,23 +45,22 @@ static public JsonObject run( ValueArray ary, int colA, int colB ) { double svar1 = svar / lr2._XXbar; double svar0 = svar/n + lr2._Xbar*lr2._Xbar*svar1; - JsonObject res = new JsonObject(); - res.addProperty("Key", ary._key.toString()); - res.addProperty("ColA", ary._cols[colA]._name); - res.addProperty("ColB", ary._cols[colB]._name); - res.addProperty("Pass1Msecs", pass1 - start); - res.addProperty("Pass2Msecs", pass2-pass1); - res.addProperty("Pass3Msecs", pass3-pass2); - res.addProperty("Rows", n); - res.addProperty("Beta0", lr3._beta0); - res.addProperty("Beta1", lr3._beta1); - res.addProperty("RSquared", R2); - res.addProperty("Beta0StdErr", Math.sqrt(svar0)); - res.addProperty("Beta1StdErr", Math.sqrt(svar1)); - res.addProperty("SSTO", lr2._YYbar); - res.addProperty("SSE", lr3._rss); - res.addProperty("SSR", lr3._ssr); - return res; + LRResult result = new LRResult(ary._key.toString(), + ary._cols[colA]._name, + ary._cols[colB]._name, + pass1-start, + pass2-pass1, + pass3-pass2, + n, + lr3._beta0, + lr3._beta1, + R2, + Math.sqrt(svar0), + Math.sqrt(svar1), + lr2._YYbar, + lr3._rss, + lr3._ssr ); + return( result ); } public static class CalcSumsTask extends MRTask { @@ -220,4 +222,94 @@ public void reduce( DRemoteTask rt ) { _ssr += lr3._ssr; } } + public static class LRResult{ + public final String key; + public final String colA; + public final String colB; + public final long pass1Msecs; + public final long pass2Msecs; + public final long pass3Msecs; + public final long rows; + public final double beta0; + public final double beta1; + public final double rSquared; + public final double beta0StdErr; + public final double beta1StdErr; + public final double ssto; + public final double sse; + public final double ssr; + + public LRResult (String key, + String colA, + String colB, + long pass1Msecs, + long pass2Msecs, + long pass3Msecs, + long rows, + double beta0, + double beta1, + double rSquared, + double beta0StdErr, + double beta1StdErr, + double ssto, + double sse, + double ssr) { + this.key = key; + this.colA = colA; + this.colB = colB; + this.pass1Msecs = pass1Msecs; + this.pass2Msecs = pass2Msecs; + this.pass3Msecs = pass3Msecs; + this.rows = rows; + this.beta0 = beta0; + this.beta1 = beta1; + this.rSquared = rSquared; + this.beta0StdErr = beta0StdErr; + this.beta1StdErr = beta1StdErr; + this.ssto = ssto; + this.sse = sse; + this.ssr = ssr; + } + + @Override + public String toString () { + return "LRResult{" + + "key='" + key + '\'' + + ", colA='" + colA + '\'' + + ", colB='" + colB + '\'' + + ", pass1Msecs=" + pass1Msecs + + ", pass2Msecs=" + pass2Msecs + + ", pass3Msecs=" + pass3Msecs + + ", rows=" + rows + + ", beta0=" + beta0 + + ", beta1=" + beta1 + + ", rSquared=" + rSquared + + ", beta0StdErr=" + beta0StdErr + + ", beta1StdErr=" + beta1StdErr + + ", ssto=" + ssto + + ", sse=" + sse + + ", ssr=" + ssr + + '}'; + } + + public JsonObject toJson() { + JsonObject res = new JsonObject(); + res.addProperty("Key", key); + res.addProperty("ColA", colA); + res.addProperty("ColB", colB); + res.addProperty("Pass1Msecs", pass1Msecs); + res.addProperty("Pass2Msecs", pass2Msecs); + res.addProperty("Pass3Msecs", pass3Msecs); + res.addProperty("Rows", rows); + res.addProperty("Beta0", beta0); + res.addProperty("Beta1", beta1); + res.addProperty("RSquared", rSquared); + res.addProperty("Beta0StdErr", beta0StdErr); + res.addProperty("Beta1StdErr", beta1StdErr); + res.addProperty("SSTO", ssto); + res.addProperty("SSE", sse); + res.addProperty("SSR", ssr); + return( res ); + } + } }