-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathcross_val.R
80 lines (77 loc) · 2.53 KB
/
cross_val.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
#' Define Cross-Validation Scheme and Training Parameters
#'
#' See \link[caret]{trainControl} for info on how the seed is being set.
#'
#' @param train_data Dataframe for training model
#' @param class_probs Whether the `\link[caret]{trainControl}` `classProbs` argument should be TRUE or FALSE (TRUE for classification, FALSE for regression)
#' @inheritParams run_ml
#' @inheritParams get_tuning_grid
#'
#' @return Caret object for trainControl that controls cross-validation
#' @export
#' @author Begüm Topçuoğlu, \email{topcuoglu.begum@@gmail.com}
#'
#'
#' @examples
#' define_cv(train_data_sm,
#' outcome_colname = "dx",
#' hyperparams_list = get_hyperparams_list(otu_small, "regLogistic"),
#' perf_metric_function = caret::twoClassSummary,
#' class_probs = TRUE,
#' kfold = 5,
#' seed = 2019
#' )
define_cv <- function(train_data, outcome_colname, hyperparams_list, perf_metric_function, class_probs, kfold = 5, cv_times = 100, group = NULL, seed = NA) {
if (!is.na(seed)) {
set.seed(seed)
}
if (is.null(group)) {
cvIndex <- caret::createMultiFolds(factor(train_data %>%
dplyr::pull(outcome_colname)),
kfold,
times = cv_times
)
} else {
cvIndex <- groupKMultiFolds(group, kfold = kfold, cv_times = cv_times)
}
seeds <- get_seeds_trainControl(hyperparams_list, kfold, cv_times, ncol(train_data))
cv <- caret::trainControl(
method = "repeatedcv",
number = kfold,
index = cvIndex,
returnResamp = "final",
classProbs = class_probs,
summaryFunction = perf_metric_function,
indexFinal = NULL,
savePredictions = TRUE,
seeds = seeds
)
return(cv)
}
#' Get seeds for caret::trainControl
#'
#' Adapted from \href{https://stackoverflow.com/a/32598959}{this Stack Overflow post}
#' and the \link[caret]{trainControl} documentation
#'
#' @param ncol_train number of columns in training data
#' @inheritParams run_ml
#' @inheritParams define_cv
#'
#' @return seeds for `caret::trainControl`
#' @export
#'
#' @examples
#' get_seeds_trainControl(get_hyperparams_list(otu_small, "regLogistic"), 5, 100, 60)
get_seeds_trainControl <- function(hyperparams_list, kfold, cv_times, ncol_train) {
seeds <- vector(mode = "list", length = kfold * cv_times + 1)
sample_from <- ncol_train * 1000
n_tuning_combos <- hyperparams_list %>%
sapply(FUN = length) %>%
prod()
for (i in 1:(kfold * cv_times)) {
seeds[[i]] <- sample.int(n = sample_from, size = n_tuning_combos)
}
## For the last model:
seeds[[kfold * cv_times + 1]] <- sample.int(n = sample_from, size = 1)
return(seeds)
}