Skip to content

Commit

Permalink
added logging robyn_engineering and prophet_decomp function
Browse files Browse the repository at this point in the history
  • Loading branch information
sinanfb91 committed Jul 26, 2022
1 parent 5781a5f commit b0b00c6
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 21 deletions.
39 changes: 18 additions & 21 deletions .idea/workspace.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 9 additions & 0 deletions R/R/inputs.R
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,8 @@ robyn_engineering <- function(x, ...) {
colnames(dt_transform)[colnames(dt_transform) == InputCollect$date_var] <- "ds"
colnames(dt_transform)[colnames(dt_transform) == InputCollect$dep_var] <- "dep_var"
dt_transform <- arrange(dt_transform, .data$ds)
message(paste("robyn_engineering dt_transform:", paste(head(dt_transform,2), collapse = ", ")))


# dt_transformRollWind
dt_transformRollWind <- dt_transform[rollingWindowStartWhich:rollingWindowEndWhich, ]
Expand Down Expand Up @@ -627,6 +629,7 @@ robyn_engineering <- function(x, ...) {
## transform all factor variables
if (length(factor_vars) > 0) {
dt_transform <- mutate_at(dt_transform, factor_vars, as.factor)
message(paste("robyn_engineering mute_at(dt_transform):", paste(head(dt_transform,2), collapse = ", ")))
}

################################################################
Expand Down Expand Up @@ -710,6 +713,7 @@ prophet_decomp <- function(dt_transform, dt_holidays,
use_weekday <- "weekday" %in% prophet_vars | "weekly.seasonality" %in% prophet_vars

dt_regressors <- cbind(recurrence, subset(dt_transform, select = c(context_vars, paid_media_spends)))
message(paste("prophet_decomp dt_regressors:", paste(dt_regressors, collapse = ", ")))

prophet_params <- list(
holidays = if (use_holiday) holidays[holidays$country == prophet_country, ] else NULL,
Expand All @@ -724,7 +728,9 @@ prophet_decomp <- function(dt_transform, dt_holidays,
daily.seasonality = FALSE # No hourly models allowed
)
prophet_params <- append(prophet_params, custom_params)
message(paste("prophet_decomp prophet_params:", paste(head(prophet_params,1), collapse = ", ")))
modelRecurrence <- do.call(prophet, as.list(prophet_params))
message(paste("prophet_decomp modelRecurrence <- do.call(prophet, as.list(prophet_params)):", paste(head(modelRecurrence,1), collapse = ", ")))

if (!is.null(factor_vars) && length(factor_vars) > 0) {
dt_ohe <- dt_regressors %>%
Expand All @@ -733,8 +739,11 @@ prophet_decomp <- function(dt_transform, dt_holidays,
ohe_names <- names(dt_ohe)
for (addreg in ohe_names) modelRecurrence <- add_regressor(modelRecurrence, addreg)
dt_ohe <- select(dt_regressors, -all_of(factor_vars)) %>% bind_cols(dt_ohe)
message(paste("prophet_decomp select(dt_regressors, -all_of(factor_vars)) %>% bind_cols(dt_ohe):", paste(head(dt_ohe,1), collapse = ", ")))
mod_ohe <- fit.prophet(modelRecurrence, dt_ohe)
message(paste("prophet_decomp fit.prophet(modelRecurrence, dt_ohe):", paste(head(mod_ohe,1), collapse = ", ")))
dt_forecastRegressor <- predict(mod_ohe, dt_ohe)
message(paste("prophet_decomp dt_forecastRegressor:", paste(head(dt_forecastRegressor,1), collapse = ", ")))
forecastRecurrence <- select(dt_forecastRegressor, -contains("_lower"), -contains("_upper"))
for (aggreg in factor_vars) {
oheRegNames <- grep(paste0("^", aggreg, ".*"), names(forecastRecurrence), value = TRUE)
Expand Down

0 comments on commit b0b00c6

Please sign in to comment.