Skip to content

Commit

Permalink
Merge branch 'master' of github.com:h2oai/h2o
Browse files Browse the repository at this point in the history
  • Loading branch information
jessica0xdata committed Jan 13, 2015
2 parents a29909d + 53c7872 commit 29f40e9
Show file tree
Hide file tree
Showing 42 changed files with 557 additions and 237 deletions.
49 changes: 49 additions & 0 deletions R/examples/manycols.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Install and Launch H2O R package
if ("package:h2o" %in% search()) { detach("package:h2o", unload=TRUE) }
if ("h2o" %in% rownames(installed.packages())) { remove.packages("h2o") }
install.packages("h2o", repos=(c("http://h2o-release.s3.amazonaws.com/h2o/h2o-parsemanycols/8/R", getOption("repos"))))
library(h2o)

# Connect to cluster (8 nodes with -Xmx 40g each)

# Launch H2O Cluster with YARN on HDP2.1
#wget http://h2o-release.s3.amazonaws.com/h2o/h2o-parsemanycols/8/h2o-2.9.0.8.zip
#unzip h2o-2.9.0.8.zip
#cd h2o-2.9.0.8/hadoop
#hadoop fs -rm -r myDir
#hadoop jar h2odriver_hdp2.1.jar water.hadoop.h2odriver -libjars ../h2o.jar -n 8 -mapperXmx 40g -output myDir -baseport 61111 -data_max_factor_levels 65000 -chunk_bits 24

h2oCluster <- h2o.init(ip="mr-0xd1", port=61111)

# Read data from HDFS
data.hex <- h2o.importFile(h2oCluster, "hdfs://mr-0xd6/datasets/15Mx2.2k.csv")

# Create 80/20 train/validation split
random <- h2o.runif(data.hex, seed = 123456789)
train <- h2o.assign(data.hex[random < .8,], "X15Mx2_2k_part0.hex")
valid <- h2o.assign(data.hex[random >= .8,], "X15Mx2_2k_part1.hex")

# Delete full training data and temporaries - only needed if memory is tight
h2o.rm(h2oCluster, "15Mx2_2k.hex") # optional
h2o.rm(h2oCluster, grep(pattern = "Last.value", x = h2o.ls(h2oCluster)$Key, value = TRUE))

response=2 #1:1000 imbalance
predictors=c(3:ncol(data.hex))

# Start modeling

# GLM
mdl.glm <- h2o.glm(x=predictors, y=response, data=train, lambda_search=T, family="binomial", max_predictors=100) #nfolds=5 is optional
mdl.glm

# compute validation error for GLM
pred.glm <- h2o.predict(mdl.glm, valid)
h2o.performance(pred.glm[,3], valid[,response], measure="F1")

# Gradient Boosted Trees
mdl.gbm <- h2o.gbm(x=predictors, y=response, data=train, validation=valid, importance=T, balance.classes = T, class.sampling.factors = c(1,250))
mdl.gbm

# Random Forest
mdl.rf <- h2o.randomForest(x=predictors, y=response, data=train, validation=valid, type="BigData", depth=15, importance=T, balance.classes = T, class.sampling.factors = c(1,250))
mdl.rf
69 changes: 39 additions & 30 deletions R/h2o-package/R/Algorithms.R

Large diffs are not rendered by default.

