Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/test_p_robyn123' into test_p_rob…
Browse files Browse the repository at this point in the history
…yn123
  • Loading branch information
sinanfb91 committed Jul 11, 2022
2 parents 1a5e8c5 + 8552caf commit 9b58bb1
Showing 1 changed file with 28 additions and 19 deletions.
47 changes: 28 additions & 19 deletions R/R/checks.R
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,6 @@ check_paidmedia <- function(dt_input, paid_media_vars, paid_media_signs, paid_me
}
if (is.null(paid_media_signs)) {
paid_media_signs <- rep("positive", mediaVarCount)
# message("'paid_media_signs' were not provided. Using 'positive'")
}
if (!all(paid_media_signs %in% opts_pnd)) {
stop("Allowed values for 'paid_media_signs' are: ", paste(opts_pnd, collapse = ", "))
Expand All @@ -212,15 +211,11 @@ check_paidmedia <- function(dt_input, paid_media_vars, paid_media_signs, paid_me
if (spendVarCount != mediaVarCount) {
stop("'paid_media_spends' must have same length as 'paid_media_vars'")
}
#if (any(dt_input[, unique(c(paid_media_vars, paid_media_spends)), with = FALSE] < 0)) {
get_cols <- any(as.data.frame(dt_input)[, unique(c(paid_media_vars, paid_media_spends))]<0)
get_cols <- any(dt_input[, unique(c(paid_media_vars, paid_media_spends))] < 0)
if (get_cols) {
check_media_names <- unique(c(paid_media_vars, paid_media_spends))
df_check <- as.data.frame(dt_input)[, check_media_names]
#check_media_val <- sapply(dt_input[, check_media_names, with = FALSE], function(X) {
check_media_val <- sapply(df_check, function(X) {
any(X < 0)
})
df_check <- dt_input[, check_media_names]
check_media_val <- sapply(df_check, function(x) any(x < 0))
stop(
paste(names(check_media_val)[check_media_val], collapse = ", "),
" contains negative values. Media must be >=0"
Expand Down Expand Up @@ -278,8 +273,9 @@ check_datadim <- function(dt_input, all_ind_vars, rel = 10) {
}

check_windows <- function(dt_input, date_var, all_media, window_start, window_end) {

dates_vec <- as.Date(dt_input[, date_var][[1]])
window_start <- as.Date(as.character(window_start))
window_end <- as.Date(as.character(window_end))

if (is.null(window_start)) {
window_start <- min(dates_vec)
Expand Down Expand Up @@ -440,8 +436,10 @@ check_calibration <- function(dt_input, date_var, calibration_input, dayInterval
all_media <- c(paid_media_spends, organic_vars)
if (!all(calibration_input$channel %in% all_media)) {
these <- unique(calibration_input$channel[which(!calibration_input$channel %in% all_media)])
stop(sprintf("All channels from 'calibration_input' must be any of: %s.\n Check: %s",
v2t(all_media), v2t(these)))
stop(sprintf(
"All channels from 'calibration_input' must be any of: %s.\n Check: %s",
v2t(all_media), v2t(these)
))
}
for (i in 1:nrow(calibration_input)) {
temp <- calibration_input[i, ]
Expand Down Expand Up @@ -673,16 +671,18 @@ check_allocator <- function(OutputCollect, select_model, paid_media_spends, scen
stop("Input 'scenario' must be one of: ", paste(opts, collapse = ", "))
}

if (length(channel_constr_low) != 1 & length(channel_constr_low) != length(paid_media_spends))
if (length(channel_constr_low) != 1 & length(channel_constr_low) != length(paid_media_spends)) {
stop(paste(
"Input 'channel_constr_low' have to contain either only 1",
"value or have same length as 'InputCollect$paid_media_spends':", length(paid_media_spends)
))
if (length(channel_constr_up) != 1 & length(channel_constr_up) != length(paid_media_spends))
}
if (length(channel_constr_up) != 1 & length(channel_constr_up) != length(paid_media_spends)) {
stop(paste(
"Input 'channel_constr_up' have to contain either only 1",
"value or have same length as 'InputCollect$paid_media_spends':", length(paid_media_spends)
))
}

if ("max_response_expected_spend" %in% scenario) {
if (any(is.null(expected_spend), is.null(expected_spend_days))) {
Expand All @@ -697,12 +697,21 @@ check_allocator <- function(OutputCollect, select_model, paid_media_spends, scen

check_metric_value <- function(metric_value, media_metric) {
if (!is.null(metric_value)) {
if (length(metric_value) != 1) stop(sprintf(
"Input 'metric_value' for %s (%s) must be a valid numerical value", media_metric, metric_value))
if (!is.numeric(metric_value)) stop(sprintf(
"Input 'metric_value' for %s (%s) must be a numerical value", media_metric, metric_value))
if (metric_value <= 0) stop(sprintf(
"Input 'metric_value' for %s (%s) must be a positive value", media_metric, metric_value))
if (length(metric_value) != 1) {
stop(sprintf(
"Input 'metric_value' for %s (%s) must be a valid numerical value", media_metric, metric_value
))
}
if (!is.numeric(metric_value)) {
stop(sprintf(
"Input 'metric_value' for %s (%s) must be a numerical value", media_metric, metric_value
))
}
if (metric_value <= 0) {
stop(sprintf(
"Input 'metric_value' for %s (%s) must be a positive value", media_metric, metric_value
))
}
}
}

Expand Down

0 comments on commit 9b58bb1

Please sign in to comment.