diff --git a/h2o-scala/src/main/scala/water/api/dsl/DslLegos.scala b/h2o-scala/src/main/scala/water/api/dsl/DslLegos.scala index 82eb0886f2..66da04888f 100644 --- a/h2o-scala/src/main/scala/water/api/dsl/DslLegos.scala +++ b/h2o-scala/src/main/scala/water/api/dsl/DslLegos.scala @@ -281,15 +281,11 @@ trait T_H2O_Env[K<:HexKey, VT <: DFrame] { // Operating with only given represen def shutdown() = H2O.CLOUD.shutdown() // DRF API call - def drf(f: VT, r: VT, ntrees: Int = 50, classification:Boolean = false): DRF.DRFModel = { + def drf(ftrain: VT, ftest:VT, x:Seq[Int], y:Int, params: (DRF)=>DRF ): DRF.DRFModel = { val drf:DRF = new DRF() - val response = r.frame().vecs()(0) - response.rollupStats() - drf.source = new Frame(f.frame().names() ++ Array("response"), f.frame.vecs()++Array(response)) - drf.response = response - drf.classification = classification - drf.ntrees = ntrees; - drf.invoke() + drf.source = ftrain(x++Seq(y)).frame() + drf.response = ftrain.frame().vec(y) + params(drf).invoke() return UKV.get(drf.dest()) } diff --git a/h2o-scala/src/main/scala/water/api/dsl/examples/Examples.scala b/h2o-scala/src/main/scala/water/api/dsl/examples/Examples.scala index 42a1a6be49..c038e09f9b 100644 --- a/h2o-scala/src/main/scala/water/api/dsl/examples/Examples.scala +++ b/h2o-scala/src/main/scala/water/api/dsl/examples/Examples.scala @@ -6,6 +6,7 @@ import hex.deeplearning.DeepLearning import java.io.File import java.util.Random import hex.deeplearning.DeepLearning.ClassSamplingMethod +import hex.drf.DRF /** * Simple example project. @@ -41,21 +42,26 @@ object Examples { /** Call DRF, make a model, predict on a train data, compute MSE. */ def example2() = { - banner(2, "Call DRF API and make a forest for cars dataset") + banner(2, """The example calls DRF API and produces a forest of regression trees for cars dataset. +The response column represents number of cylinders for each car included in train data. Example then +makes a prediction over train data and compute MSE of prediction."""") + import water.api.dsl.H2ODsl._ val f = parse(ffind("smalldata/cars.csv")) - val source = f(1) ++ f(3 to 7) - val response = f(2) - + val params = (p:DRF) => { import p._ + ntrees = 10 + classification = false + p + } // build a model - val model = drf(source, response, 10, false) // doing regression + val model = drf(f, null, 3 to 7, 2, params) // doing regression println("The DRF model is: \n" + model) // make a prediction - val predict:DFrame = model.score(source.frame()) + val predict:DFrame = model.score(f.frame()) println("Prediction on train data: \n" + predict) - // compute squared errors + // compute mean squared errors val serr = (response - predict)^2 println("Errors per row: " + serr) // make a sum @@ -123,10 +129,10 @@ object Examples { def userMain(args: Array[String]):Unit = { H2O.main(args) try { - example1() + //example1() example2() - example3() - example4() + //example3() + //example4() } catch { case t:Throwable => t.printStackTrace() // Simple debug } finally { @@ -136,7 +142,7 @@ object Examples { // Print simple banner private def banner(id:Int, desc: String) = { - println("\n==== Example #"+id+" ====\n * "+desc ) + println("\n==== Example #"+id+" ====\n "+desc ) println( "====================\n") }