Skip to content

Commit

Permalink
Add variable importance results to DRF
Browse files Browse the repository at this point in the history
  • Loading branch information
anqi committed Mar 25, 2014
1 parent dbfd28e commit 4c34271
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 2 deletions.
11 changes: 9 additions & 2 deletions R/h2o-package/R/Algorithms.R
Original file line number Diff line number Diff line change
Expand Up @@ -840,7 +840,7 @@ h2o.randomForest.FV <- function(x, y, data, ntree=50, depth=50, sample.rate=2/3,
# while(!.h2o.__isDone(data@h2o, "RF2", res)) { Sys.sleep(1) }
res2 = .h2o.__remoteSend(data@h2o, .h2o.__PAGE_DRFModelView, '_modelKey'=res$destination_key)

result = .h2o.__getDRFResults(res2$drf_model, params)
result = .h2o.__getDRFResults(res2$drf_model, params, importance)
new("H2ODRFModel", key=res$destination_key, data=data, model=result, valid=validation)
} else {
# .h2o.gridsearch.internal("RF", data, res$job_key, res$destination_key, validation, args$y_i)
Expand All @@ -862,7 +862,7 @@ h2o.randomForest.FV <- function(x, y, data, ntree=50, depth=50, sample.rate=2/3,
return(mySum)
}

.h2o.__getDRFResults <- function(res, params) {
.h2o.__getDRFResults <- function(res, params, importance = FALSE) {
result = list()
params$ntree = res$N
params$depth = res$max_depth
Expand All @@ -888,6 +888,13 @@ h2o.randomForest.FV <- function(x, y, data, ntree=50, depth=50, sample.rate=2/3,
class_names = res$'cmDomain' # tail(res$'_domains', 1)[[1]]
result$confusion = .build_cm(tail(res$'cms', 1)[[1]]$'_arr', class_names) #res$'_domains'[[length(res$'_domains')]])
}

if(importance) {
result$varimp = data.frame(rbind(res$varimp$varimp, res$varimp$varimpSD))
result$varimp[3,] = sqrt(params$ntree)*result$varimp[1,]/result$varimp[2,] # Compute z-scores
colnames(result$varimp) = res$'_names'[-length(res$'_names')] # res$varimp$variables
rownames(result$varimp) = c(res$varimp$method, "Standard Deviation", "Z-Scores")
}
return(result)
}

Expand Down
3 changes: 3 additions & 0 deletions R/h2o-package/R/Classes.R
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,9 @@ setMethod("show", "H2ODRFModel", function(object) {
if(!is.null(model$auc) && !is.null(model$gini))
cat("\nAUC:", model$auc, "\nGini:", model$gini, "\n")
}
if(!is.null(model$varimp)) {
cat("\nVariable importance:\n"); print(model$varimp)
}
cat("\nMean-squared Error by tree:\n"); print(model$mse)
})

Expand Down

0 comments on commit 4c34271

Please sign in to comment.