From 8d6764aa6271333a50b36358de6942d84af39b55 Mon Sep 17 00:00:00 2001 From: Tom Kraljevic Date: Sun, 1 Mar 2015 17:29:30 -0800 Subject: [PATCH] Minor test tweaks to make it easier to demo offline. --- .../testdir_demos/runit_demo_VI_all_algos.R | 18 +++++++---- R/tests/testdir_demos/runit_demo_tk_cm_roc.R | 31 ++++++++++++++----- R/tests/testdir_demos/runit_demo_tk_steam.R | 11 ++++--- 3 files changed, 41 insertions(+), 19 deletions(-) diff --git a/R/tests/testdir_demos/runit_demo_VI_all_algos.R b/R/tests/testdir_demos/runit_demo_VI_all_algos.R index bebc71a30a..7519fdd1fc 100644 --- a/R/tests/testdir_demos/runit_demo_VI_all_algos.R +++ b/R/tests/testdir_demos/runit_demo_VI_all_algos.R @@ -13,7 +13,13 @@ test <- function(h) { # Parse data into H2O print("Parsing data into H2O") # From an h2o git workspace. -data.hex = h2o.importFile(h, normalizePath(locate("smalldata/bank-additional-full.csv")), key="data.hex") +if (FALSE) { + h = h2o.init() + data.hex = h2o.importFile(h, "/Users/tomk/0xdata/ws/h2o/smalldata/bank-additional-full.csv", key="data.hex") +} +else { + data.hex = h2o.importFile(h, normalizePath(locate("smalldata/bank-additional-full.csv")), key="data.hex") +} # Or directly from github. # data.hex = h2o.importFile(h, path = "https://raw.github.com/0xdata/h2o/master/smalldata/bank-additional-full.csv", key="data.hex") @@ -31,7 +37,7 @@ myY="y" # Run GBM with variable importance my.gbm <- h2o.gbm(x = myX, y = myY, distribution = "bernoulli", data = data.hex, n.trees =100, - interaction.depth = 2, shrinkage = 0.01, importance = T) + interaction.depth = 2, shrinkage = 0.01, importance = T) # Access Variable Importance from the built model gbm.VI = my.gbm@model$varimp @@ -73,18 +79,18 @@ my.glm = h2o.glm(x=myX, y=myY, data=data.hex, family="binomial",standardize=T,us # Select the best model picked by glm best_model = my.glm@best_model -# Get the normalized coefficients of the best model -n_coeff = abs(my.glm@models[[best_model]]@model$normalized_coefficients) +# Get the normalized coefficients of the best model +n_coeff = abs(my.glm@models[[best_model]]@model$normalized_coefficients) # Access Variable Importance by removing the intercept term -VI = abs(n_coeff[-length(n_coeff)]) +VI = abs(n_coeff[-length(n_coeff)]) glm.VI = VI[order(VI,decreasing=T)] print("Variable importance from GLM") print(glm.VI) # Plot variable importance from glm -barplot(glm.VI[1:20],las=2,main="VI from GLM") +barplot(glm.VI[1:20],las=2,main="VI from GLM") #-------------------------------------------------- # Run deeplearning with variable importance diff --git a/R/tests/testdir_demos/runit_demo_tk_cm_roc.R b/R/tests/testdir_demos/runit_demo_tk_cm_roc.R index 5287cc6759..9495fc8306 100644 --- a/R/tests/testdir_demos/runit_demo_tk_cm_roc.R +++ b/R/tests/testdir_demos/runit_demo_tk_cm_roc.R @@ -15,6 +15,8 @@ if (TRUE) { if (FALSE) { setwd("/Users/tomk/0xdata/ws/h2o/R/tests/testdir_demos") + filePath <- "/Users/tomk/0xdata/ws/h2o/smalldata/airlines/AirlinesTrain.csv.zip" + testFilePath <- "/Users/tomk/0xdata/ws/h2o/smalldata/airlines/AirlinesTest.csv.zip" } source('../findNSourceUtils.R') @@ -23,8 +25,8 @@ if (TRUE) { testFilePath <- normalizePath(locate("smalldata/airlines/AirlinesTest.csv.zip")) } else { stop("need to hardcode ip and port") - # myIP = "127.0.0.1" - # myPort = 54321 + myIP = "127.0.0.1" + myPort = 54321 library(h2o) PASS_BANNER <- function() { cat("\nPASS\n\n") } @@ -48,7 +50,7 @@ myX = c("Origin", "Dest", "Distance", "UniqueCarrier", "fMonth", "fDayofMonth", myY="IsDepDelayed" #gbm -air.gbm = h2o.gbm(x = myX, y = myY, distribution = "multinomial", data = air.train, n.trees = 10, +air.gbm = h2o.gbm(x = myX, y = myY, distribution = "multinomial", data = air.train, n.trees = 10, interaction.depth = 3, shrinkage = 0.01, n.bins = 100, validation = air.valid, importance = T) print(air.gbm@model) air.gbm@model$auc @@ -62,7 +64,7 @@ air.test=h2o.importFile(conn,testFilePath,key="air.test") model_object=air.rf #air.glm air.rf air.dl -#predicting on test file +#predicting on test file pred = h2o.predict(model_object,air.test) head(pred) @@ -80,14 +82,27 @@ plot(perf,type="roc") PASS_BANNER() -if (FALSE) { - h = h2o.init(ip="mr-0xb1", port=60024) +BIGDATA = FALSE +if (BIGDATA) { + h = h2o.init(ip="172.16.2.190", port=60024) df = h2o.importFile(h, "/home/tomk/airlines_all.csv") nrow(df) ncol(df) head(df) + + s = h2o.runif(df) # Useful when number of rows too large for R to handle + air.train = df[s <= 0.8,] + air.test = df[s > 0.8,] + myX = c("Origin", "Dest", "Distance", "UniqueCarrier", "Month", "DayofMonth", "DayOfWeek") myY = "IsDepDelayed" - air.glm = h2o.glm(x = myX, y = myY, data = df, family = "binomial", nfolds = 10, alpha = 0.25, lambda = 0.001) - air.glm@model$confusion + air.glm = h2o.glm(x = myX, y = myY, data = air.train, + family = "binomial", nfolds = 1, alpha = 0.25, lambda = 0.001) + + pred = h2o.predict(air.glm, air.test) + dim(pred) + head(pred) + perf = h2o.performance(pred$YES,air.test$IsDepDelayed) + perf + plot(perf,type="roc") } diff --git a/R/tests/testdir_demos/runit_demo_tk_steam.R b/R/tests/testdir_demos/runit_demo_tk_steam.R index 9cab06b053..42e59cfe36 100644 --- a/R/tests/testdir_demos/runit_demo_tk_steam.R +++ b/R/tests/testdir_demos/runit_demo_tk_steam.R @@ -11,19 +11,20 @@ if (TRUE) { # Set working directory so that the source() below works. setwd(normalizePath(dirname(R.utils::commandArgs(asValues=TRUE)$"f"))) - + if (FALSE) { setwd("/Users/tomk/0xdata/ws/h2o/R/tests/testdir_demos") + filePath <- "/Users/tomk/0xdata/ws/h2o/smalldata/airlines/allyears2k_headers.zip" } - + source('../findNSourceUtils.R') options(echo=TRUE) filePath <- normalizePath(locate("smalldata/airlines/allyears2k_headers.zip")) } else { stop("need to hardcode ip and port") - # myIP = "127.0.0.1" - # myPort = 54321 - + myIP = "127.0.0.1" + myPort = 54321 + library(h2o) PASS_BANNER <- function() { cat("\nPASS\n\n") } filePath <- "https://raw.github.com/0xdata/h2o/master/smalldata/airlines/allyears2k_headers.zip"