Skip to content

Commit

Permalink
Merge branch 'master' of github.com:h2oai/h2o
Browse files Browse the repository at this point in the history
  • Loading branch information
tomkraljevic committed Dec 16, 2014
2 parents 32803e8 + cac5355 commit 4d5b1ef
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 10 deletions.
8 changes: 4 additions & 4 deletions R/h2o-package/R/Internal.R
Original file line number Diff line number Diff line change
Expand Up @@ -929,11 +929,11 @@ h2o.getModel <- function(h2o, key) {
if(!is.null(response$"_dataKey")) train_fr <- h2o.getFrame(h2o, response$"_dataKey")
params$importance <- !is.null(params$varimp)
if(!is.null(params$family) && model.type == "gbm_model") {
params$distribution <- "multinomial"
if(params$family == "AUTO") {
if(!is.null(json[[model.type]]$validAUC)) params$distribution <- "bernoulli"
if(params$classification == "false") {params$distribution <- "gaussian"
} else {
if(length(params$'_distribution') > 2) params$distribution <- "multinomial" else params$distribution <- "bernoulli"
}
}
}
if(algo == "model") {
newModel <- new(model_obj, key = dest_key, data = train_fr, model = results_fun(json[[model.type]], train_fr, params))
return(newModel)
Expand Down
21 changes: 15 additions & 6 deletions R/tests/testdir_jira/runit_hex_1775_save_load.R
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ test.hex_1775 <- function(conn) {
Log.info("Build GLM model")
prostate.glm = h2o.glm(y = "CAPSULE", x = c("AGE","RACE","PSA","DCAPS"), data = prostate.hex, family = "binomial", nfolds = 0, alpha = 0.5)
Log.info("Build GBM model")
prostate.gbm = h2o.gbm(y = 2, x = 3:9, data = prostate.hex, nfolds = 5)
prostate.gbm = h2o.gbm(y = 2, x = 3:9, data = prostate.hex, nfolds = 5, distribution = "bernoulli")
Log.info("Build GBM model with Gaussian distribution")
prostate.gbm2 = h2o.gbm(y = "PSA", x = c("AGE","RACE","CAPSULE","DCAPS"), data = prostate.hex, distribution = "gaussian")
Log.info("Build Speedy Random Forest Model")
iris.speedrf = h2o.randomForest(x = c(2,3,4), y = 5, data = iris.hex, ntree = 10, depth = 20, type = "fast")
Log.info("Build BigData Random Forest Model")
Expand All @@ -44,6 +46,7 @@ test.hex_1775 <- function(conn) {

glm.pred = pred_df(object = prostate.glm , newdata = prostate.hex)
gbm.pred = pred_df(object = prostate.gbm, newdata = prostate.hex)
gbm.pred_2 = pred_df(object = prostate.gbm2, newdata = prostate.hex)
speedrf.pred = pred_df(object = iris.speedrf, newdata = iris.hex)
rf.pred = pred_df(object = iris.rf, newdata = iris.hex)
nb.pred = pred_df(object = iris.nb, newdata = iris.hex)
Expand All @@ -53,10 +56,12 @@ test.hex_1775 <- function(conn) {
Log.info("Saving models to disk")
prostate.glm.path = h2o.saveModel(object = prostate.glm, dir = temp_subdir1, save_cv = FALSE, force = TRUE)
prostate.gbm.path = h2o.saveModel(object = prostate.gbm, dir = temp_subdir1, save_cv = TRUE, force = TRUE)
prostate.gbm.path2 = h2o.saveModel(object = prostate.gbm2, dir = temp_subdir1, save_cv = FALSE, force = TRUE)
iris.speedrf.path = h2o.saveModel(object = iris.speedrf, dir = temp_subdir1, save_cv = FALSE, force = TRUE)
iris.rf.path = h2o.saveModel(object = iris.rf, dir = temp_subdir1, save_cv = TRUE, force = TRUE)
iris.nb.path = h2o.saveModel(object = iris.nb, dir = temp_subdir1, save_cv = FALSE, force = TRUE)
iris.dl.path = h2o.saveModel(object = iris.dl, dir = temp_subdir1, save_cv = FALSE, force = TRUE)
Log.info("Finished saving models to disk")

# All keys removed to test that cross validation models are actually being loaded
h2o.removeAll(object = conn)
Expand All @@ -65,7 +70,7 @@ test.hex_1775 <- function(conn) {
Log.info(paste("Moving models from", temp_subdir1, "to", temp_subdir2))
file.rename(temp_subdir1, temp_subdir2)

model_paths = c(prostate.glm.path, prostate.gbm.path, iris.speedrf.path, iris.rf.path, iris.nb.path, iris.dl.path)
model_paths = c(prostate.glm.path, prostate.gbm.path, prostate.gbm.path2, iris.speedrf.path, iris.rf.path, iris.nb.path, iris.dl.path)
new_model_paths = {}
for (path in model_paths) {
new_path = paste(temp_subdir2,basename(path),sep = .Platform$file.sep)
Expand All @@ -88,13 +93,15 @@ test.hex_1775 <- function(conn) {
Log.info("Running Predictions for Loaded Models")
glm2 = reloaded_models[[1]]
gbm2 = reloaded_models[[2]]
speedrf2 = reloaded_models[[3]]
rf2 = reloaded_models[[4]]
nb2 = reloaded_models[[5]]
dl2 = reloaded_models[[6]]
gbm2_2 = reloaded_models[[3]]
speedrf2 = reloaded_models[[4]]
rf2 = reloaded_models[[5]]
nb2 = reloaded_models[[6]]
dl2 = reloaded_models[[7]]

glm.pred2 = pred_df(object = glm2, newdata = prostate.hex)
gbm.pred2 = pred_df(object = gbm2, newdata = prostate.hex)
gbm.pred2_2 = pred_df(object = gbm2_2, newdata = prostate.hex)
speedrf.pred2 = pred_df(object = speedrf2, newdata = iris.hex)
rf.pred2 = pred_df(object = rf2, newdata = iris.hex)
nb.pred2 = pred_df(object = nb2, newdata = iris.hex)
Expand All @@ -105,6 +112,8 @@ test.hex_1775 <- function(conn) {
expect_equal(glm.pred, glm.pred2)
expect_equal(nrow(gbm.pred), 380)
expect_equal(gbm.pred, gbm.pred2)
expect_equal(nrow(gbm.pred_2), 380)
expect_equal(gbm.pred_2, gbm.pred2_2)
expect_equal(nrow(rf.pred), 150)
expect_equal(rf.pred, rf.pred2)
expect_equal(nrow(nb.pred), 150)
Expand Down

0 comments on commit 4d5b1ef

Please sign in to comment.