Skip to content

Commit

Permalink
Unified DRF API with deep learning API.
Browse files Browse the repository at this point in the history
- Scala DRF API now supports the same style of
passing arguments via anynymous function filling
algorithm object
  • Loading branch information
mmalohlava committed Apr 3, 2014
1 parent 8ee8ee6 commit dbf8d41
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 19 deletions.
12 changes: 4 additions & 8 deletions h2o-scala/src/main/scala/water/api/dsl/DslLegos.scala
Original file line number Diff line number Diff line change
Expand Up @@ -281,15 +281,11 @@ trait T_H2O_Env[K<:HexKey, VT <: DFrame] { // Operating with only given represen
def shutdown() = H2O.CLOUD.shutdown()

// DRF API call
def drf(f: VT, r: VT, ntrees: Int = 50, classification:Boolean = false): DRF.DRFModel = {
def drf(ftrain: VT, ftest:VT, x:Seq[Int], y:Int, params: (DRF)=>DRF ): DRF.DRFModel = {
val drf:DRF = new DRF()
val response = r.frame().vecs()(0)
response.rollupStats()
drf.source = new Frame(f.frame().names() ++ Array("response"), f.frame.vecs()++Array(response))
drf.response = response
drf.classification = classification
drf.ntrees = ntrees;
drf.invoke()
drf.source = ftrain(x++Seq(y)).frame()
drf.response = ftrain.frame().vec(y)
params(drf).invoke()
return UKV.get(drf.dest())
}

Expand Down
28 changes: 17 additions & 11 deletions h2o-scala/src/main/scala/water/api/dsl/examples/Examples.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import hex.deeplearning.DeepLearning
import java.io.File
import java.util.Random
import hex.deeplearning.DeepLearning.ClassSamplingMethod
import hex.drf.DRF

/**
* Simple example project.
Expand Down Expand Up @@ -41,21 +42,26 @@ object Examples {

/** Call DRF, make a model, predict on a train data, compute MSE. */
def example2() = {
banner(2, "Call DRF API and make a forest for cars dataset")
banner(2, """The example calls DRF API and produces a forest of regression trees for cars dataset.
The response column represents number of cylinders for each car included in train data. Example then
makes a prediction over train data and compute MSE of prediction."""")

import water.api.dsl.H2ODsl._
val f = parse(ffind("smalldata/cars.csv"))
val source = f(1) ++ f(3 to 7)
val response = f(2)

val params = (p:DRF) => { import p._
ntrees = 10
classification = false
p
}
// build a model
val model = drf(source, response, 10, false) // doing regression
val model = drf(f, null, 3 to 7, 2, params) // doing regression
println("The DRF model is: \n" + model)
// make a prediction
val predict:DFrame = model.score(source.frame())
val predict:DFrame = model.score(f.frame())

println("Prediction on train data: \n" + predict)

// compute squared errors
// compute mean squared errors
val serr = (response - predict)^2
println("Errors per row: " + serr)
// make a sum
Expand Down Expand Up @@ -123,10 +129,10 @@ object Examples {
def userMain(args: Array[String]):Unit = {
H2O.main(args)
try {
example1()
//example1()
example2()
example3()
example4()
//example3()
//example4()
} catch {
case t:Throwable => t.printStackTrace() // Simple debug
} finally {
Expand All @@ -136,7 +142,7 @@ object Examples {

// Print simple banner
private def banner(id:Int, desc: String) = {
println("\n==== Example #"+id+" ====\n * "+desc )
println("\n==== Example #"+id+" ====\n "+desc )
println( "====================\n")
}

Expand Down

0 comments on commit dbf8d41

Please sign in to comment.