Skip to content

Commit 2d67ed1

Browse files
committed
Merge branch 'master' of github.com:jgustave/h2o into jgustave-master
Conflicts: h2o-scala/build.sbt
2 parents 6069c1d + a62bd2c commit 2d67ed1

File tree

5 files changed

+323
-18
lines changed

5 files changed

+323
-18
lines changed

h2o-scala/build.sbt

+3-1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ libraryDependencies += "org.apache.hadoop" % "hadoop-client" % "1.1.0"
3333

3434
libraryDependencies += "joda-time" % "joda-time" % "2.3"
3535

36+
libraryDependencies += "com.github.wookietreiber" %% "scala-chart" % "latest.integration"
37+
3638
libraryDependencies <+= scalaVersion { v => "org.scala-lang" % "scala-library" % v }
3739

3840
libraryDependencies <+= scalaVersion { v => "org.scala-lang" % "scala-compiler" % v }
@@ -64,7 +66,7 @@ unmanagedClasspath in Compile += h2oSources.value
6466

6567
unmanagedClasspath in Runtime += h2oClasses.value
6668

67-
// Setup run
69+
// Setup run
6870
// - Fork in run
6971
fork in run := true
7072

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
package water.api.dsl.util;
2+
3+
import water.AutoBuffer;
4+
import water.Iced;
5+
6+
import java.util.PriorityQueue;
7+
8+
/**
9+
* Simple utility to perform reservoir sampling in order to make huge data small enough to
10+
* bring in to local memory and display/etc.
11+
* I did this in Java instead of Scala, because I wasn't sure about Serialization/transient issues in Scala.
12+
*
13+
* Not Entirely Sure about threading/synchronization in this model
14+
*/
15+
public class Reservoir extends Iced {
16+
17+
private transient final PriorityQueue<SampleItem> minHeap = new PriorityQueue<SampleItem>();
18+
public final int reservoirSize;
19+
20+
//Needed For Serialization?
21+
public Reservoir() {
22+
this.reservoirSize = -1;
23+
}
24+
25+
public Reservoir(int reservoirSize) {
26+
this.reservoirSize = reservoirSize;
27+
}
28+
29+
@Override
30+
public water.AutoBuffer write(AutoBuffer bb) {
31+
int[] order = new int[minHeap.size()];
32+
double[] vals = new double[minHeap.size()]; //TODO: Till we figure out nulls
33+
34+
int x=0;
35+
for(SampleItem item : minHeap) {
36+
order[x] = item.getRandomOrder();
37+
vals[x] = item.getValue();//==null?0:item.getValue(); //TODO: till we figure out nulls
38+
x++;
39+
}
40+
bb.put4(reservoirSize);
41+
bb.putA4(order);
42+
bb.putA8d(vals);
43+
return( bb );
44+
}
45+
46+
@SuppressWarnings("unchecked")
47+
@Override
48+
public water.api.dsl.util.Reservoir read(AutoBuffer bb) {
49+
int rSize = bb.get4();
50+
int[] order = bb.getA4();
51+
double[] vals = bb.getA8d();
52+
Reservoir reservoir = new Reservoir(rSize);
53+
54+
for( int x=0;x<order.length;x++) {
55+
reservoir.minHeap.add(new SampleItem(order[x], vals[x]));
56+
}
57+
58+
return(reservoir);
59+
}
60+
61+
public void add(double item) {
62+
add( new SampleItem(item) );
63+
}
64+
65+
synchronized public void add(SampleItem item) {
66+
if( item != null ) {
67+
if( minHeap.size() < reservoirSize) {
68+
minHeap.add(item);
69+
}else {
70+
SampleItem head = minHeap.peek();
71+
//If Item is > than the lest item in the heap.. then swap them out.
72+
if( item.getRandomOrder() > head.getRandomOrder() ) {
73+
minHeap.poll();
74+
minHeap.add(item);
75+
}
76+
}
77+
}
78+
}
79+
80+
// synchronized public void merge( Reservoir other ) {
81+
// if( other != null ) {
82+
// for( SampleItem item : other.minHeap) {
83+
// add(item);
84+
// }
85+
// }
86+
// }
87+
synchronized public Reservoir merge( Reservoir other ) {
88+
Reservoir result = new Reservoir(this.reservoirSize);
89+
for(SampleItem item : this.minHeap ) {
90+
result.add(item);
91+
}
92+
if( other != null ) {
93+
for(SampleItem item : other.minHeap ) {
94+
result.add(item);
95+
}
96+
}
97+
return( result );
98+
}
99+
100+
101+
synchronized public double[] getValues() {
102+
double[] result = new double[minHeap.size()];
103+
104+
int x=0;
105+
for(SampleItem item : minHeap) {
106+
result[x] = item.getValue();
107+
x++;
108+
}
109+
return( result );
110+
}
111+
112+
synchronized public int getNumValues(){
113+
return( minHeap.size() );
114+
}
115+
116+
// public static void main(String[] args) {
117+
// Reservoir reservoir = new Reservoir(10);
118+
// for( int x=0;x<15;x++) {
119+
// reservoir.add((double)x);
120+
// }
121+
// }
122+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
package water.api.dsl.util;
2+
3+
import water.Iced;
4+
5+
import java.util.Random;
6+
7+
/**
8+
* Wraps an value, with a random int (Used in reservoir sampling)
9+
*/
10+
public class SampleItem extends Iced implements Comparable<SampleItem> {
11+
private static final Random rand = new Random();
12+
public final int randomOrder;
13+
public final double value;
14+
15+
public SampleItem(int randomOrder, double value) {
16+
this.randomOrder = randomOrder;
17+
this.value = value;
18+
}
19+
20+
public SampleItem(double value) {
21+
this.value = value;
22+
this.randomOrder = rand.nextInt();
23+
}
24+
25+
public int getRandomOrder() {
26+
return randomOrder;
27+
}
28+
29+
public double getValue() {
30+
return value;
31+
}
32+
33+
@Override
34+
public int compareTo(SampleItem that) {
35+
if(this == that) {
36+
return(0);
37+
}else if(that == null ) {
38+
return(1);
39+
}else {
40+
return( java.lang.Integer.compare(this.randomOrder,that.randomOrder) );
41+
}
42+
}
43+
44+
@Override
45+
public String toString() {
46+
return "SampleItemStub{" +
47+
"randomOrder=" + randomOrder +
48+
", value=" + value +
49+
'}';
50+
}
51+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
package water.api.dsl.util
2+
3+
import water.api.dsl.T_T_Collect
4+
5+
/**
6+
* Resevoir Sampler to extract a column from a DataFrame and bring it to the local context.
7+
*
8+
*
9+
* Some Scala REPL foo:
10+
*
11+
* val g = parse ("/Users/jerdavis/temp/export1.gz")
12+
* import water.api.dsl.util._
13+
* val smallData = g(100) collect ( new Reservoir(1000), new Sampler() )
14+
*
15+
*
16+
* import scalax.chart._
17+
* import scalax.chart.Charting._
18+
*
19+
* val data = (1 to smallData.getNumValues) zip smallData.getValues
20+
* val dataset = data.toXYSeriesCollection("some points")
21+
* val chart = XYLineChart(dataset)
22+
* chart.show
23+
*
24+
* Not entirely sure about threading / synchronization in this model
25+
*/
26+
class Sampler extends T_T_Collect[Reservoir,scala.Double] {
27+
28+
override def apply(acc:Reservoir, rhs:Array[scala.Double]):Reservoir = {
29+
for( x <- rhs ) {
30+
acc.add(x)
31+
}
32+
acc
33+
}
34+
35+
override def reduce(lhs:Reservoir,rhs:Reservoir) = {
36+
lhs.merge(rhs)
37+
}
38+
}

src/main/java/hex/LinearRegression.java

+109-17
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
public abstract class LinearRegression {
66

77
static public JsonObject run( ValueArray ary, int colA, int colB ) {
8+
return( exec(ary,colA,colB).toJson() );
9+
}
10+
static public LRResult exec( ValueArray ary, int colA, int colB ) {
811
// Pass 1: compute sums & sums-of-squares
912
long start = System.currentTimeMillis();
1013
CalcSumsTask lr1 = new CalcSumsTask();
@@ -42,23 +45,22 @@ static public JsonObject run( ValueArray ary, int colA, int colB ) {
4245
double svar1 = svar / lr2._XXbar;
4346
double svar0 = svar/n + lr2._Xbar*lr2._Xbar*svar1;
4447

45-
JsonObject res = new JsonObject();
46-
res.addProperty("Key", ary._key.toString());
47-
res.addProperty("ColA", ary._cols[colA]._name);
48-
res.addProperty("ColB", ary._cols[colB]._name);
49-
res.addProperty("Pass1Msecs", pass1 - start);
50-
res.addProperty("Pass2Msecs", pass2-pass1);
51-
res.addProperty("Pass3Msecs", pass3-pass2);
52-
res.addProperty("Rows", n);
53-
res.addProperty("Beta0", lr3._beta0);
54-
res.addProperty("Beta1", lr3._beta1);
55-
res.addProperty("RSquared", R2);
56-
res.addProperty("Beta0StdErr", Math.sqrt(svar0));
57-
res.addProperty("Beta1StdErr", Math.sqrt(svar1));
58-
res.addProperty("SSTO", lr2._YYbar);
59-
res.addProperty("SSE", lr3._rss);
60-
res.addProperty("SSR", lr3._ssr);
61-
return res;
48+
LRResult result = new LRResult(ary._key.toString(),
49+
ary._cols[colA]._name,
50+
ary._cols[colB]._name,
51+
pass1-start,
52+
pass2-pass1,
53+
pass3-pass2,
54+
n,
55+
lr3._beta0,
56+
lr3._beta1,
57+
R2,
58+
Math.sqrt(svar0),
59+
Math.sqrt(svar1),
60+
lr2._YYbar,
61+
lr3._rss,
62+
lr3._ssr );
63+
return( result );
6264
}
6365

6466
public static class CalcSumsTask extends MRTask {
@@ -220,4 +222,94 @@ public void reduce( DRemoteTask rt ) {
220222
_ssr += lr3._ssr;
221223
}
222224
}
225+
public static class LRResult{
226+
public final String key;
227+
public final String colA;
228+
public final String colB;
229+
public final long pass1Msecs;
230+
public final long pass2Msecs;
231+
public final long pass3Msecs;
232+
public final long rows;
233+
public final double beta0;
234+
public final double beta1;
235+
public final double rSquared;
236+
public final double beta0StdErr;
237+
public final double beta1StdErr;
238+
public final double ssto;
239+
public final double sse;
240+
public final double ssr;
241+
242+
public LRResult (String key,
243+
String colA,
244+
String colB,
245+
long pass1Msecs,
246+
long pass2Msecs,
247+
long pass3Msecs,
248+
long rows,
249+
double beta0,
250+
double beta1,
251+
double rSquared,
252+
double beta0StdErr,
253+
double beta1StdErr,
254+
double ssto,
255+
double sse,
256+
double ssr) {
257+
this.key = key;
258+
this.colA = colA;
259+
this.colB = colB;
260+
this.pass1Msecs = pass1Msecs;
261+
this.pass2Msecs = pass2Msecs;
262+
this.pass3Msecs = pass3Msecs;
263+
this.rows = rows;
264+
this.beta0 = beta0;
265+
this.beta1 = beta1;
266+
this.rSquared = rSquared;
267+
this.beta0StdErr = beta0StdErr;
268+
this.beta1StdErr = beta1StdErr;
269+
this.ssto = ssto;
270+
this.sse = sse;
271+
this.ssr = ssr;
272+
}
273+
274+
@Override
275+
public String toString () {
276+
return "LRResult{" +
277+
"key='" + key + '\'' +
278+
", colA='" + colA + '\'' +
279+
", colB='" + colB + '\'' +
280+
", pass1Msecs=" + pass1Msecs +
281+
", pass2Msecs=" + pass2Msecs +
282+
", pass3Msecs=" + pass3Msecs +
283+
", rows=" + rows +
284+
", beta0=" + beta0 +
285+
", beta1=" + beta1 +
286+
", rSquared=" + rSquared +
287+
", beta0StdErr=" + beta0StdErr +
288+
", beta1StdErr=" + beta1StdErr +
289+
", ssto=" + ssto +
290+
", sse=" + sse +
291+
", ssr=" + ssr +
292+
'}';
293+
}
294+
295+
public JsonObject toJson() {
296+
JsonObject res = new JsonObject();
297+
res.addProperty("Key", key);
298+
res.addProperty("ColA", colA);
299+
res.addProperty("ColB", colB);
300+
res.addProperty("Pass1Msecs", pass1Msecs);
301+
res.addProperty("Pass2Msecs", pass2Msecs);
302+
res.addProperty("Pass3Msecs", pass3Msecs);
303+
res.addProperty("Rows", rows);
304+
res.addProperty("Beta0", beta0);
305+
res.addProperty("Beta1", beta1);
306+
res.addProperty("RSquared", rSquared);
307+
res.addProperty("Beta0StdErr", beta0StdErr);
308+
res.addProperty("Beta1StdErr", beta1StdErr);
309+
res.addProperty("SSTO", ssto);
310+
res.addProperty("SSE", sse);
311+
res.addProperty("SSR", ssr);
312+
return( res );
313+
}
314+
}
223315
}

0 commit comments

Comments
 (0)