Skip to content

Commit

Permalink
Add N-fold CV, use GBM for MNIST.
Browse files Browse the repository at this point in the history
  • Loading branch information
arnocandel committed Jan 17, 2015
1 parent 8759a96 commit 60a1957
Showing 1 changed file with 11 additions and 7 deletions.
18 changes: 11 additions & 7 deletions R/examples/H2ODeepLearningNextML.R
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ adult_hex <- h2o.importFile(h2oServer, "adult.gz")
dim(adult_hex)
summary(adult_hex)


## Set column names manually
colnames(adult_hex) <- c("age","workclass","fnlwgt","education","education-num","marital-status",
"occupation","relationship","race","sex","capital-gain","capital-loss",
Expand All @@ -82,11 +83,11 @@ best_model@model$params$activation
best_model@model$params$l1
best_model@model$params$hidden


## Compare to GBM
gbm_model <- h2o.gbm(data=adult_hex, x=pred, y=resp)
gbm_model

## Compute N-fold CV error for the best model
cv_model <- h2o.deeplearning(data=adult_hex, x=pred, y="income",
activation=c("Rectifier"), l1=best_model@model$params$l1,
hidden=best_model@model$params$hidden, epochs=10, nfold=5)
cv_model


## MNIST
Expand All @@ -97,8 +98,11 @@ mnist_test_hex <- h2o.importFile(h2oServer, "mnist.test.csv.gz")
dim(mnist_train_hex)
dim(mnist_test_hex)

mnist_model <- h2o.deeplearning(data=mnist_train_hex, validation=mnist_test_hex, hidden=c(20,20,20), epochs=1, x=1:784, y=785)
mnist_model
dl_model <- h2o.deeplearning(data=mnist_train_hex, validation=mnist_test_hex, hidden=c(20,20,20), epochs=1, x=1:784, y=785)
dl_model

## Compare to GBM
gbm_model <- h2o.gbm(data=mnist_train_hex, validation=mnist_test_hex, x=1:784, y=785, distribution="multinomial", n.trees=5)
gbm_model

## For more examples, see http://learn.h2o.ai/content/

0 comments on commit 60a1957

Please sign in to comment.