Skip to content

Commit

Permalink
bug fixes & add plot_folder_sub
Browse files Browse the repository at this point in the history
- fix robyn_refresh bug when input data using read.csv
- fix plan(multicore(workers = ...)) bug
- sort data by datet
- add plot_folder_sub to allow custom folder for results
- add check_robyn_object function
- fix calibration output bug in robyn_input
  • Loading branch information
Gufeng Zhou committed Nov 17, 2021
1 parent fca1d04 commit 5b69457
Show file tree
Hide file tree
Showing 9 changed files with 54 additions and 16 deletions.
4 changes: 2 additions & 2 deletions R/DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
Package: Robyn
Type: Package
Title: Automated Marketing Mix Modeling (MMM) Open Source Beta Project from Facebook Marketing Science
Version: 3.4.6
Version: 3.4.7
Authors@R: c(
person("Gufeng", "Zhou", , "[email protected]", c("aut")),
person("Leonel", "Sentana", , "[email protected]", c("aut")),
person("Igor", "Skokan", , "[email protected]", c("aut")),
person("Bernardo", "Lares", , "[email protected]", c("cre")))
Maintainer: Bernardo Lares <[email protected]>
Maintainer: Gufeng Zhou <[email protected]>, Bernardo Lares <[email protected]>
Description: Automated Marketing Mix Modeling (MMM) package that aims to reduce human bias by means of ridge regression and evolutionary algorithms, enables actionable decision making providing a budget allocator and diminishing returns curves and allows ground-truth calibration to account for causation.
Depends:
R (>= 3.5)
Expand Down
10 changes: 9 additions & 1 deletion R/R/checks.R
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ check_datevar <- function(dt_input, date_var = "auto") {
stop("You must provide only 1 correct date variable name for 'date_var'")
}
dt_input <- as.data.table(dt_input)
dt_input <- dt_input[order(get(date_var))]
date_var_idate <- as.IDate(dt_input[, get(date_var)])
dt_input[, (date_var):= date_var_idate]
inputLen <- length(date_var_idate)
Expand Down Expand Up @@ -104,7 +105,8 @@ check_datevar <- function(dt_input, date_var = "auto") {
invisible(return(list(
date_var = date_var,
dayInterval = dayInterval,
intervalType = intervalType
intervalType = intervalType,
dt_input = dt_input
)))
}

Expand Down Expand Up @@ -398,6 +400,12 @@ check_InputCollect <- function(list) {
}
}

check_robyn_object <- function(robyn_object) {
file_end <- substr(robyn_object, nchar(robyn_object)-5, nchar(robyn_object))
if (file_end == ".RData") {stop("robyn_object must has format .RDS, not .RData")}
}