9 changes: 4 additions & 5 deletions R/h2o-package/R/Classes.R
Original file line number Diff line number Diff line change
Expand Up @@ -1164,12 +1164,11 @@ rbind.H2OParsedData <- function(..., deparse.level = 1) {
# l_dep <- sapply(substitute(placeholderFunction(...))[-1], deparse)
if(length(l) == 0) stop('rbind requires an H2O parsed dataset')

klass <- 'H2OParsedData'
# klass <- 'H2OParsedData'
h2o <- l[[1]]@h2o
nrows <- nrow(l[[1]])
m <- Map(function(elem){ inherits(elem, klass) & elem@h2o@ip == h2o@ip & elem@h2o@port == h2o@port & nrows == nrow(elem) }, l)
compatible <- Reduce(function(l,r) l & r, x=m, init=T)
if(!compatible){ stop(paste('rbind: all elements must be of type', klass, 'and in the same H2O instance'))}
# m <- Map(function(elem){ inherits(elem, klass) & elem@h2o@ip == h2o@ip & elem@h2o@port == h2o@port & nrows == nrow(elem) }, l)
# compatible <- Reduce(function(l,r) l & r, x=m, init=T)
# if(!compatible){ stop(paste('rbind: all elements must be of type', klass, 'and in the same H2O instance'))}

# If cbind(x,x), dupe colnames will automatically be renamed by H2O
if(is.null(names(l)))
Expand Down
6 changes: 4 additions & 2 deletions R/h2o-package/R/ParseImport.R
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ h2o.assign <- function(data, key) {
.h2o.exec2(expr = data@key, h2o = data@h2o, dest_key = key)
}

h2o.createFrame <- function(object, key, rows, cols, seed, randomize, value, real_range, categorical_fraction, factors, integer_fraction, integer_range, missing_fraction, response_factors) {
h2o.createFrame <- function(object, key, rows, cols, seed, randomize, value, real_range, categorical_fraction, factors, integer_fraction, integer_range, binary_fraction=0, binary_ones_fraction=0.5, missing_fraction, response_factors) {
if(!is.numeric(rows)) stop("rows must be a numeric value")
if(!is.numeric(cols)) stop("cols must be a numeric value")
if(!is.numeric(seed)) stop("seed must be a numeric value")
Expand All @@ -118,9 +118,11 @@ h2o.createFrame <- function(object, key, rows, cols, seed, randomize, value, rea
if(!is.numeric(integer_range)) stop("integer_range must be a numeric value")
if(!is.numeric(missing_fraction)) stop("missing_fraction must be a numeric value")
if(!is.numeric(response_factors)) stop("response_factors must be a numeric value")
if(!is.numeric(binary_fraction)) stop("binary_fraction must be a numeric value")
if(!is.numeric(binary_ones_fraction)) stop("binary_ones_fraction must be a numeric value")

res <- .h2o.__remoteSend(object, .h2o.__PAGE_CreateFrame, key = key, rows = rows, cols = cols, seed = seed, randomize = as.numeric(randomize), value = value, real_range = real_range,
categorical_fraction = categorical_fraction, factors = factors, integer_fraction = integer_fraction, integer_range = integer_range, missing_fraction = missing_fraction, response_factors = response_factors)
categorical_fraction = categorical_fraction, factors = factors, integer_fraction = integer_fraction, integer_range = integer_range, binary_fraction = binary_fraction, binary_ones_fraction=binary_ones_fraction, missing_fraction = missing_fraction, response_factors = response_factors)
.h2o.exec2(expr = key, h2o = object, dest_key = key)
}

Expand Down
11 changes: 7 additions & 4 deletions R/h2o-package/man/h2o.SpeeDRF.Rd
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ H2O: Single-Node Random Forest
Performs single-node random forest classification on a data set.
}
\usage{
h2o.SpeeDRF(x, y, data, key = "", classification = TRUE, nfolds = 0, validation,
mtries = -1, ntree = 50, depth = 20, sample.rate = 2/3, oobee = TRUE,
importance = FALSE, nbins = 1024, seed = -1, stat.type = "ENTROPY",
balance.classes = FALSE, verbose = FALSE)
h2o.SpeeDRF(x, y, data, key = "", classification = TRUE, nfolds = 0, validation,
holdout.fraction = 0, mtries = -1, ntree = 50, depth = 20, sample.rate = 2/3,
oobee = TRUE, importance = FALSE, nbins = 1024, seed = -1,
stat.type = "ENTROPY", balance.classes = FALSE, verbose = FALSE)
}
%- maybe also 'usage' for other objects documented here.
\arguments{
Expand All @@ -34,6 +34,9 @@ An \code{\linkS4class{H2OParsedData}} object containing the variables in the mod
}
\item{validation}{
(Optional) An \code{\linkS4class{H2OParsedData}} object indicating the validation dataset used to construct confusion matrix. If left blank, this defaults to the training data when \code{nfolds = 0}.}

\item{holdout.fraction}{ (Optional) Fraction of the training data to hold out for validation.}

\item{mtries}{
(Optional) Number of features to randomly select at each split in the tree. If set to the default of -1, this will be set to \code{sqrt(ncol(data))}, rounded down to the nearest integer.
}
Expand Down
10 changes: 7 additions & 3 deletions R/h2o-package/man/h2o.createFrame.Rd
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ Create an H2O data frame from scratch, with optional randomization. Supports cat
}
\usage{
h2o.createFrame(object, key, rows, cols, seed, randomize, value, real_range,
categorical_fraction, factors, integer_fraction, integer_range,
categorical_fraction, factors, integer_fraction, integer_range,
binary_fraction, binary_ones_fraction,
missing_fraction, response_factors)
}
%- maybe also 'usage' for other objects documented here.
Expand All @@ -27,6 +28,8 @@ h2o.createFrame(object, key, rows, cols, seed, randomize, value, real_range,
\item{factors}{Factor levels for categorical variables}
\item{integer_fraction}{Fraction of integer columns (for randomize=true)}
\item{integer_range}{Range for integer variables (-range ... range)}
\item{binary_fraction}{Fraction of binary columns (for randomize=true)}
\item{binary_ones_fraction}{Fraction of 1's in binary columns (for randomize=true)}
\item{missing_fraction}{Fraction of missing values}
\item{response_factors}{Number of factor levels of the first column (1=real, 2=binomial, N=multinomial)}
}
Expand All @@ -39,8 +42,9 @@ localH2O = h2o.init(beta = TRUE)
myframe = h2o.createFrame(localH2O, 'myframekey', rows = 1000, cols = 10,
seed = -12301283, randomize = TRUE, value = 0, real_range = 2.0,
categorical_fraction = 0.2, factors = 100,
integer_fraction = 0.2, integer_range = 100, missing_fraction = 0.1,
response_factors = 2)
integer_fraction = 0.2, integer_range = 100,
binary_fraction = 0.1, binary_ones_fraction = 0.01,
missing_fraction = 0.1, response_factors = 2)
head(myframe)
summary(myframe)
h2o.shutdown(localH2O)
Expand Down
9 changes: 5 additions & 4 deletions R/h2o-package/man/h2o.deeplearning.Rd
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ Performs Deep Learning neural networks on an \code{\linkS4class{H2OParsedData}}
}
\usage{
h2o.deeplearning(x, y, data, key = "",override_with_best_model, classification = TRUE,
nfolds = 0, validation, checkpoint = "", autoencoder, use_all_factor_levels,
activation, hidden, epochs, train_samples_per_iteration, seed, adaptive_rate,
rho, epsilon, rate, rate_annealing, rate_decay, momentum_start,
momentum_ramp, momentum_stable, nesterov_accelerated_gradient,
nfolds = 0, validation, holdout_fraction = 0, checkpoint = "", autoencoder,
use_all_factor_levels, activation, hidden, epochs, train_samples_per_iteration,
seed, adaptive_rate, rho, epsilon, rate, rate_annealing, rate_decay,
momentum_start, momentum_ramp, momentum_stable, nesterov_accelerated_gradient,
input_dropout_ratio, hidden_dropout_ratios, l1, l2, max_w2,
initial_weight_distribution, initial_weight_scale, loss,
score_interval, score_training_samples, score_validation_samples,
Expand All @@ -33,6 +33,7 @@ h2o.deeplearning(x, y, data, key = "",override_with_best_model, classification =
\item{classification}{ (Optional) A logical value indicating whether the algorithm should conduct classification. }
\item{nfolds}{(Optional) Number of folds for cross-validation. If \code{nfolds >= 2}, then \code{validation} must remain empty.}
\item{validation}{(Optional) An \code{\linkS4class{H2OParsedData}} object indicating the validation dataset used to construct confusion matrix. If left blank, this defaults to the training data when \code{nfolds = 0}.}
\item{holdout_fraction}{ (Optional) Fraction of the training data to hold out for validation.}
\item{checkpoint}{"Model checkpoint (either key or H2ODeepLearningModel) to resume training with."}
\item{activation}{A string indicating the activation function to use. Must be either "Tanh", "TanhWithDropout", "Rectifier", "RectifierWithDropout", "Maxout" or "MaxoutWithDropout".}
\item{hidden}{ Hidden layer sizes (e.g. c(100,100)}
Expand Down
9 changes: 5 additions & 4 deletions R/h2o-package/man/h2o.gbm.Rd
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ H2O: Gradient Boosted Machines
\usage{
h2o.gbm(x, y, distribution = "multinomial", data, key = "", n.trees = 10,
interaction.depth = 5, n.minobsinnode = 10, shrinkage = 0.1, n.bins = 20,
group_split = TRUE, importance = FALSE, nfolds = 0, validation, balance.classes = FALSE,
max.after.balance.size = 5)
group_split = TRUE, importance = FALSE, nfolds = 0, validation, holdout.fraction = 0,
balance.classes = FALSE, max.after.balance.size = 5, class.sampling.factors = NULL)
}
\arguments{
\item{x}{
Expand Down Expand Up @@ -56,10 +56,11 @@ An \code{\linkS4class{H2OParsedData}} object containing the variables in the mod
\item{nfolds}{
(Optional) Number of folds for cross-validation. If \code{nfolds >= 2}, then \code{validation} must remain empty.
}
\item{validation}{
(Optional) An \code{\linkS4class{H2OParsedData}} object indicating the validation dataset used to construct confusion matrix. If left blank, this defaults to the training data when \code{nfolds = 0}.}
\item{validation}{ (Optional) An \code{\linkS4class{H2OParsedData}} object indicating the validation dataset used to construct confusion matrix. If left blank, this defaults to the training data when \code{nfolds = 0}.}
\item{holdout.fraction}{ (Optional) Fraction of the training data to hold out for validation.}
\item{balance.classes}{(Optional) Balance training data class counts via over/under-sampling (for imbalanced data)}
\item{max.after.balance.size}{Maximum relative size of the training data after balancing class counts (can be less than 1.0)}
\item{class.sampling.factors}{ Desired over/under-sampling ratios per class (lexicographic order). }
}
\value{
An object of class \code{\linkS4class{H2OGBMModel}} with slots key, data, valid (the validation dataset) and model, where the last is a list of the following components:
Expand Down
4 changes: 3 additions & 1 deletion R/h2o-package/man/h2o.interaction.Rd
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ localH2O = h2o.init()
myframe = h2o.createFrame(localH2O, 'framekey', rows = 20, cols = 5,
seed = -12301283, randomize = TRUE, value = 0,
categorical_fraction = 0.8, factors = 10, real_range = 1,
integer_fraction = 0.2, integer_range = 10, missing_fraction = 0.2,
integer_fraction = 0.2, integer_range = 10,
binary_fraction = 0, binary_ones_fraction = 0.5,
missing_fraction = 0.2,
response_factors = 1)[,-1]
myframe[,3] <- as.factor(myframe[,3])
summary(myframe)
Expand Down
11 changes: 8 additions & 3 deletions R/h2o-package/man/h2o.randomForest.Rd
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@ Performs random forest classification on a data set.
\usage{
h2o.randomForest(x, y, data, key = "", classification = TRUE, ntree = 50,
depth = 20, mtries = -1, sample.rate = 2/3, nbins = 20, seed = -1,
importance = FALSE, nfolds = 0, validation, nodesize = 1,
balance.classes = FALSE, max.after.balance.size = 5, doGrpSplit = TRUE,
verbose = FALSE, oobee = TRUE, stat.type = "ENTROPY", type = "fast")
importance = FALSE, nfolds = 0, validation, holdout.fraction = 0, nodesize = 1,
balance.classes = FALSE, max.after.balance.size = 5, class.sampling.factors = NULL,
doGrpSplit = TRUE, verbose = FALSE, oobee = TRUE, stat.type = "ENTROPY",
type = "fast")
}
%- maybe also 'usage' for other objects documented here.
\arguments{
Expand Down Expand Up @@ -60,12 +61,16 @@ An \code{\linkS4class{H2OParsedData}} object containing the variables in the mod
\item{validation}{
(Optional) An \code{\linkS4class{H2OParsedData}} object indicating the validation dataset used to construct
confusion matrix. If left blank, this defaults to the training data when \code{nfolds = 0}.}

\item{holdout.fraction}{ (Optional) Fraction of the training data to hold out for validation.}

\item{nodesize}{
(Optional) Number of nodes to use for computation.
}
\item{balance.classes}{(Optional) Balance training data class counts via over/under-sampling (for imbalanced data)}
\item{max.after.balance.size}{Maximum relative size of the training data after balancing
class counts (can be less than 1.0)}
\item{class.sampling.factors}{ Desired over/under-sampling ratios per class (lexicographic order). }
\item{doGrpSplit}{Check non-contiguous group splits for categorical predictors}
\item{verbose}{(Optional) A logical value indicating whether verbose results should be returned.}
\item{stat.type}{(Optional) Type of statistic to use, equal to either "ENTROPY" or "GINI" or "TWOING".}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ check.deeplearning_autoencoder <- function(conn) {
cm <- h2o.confusionMatrix(test_preds[,1], test_labels)
cm

checkTrue(cm[length(cm)] == 0.104) #10% test set error
checkTrue(cm[length(cm)] == 0.1085) #10% test set error

testEnd()
}
Expand Down
3 changes: 2 additions & 1 deletion R/tests/testdir_demos/runit_demo_random_data_glm.R
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ for(i in 1:length(rows)){ # changing number of rows
seed = 12345, randomize = T, value = 0, real_range = 100,
categorical_fraction = 0.0, factors = 10,
integer_fraction = 0.4, integer_range = 100,
missing_fraction = 0, response_factors = 1) )
missing_fraction = 0, response_factors = 1,
binary_fraction = 0, binary_ones_fraction = 0.5) )
create_frm_time[i,j] = as.numeric(sst[3])
mem = h2o.ls(conn,"myframe")
frm_size[i,j] = as.numeric(mem[2])
Expand Down
5 changes: 3 additions & 2 deletions R/tests/testdir_demos/runit_demo_random_data_pca.R
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,9 @@ for(i in 1:length(rows)){ # changing number of rows
seed = 12345, randomize = T, value = 0, real_range = 100,
categorical_fraction = 0.0, factors = 10,
integer_fraction = 0.4, integer_range = 100,
missing_fraction = 0, response_factors = 1) )

missing_fraction = 0, response_factors = 1,
binary_fraction = 0, binary_ones_fraction = 0.5) )

