Skip to content

Commit

Permalink
Add Scala API for deeplearning.
Browse files Browse the repository at this point in the history
- simple deeplearning call
- add an example of call usage
  • Loading branch information
mmalohlava committed Apr 3, 2014
1 parent 7d003e0 commit 8ee8ee6
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 29 deletions.
20 changes: 16 additions & 4 deletions h2o-scala/src/main/scala/water/api/dsl/DslLegos.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ import water.fvec.ParseDataset2
import water.Job
import hex.drf.DRF
import water.fvec.NewChunk
import hex.Quantiles
import water.api.QuantilesPage
import hex.deeplearning.{DeepLearning, DeepLearningModel}

trait TRef {}

Expand Down Expand Up @@ -244,10 +244,12 @@ abstract trait T_MR[T <: DFrame] {
trait T_H2O_Env[K<:HexKey, VT <: DFrame] { // Operating with only given representation of key

// Parse a dataset
def parse(s:String):DFrame = parse(s, s+".hex")
def parse(s:String, destKey:String):DFrame = {
def parse(s:String):DFrame = parse(new File(s))
def parse(file:File):DFrame = parse(file, file.getName+".hex")
def parse(s:String, destKey:String):DFrame = parse(new File(s), destKey)
def parse(file:File, destKey:String):DFrame = {
val dest: Key = Key.make(destKey)
val fkey = NFSFileVec.make(new File(s))
val fkey:Key = NFSFileVec.make(file)
val f = ParseDataset2.parse(dest, Array(fkey))
UKV.remove(fkey)
// Wrap the frame
Expand Down Expand Up @@ -298,6 +300,16 @@ trait T_H2O_Env[K<:HexKey, VT <: DFrame] { // Operating with only given represen
qp.invoke()
return qp.result
}

def deeplearning(ftrain: VT, ftest: VT, x:Seq[Int], y:Int, params: (DeepLearning)=>DeepLearning):DeepLearningModel = {
val dl = new DeepLearning
dl.source = ftrain(x++Seq(y)).frame()
dl.response = ftrain.frame().vec(y)
dl.validation = if (ftest != null) ftest.frame() else null
// Fill parameters and invoke computation
params(dl).invoke()
return UKV.get(dl.dest())
}
}

/** Trait representing provided global environment in R-like style.
Expand Down
104 changes: 79 additions & 25 deletions h2o-scala/src/main/scala/water/api/dsl/examples/Examples.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,36 +2,21 @@ package water.api.dsl.examples

import water.{Boot, H2O, Iced}
import water.api.dsl.{Row, T_T_Collect, DFrame}
import hex.deeplearning.DeepLearning
import java.io.File
import java.util.Random
import hex.deeplearning.DeepLearning.ClassSamplingMethod

/**
* Simple example project.
*
* Ideas:
* GroupBy for H2O.
* Histogram.
* Matrix product.
* - GroupBy for H2O.
* - Histogram.
* - Matrix product.
*/
object Examples {

// Call in the context of main classloader
def main(args: Array[String]):Unit = {
Boot.main(classOf[Examples], args)
}

// Call in the context of H2O classloader
def userMain(args: Array[String]):Unit = {
H2O.main(args)
example1()
example2()
example3()
water.api.dsl.H2ODsl.shutdown()
}

def banner(id:Int, desc: String) = {
println("\n==== Example #"+id+" ====\n== \""+desc+"\"" )
println( "====\n")
}

/** Compute average for given column. */
def example1() = {
banner(1, "Compute average of 2nd column in cars dataset")
Expand All @@ -40,7 +25,7 @@ object Examples {
/** Mutable class */
class Avg(var sum:scala.Double, var cnt:Int) extends Iced;

val f = parse("../private/cars.csv")
val f = parse("../smalldata/cars.csv")
val r = f collect ( new Avg(0,0),
new T_T_Collect[Avg] {
override def apply(acc:Avg, rhs:Row):Avg = {
Expand All @@ -58,7 +43,7 @@ object Examples {
def example2() = {
banner(2, "Call DRF API and make a forest for cars dataset")
import water.api.dsl.H2ODsl._
val f = parse("../private/cars.csv")
val f = parse(ffind("smalldata/cars.csv"))
val source = f(1) ++ f(3 to 7)
val response = f(2)

Expand Down Expand Up @@ -86,7 +71,7 @@ object Examples {
def example3() = {
banner(3, "Call quantiles API and compute quantiles for all columns in cars dataset.")
import water.api.dsl.H2ODsl._
val f = parse("../private/cars.csv")
val f = parse("../smalldata/cars.csv")

// Iterate over columns, pick only non-enum column and compute quantile for the column
for (columnId <- 0 until f.ncol) {
Expand All @@ -99,6 +84,75 @@ object Examples {
}
}
}

/** Simple example of deep learning model builder. */
def example4() = {
banner(4, "Call deep learning model builder and validate it on test data.")
import water.api.dsl.H2ODsl._
// Parse train dataset
val ftrain = parse(ffind("smalldata/logreg/prostate.csv"))
// Create parameters for deep learning
val params = (p:DeepLearning) => {
import p._
epochs = 1.0
hidden = Array(1+new Random(seed).nextInt(4), 1+new Random(seed).nextInt(6))
classification = true
seed = seed
mini_batch = 0
force_load_balance = false
replicate_training_data = true
shuffle_training_data = true
score_training_samples = 0
score_validation_samples = 0
balance_classes = true
quiet_mode = false
score_validation_sampling = ClassSamplingMethod.Stratified
p
}

val dlModel = deeplearning(ftrain, null, 2 to 8, 1, params)
println("Resulting model: " + dlModel)
}

// Call in the context of main classloader
def main(args: Array[String]):Unit = {
Boot.main(classOf[Examples], args)
}

// Call in the context of H2O classloader
def userMain(args: Array[String]):Unit = {
H2O.main(args)
try {
example1()
example2()
example3()
example4()
} catch {
case t:Throwable => t.printStackTrace() // Simple debug
} finally {
water.api.dsl.H2ODsl.shutdown()
}
}

// Print simple banner
private def banner(id:Int, desc: String) = {
println("\n==== Example #"+id+" ====\n * "+desc )
println( "====================\n")
}

// Find a given filename
private def ffind(fname: String):File = {
var file = new File(fname)
if (!file.exists())
file = new File("../" + fname)
if (!file.exists())
file = new File("../../" + fname)
if (!file.exists())
file = null
file
}


}

// Companion class
Expand Down

0 comments on commit 8ee8ee6

Please sign in to comment.