check_filedir <- function(plot_folder) {
file_end <- substr(plot_folder, nchar(plot_folder)-3, nchar(plot_folder))
if (file_end == ".RDS") {
Expand Down
8 changes: 5 additions & 3 deletions R/R/inputs.R
Original file line number Diff line number Diff line change
Expand Up @@ -195,10 +195,10 @@ robyn_inputs <- function(dt_input = NULL,

## check date input (and set dayInterval and intervalType)
date_input <- check_datevar(dt_input, date_var)
dt_input <- date_input$dt_input # sort date by ascending
date_var <- date_input$date_var # when date_var = "auto"
dayInterval <- date_input$dayInterval
intervalType <- date_input$intervalType
setorderv(dt_input, date_var)

## check dependent var
check_depvar(dt_input, dep_var, dep_var_type)
Expand Down Expand Up @@ -310,8 +310,10 @@ robyn_inputs <- function(dt_input = NULL,
)
## update calibration_input
if (!is.null(calibration_input)) InputCollect$calibration_input <- calibration_input
if (!(is.null(InputCollect$hyperparameters) & is.null(hyperparameters))) {
### conditional output 2.2
if (is.null(InputCollect$hyperparameters) & is.null(hyperparameters)) {
stop("must provide hyperparameters in robyn_inputs()")
} else {
### conditional output 2.1
## 'hyperparameters' provided --> run robyn_engineering()
## update & check hyperparameters
if (is.null(InputCollect$hyperparameters)) InputCollect$hyperparameters <- hyperparameters
Expand Down
20 changes: 15 additions & 5 deletions R/R/model.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
#' @inheritParams robyn_allocator
#' @param plot_folder Character. Path for saving plots. Default
#' to \code{robyn_object} and saves plot in the same directory as \code{robyn_object}.
#' @param plot_folder_sub Character. Customize sub path to save plots. The total
#' path is created with \code{dir.create(file.path(plot_folder, plot_folder_sub))}.
#' For example, plot_folder_sub = "sub_dir".
#' @param dt_hyper_fixed data.frame. Only provide when loading old model results.
#' It consumes hyperparameters from saved csv \code{pareto_hyperparameters.csv}.
#' @param pareto_fronts Integer. Number of Pareto fronts for the output.
Expand Down Expand Up @@ -43,6 +46,7 @@
#' @export
robyn_run <- function(InputCollect,
plot_folder = getwd(),
plot_folder_sub = NULL,
pareto_fronts = 1,
plot_pareto = TRUE,
calibration_constraint = 0.1,
Expand All @@ -63,6 +67,7 @@ robyn_run <- function(InputCollect,
t0 <- Sys.time()

# check path
check_robyn_object(plot_folder)
plot_folder <- check_filedir(plot_folder)

dt_mod <- copy(InputCollect$dt_mod)
Expand Down Expand Up @@ -237,7 +242,7 @@ robyn_run <- function(InputCollect,
## get mean_response
registerDoFuture()
if (.Platform$OS.type == "unix") {
plan(multicore(workers = InputCollect$cores))
plan(multicore, workers = InputCollect$cores)
} else {
plan(sequential)
}
Expand Down Expand Up @@ -289,11 +294,11 @@ robyn_run <- function(InputCollect,
#### Plot overview

## set folder to save plat
if (!exists("plot_folder_sub")) {
if (is.null(plot_folder_sub)) {
folder_var <- ifelse(!refresh, "init", paste0("rf", InputCollect$refreshCounter))
plot_folder_sub <- paste0(format(Sys.time(), "%Y-%m-%d %H.%M"), " ", folder_var)
plotPath <- dir.create(file.path(plot_folder, plot_folder_sub))
}
plotPath <- dir.create(file.path(plot_folder, plot_folder_sub))

# pareto_fronts_vec <- ifelse(!hyper_fixed, c(1,2,3), 1)
if (!hyper_fixed) {
Expand Down Expand Up @@ -452,6 +457,12 @@ robyn_run <- function(InputCollect,
pbplot <- txtProgressBar(max = num_pareto123, style = 3)
}

all_fronts <- unique(xDecompAgg$robynPareto)
all_fronts <- sort(all_fronts[!is.na(all_fronts)])
if (!all(pareto_fronts_vec %in% all_fronts)) {
pareto_fronts_vec <- all_fronts
}

cnt <- 0
mediaVecCollect <- list()
xDecompVecCollect <- list()
Expand Down Expand Up @@ -1111,13 +1122,12 @@ robyn_mmm <- function(hyper_collect,

registerDoFuture()
if (.Platform$OS.type == "unix") {
plan(multicore(workers = InputCollect$cores))
plan(multicore, workers = cores)
} else {
plan(sequential)
}

# nbrOfWorkers()

getDoParWorkers()
doparCollect <- suppressPackageStartupMessages(
foreach(i = 1:iterPar) %dorng% { # i = 1
Expand Down
14 changes: 11 additions & 3 deletions R/R/refresh.R
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ robyn_save <- function(robyn_object,
select_model,
InputCollect,
OutputCollect) {
check_robyn_object(robyn_object)

if (!(select_model %in% OutputCollect$resultHypParam$solID)) {
stop(paste0("'select_model' must be one of these values: ", paste(
OutputCollect$resultHypParam$solID,
Expand Down Expand Up @@ -78,6 +80,7 @@ robyn_save <- function(robyn_object,
#' spend level. It returns aggregated result with all previous builds for
#' reporting purpose and produces reporting plots.
#'
#' @inheritParams robyn_run
#' @inheritParams robyn_allocator
#' @param dt_input A data.frame. Should include all previous data and newly added
#' data for the refresh.
Expand Down Expand Up @@ -142,6 +145,7 @@ robyn_save <- function(robyn_object,
#' }
#' @export
robyn_refresh <- function(robyn_object,
plot_folder_sub = NULL,
dt_input = dt_input,
dt_holidays = dt_holidays,
refresh_steps = 4,
Expand All @@ -153,6 +157,8 @@ robyn_refresh <- function(robyn_object,
while (refreshControl) {

## load inital model
if (!exists("robyn_object")) stop("Must speficy robyn_object")
check_robyn_object(robyn_object)
if (!file.exists(robyn_object)) {
stop("File does not exist or is somewhere else. Check: ", robyn_object)
} else {
Expand Down Expand Up @@ -210,8 +216,9 @@ robyn_refresh <- function(robyn_object,

## load new data
dt_input <- as.data.table(dt_input)
date_input <- check_datevar(dt_input, InputCollectRF$date_var)
dt_input <- date_input$dt_input # sort date by ascending
dt_holidays <- as.data.table(dt_holidays)
setorderv(dt_input, InputCollectRF$date_var, order = 1L)
InputCollectRF$dt_input <- dt_input
InputCollectRF$dt_holidays <- dt_holidays

Expand Down Expand Up @@ -290,6 +297,7 @@ robyn_refresh <- function(robyn_object,
OutputCollectRF <- robyn_run(
InputCollect = InputCollectRF,
plot_folder = objectPath,
plot_folder_sub = plot_folder_sub,
pareto_fronts = 1,
refresh = TRUE,
plot_pareto = plot_pareto
Expand Down Expand Up @@ -332,15 +340,15 @@ robyn_refresh <- function(robyn_object,
listOutputPrev$mediaVecCollect[
bestModRF == TRUE & ds >= (refreshStart - InputCollectRF$dayInterval * refresh_steps) &
ds <= (refreshEnd - InputCollectRF$dayInterval * refresh_steps)
],
][, ds := as.IDate(ds)],
OutputCollectRF$mediaVecCollect[
bestModRF == TRUE & ds >= InputCollectRF$refreshAddedStart &
ds <= refreshEnd
][, ':='(refreshStatus = refreshCounter, ds = as.IDate(ds))]
)
mediaVecReport <- mediaVecReport[order(type, ds, refreshStatus)]
xDecompVecReport <- rbind(
listOutputPrev$xDecompVecCollect[bestModRF == TRUE],
listOutputPrev$xDecompVecCollect[bestModRF == TRUE][, ds := as.IDate(ds)],
OutputCollectRF$xDecompVecCollect[
bestModRF == TRUE & ds >= InputCollectRF$refreshAddedStart &
ds <= refreshEnd
Expand Down
4 changes: 4 additions & 0 deletions R/man/robyn_refresh.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions R/man/robyn_run.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion demo/debug.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ seed = 123L
## debug robyn_run
# prep input para
plot_folder = robyn_object
plot_folder_sub = NULL
pareto_fronts = 1
plot_pareto = TRUE
calibration_constraint = 0.1
Expand All @@ -27,8 +28,9 @@ seed = 123

## debug robyn_refresh
# robyn_object
plot_folder_sub = NULL
dt_input = dt_input
dt_holidays = dt_holidays
dt_holidays = dt_prophet_holidays
refresh_steps = 14
refresh_mode = "auto" # "auto", "manual"
refresh_iters = 100
Expand Down
2 changes: 1 addition & 1 deletion demo/demo.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# LICENSE file in the root directory of this source tree.

#############################################################################################
#################### Facebook MMM Open Source - Robyn 3.4.6 ######################
#################### Facebook MMM Open Source - Robyn 3.4.7 ######################
#################### Quick guide #######################
#############################################################################################

Expand Down

0 comments on commit 5b69457

Please sign in to comment.