create_frm_time[i,j] = as.numeric(sst[3])
mem = h2o.ls(conn,"myframe")
frm_size[i,j] = as.numeric(mem[2])
Expand Down
18 changes: 18 additions & 0 deletions hadoop/src/main/java/water/hadoop/h2odriver.java
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ public class h2odriver extends Configured implements Tool {
static int cloudFormationTimeoutSeconds = DEFAULT_CLOUD_FORMATION_TIMEOUT_SECONDS;
static int nthreads = -1;
static int basePort = -1;
static int chunk_bits;
static int data_max_factor_levels;
static boolean beta = false;
static boolean enableRandomUdpDrop = false;
static boolean enableExceptions = false;
Expand Down Expand Up @@ -389,6 +391,8 @@ static void usage() {
" -n | -nodes <number of H2O nodes (i.e. mappers) to create>\n" +
" [-nthreads <maximum typical worker threads, i.e. cpus to use>]\n" +
" [-baseport <starting HTTP port for H2O nodes; default is 54321>]\n" +
" [-chunk_bits <bits per chunk (e.g., 22 for 4MB chunks)>]\n" +
" [-data_max_factor_levels <max. number of factors per column (e.g., 65000)>]\n" +
" [-ea]\n" +
" [-verbose:gc]\n" +
" [-XX:+PrintGCDetails]\n" +
Expand Down Expand Up @@ -543,6 +547,14 @@ else if (s.equals("-nthreads")) {
i++; if (i >= args.length) { usage(); }
nthreads = Integer.parseInt(args[i]);
}
else if (s.equals("-chunk_bits")) {
i++; if (i >= args.length) { usage(); }
chunk_bits = Integer.parseInt(args[i]);
}
else if (s.equals("-data_max_factor_levels")) {
i++; if (i >= args.length) { usage(); }
data_max_factor_levels = Integer.parseInt(args[i]);
}
else if (s.equals("-baseport")) {
i++; if (i >= args.length) { usage(); }
basePort = Integer.parseInt(args[i]);
Expand Down Expand Up @@ -912,6 +924,12 @@ private int run2(String[] args) throws Exception {
if (beta) {
conf.set(h2omapper.H2O_BETA_KEY, "-beta");
}
if (chunk_bits > 0) {
conf.set(h2omapper.H2O_CHUNKBITS_KEY, Integer.toString(chunk_bits));
}
if (data_max_factor_levels > 0) {
conf.set(h2omapper.H2O_DATAMAXFACTORLEVELS_KEY, Integer.toString(data_max_factor_levels));
}
if (enableRandomUdpDrop) {
conf.set(h2omapper.H2O_RANDOM_UDP_DROP_KEY, "-random_udp_drop");
}
Expand Down
Loading

0 comments on commit 29f40e9

Please sign in to comment.