Skip to content

Commit

Permalink
Fix generate.soothsayer (fixing bootstrapped forecasts as well)
Browse files Browse the repository at this point in the history
  • Loading branch information
JSzitas committed May 27, 2022
1 parent a16d9a4 commit 4c8be5e
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 19 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Package: soothsayer
Type: Package
Title: Tidy Time Series Forecast Model Selection
Version: 0.6.2
Version: 0.6.3
Author: Juraj Szitas
Maintainer: Juraj Szitas <[email protected]>
Description: More about what it does (maybe more than one line)
Expand Down
35 changes: 22 additions & 13 deletions R/methods.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ generate.soothsayer <- function (x, new_data = NULL, h = NULL, specials = NULL,
names(weights) <- names(x[["model_fits"]])
generated_distrs <- purrr::imap( x[["model_fits"]],
function(model, name) {
safe_gen <- purrr::possibly(generate,
otherwise = data.frame( .sim = NA))
dplyr::bind_cols(
generate(
safe_gen(
model[[1]],
new_data = new_data,
h = h,
Expand All @@ -18,25 +20,38 @@ generate.soothsayer <- function (x, new_data = NULL, h = NULL, specials = NULL,
...),
model = name)
})
# check weight consistency - not all methods implement the generate() function
valid_dists <- purrr::map_lgl( generated_distrs, ~ !all(is.na(.x[[".sim"]])) )

if( !all(valid_dists) ) {
warning(paste0("Generation failed for following models:\n",
paste0(names(x[["model_fits"]][!valid_dists]), collapse = ", "),
"\nThese models will be ignored when creating combined samples."
)
)
}
# recompute weights
weights <- weights[ valid_dists ]/sum(weights[valid_dists])
generated_distrs <- generated_distrs[valid_dists]

dists <- purrr::imap( generated_distrs, function(dist, name) {
dists <- dplyr::mutate(dist, .sim = .data$.sim * weights[name])
tsibble::as_tibble(dists)
})
# have to group by first column - which is, happily, the time index variable
index_var <- as.name(names(dists[[1]])[1])
dists <- dplyr::group_by( dplyr::bind_rows(dists), !!index_var)
dists <- dplyr::group_by( dplyr::bind_rows(dists), !!index_var, .data$.rep)
dists <- dplyr::summarise( dists,
.sim = sum(.data$.sim),
.rep = unique(.data$.rep))
tsibble::as_tsibble(dists, index = names(dists)[1])
.groups = "keep")
tsibble::tsibble( dists, index = rlang::as_string(index_var), key = c(.data$.rep))
}
# to be fair, this is a method over a fable, but I do not want to write a generic for it
get_distribution <- function(x) {
distr <- purrr::map_lgl( x, distributional::is_distribution )
distr <- names(distr)[ which(distr) ]
c(x[,distr])
}

#' @importFrom fabletools forecast
#' @export
forecast.soothsayer <- function( object,
Expand All @@ -50,8 +65,8 @@ forecast.soothsayer <- function( object,
fcst <- fabletools::forecast(
.x[[1]],
new_data = new_data,
bootstrap = bootstrap,
times = times,
bootstrap = FALSE,
times = 0,
...)
get_distribution(fcst)
}
Expand All @@ -67,12 +82,6 @@ forecast.soothsayer <- function( object,
fcst_means <- as.matrix(dplyr::bind_cols(fcst_means))
# and compute the final mean
fcst_means <- c( fcst_means %*% weights )
# also get forecast variances

# I am quite unsure at this point that computing the variances is very sensible
# for non-normal distributions its bad, but even for normal distributions... there is
# no guarantee that the end result is normal. if its bimodal... then we are a bit
# screwed either way. lets not do that.
distributional::dist_degenerate( fcst_means )
}
#' @importFrom stats residuals
Expand Down
11 changes: 10 additions & 1 deletion experiments/helpers/multitarget_weight_oracle.R
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ full_tbl <- tbl %>%
dplyr::rename_with( .fn = ~paste0("target_",.x) ) %>%
dplyr::rename(key = target_key)



top_5 <- tbl %>%
dplyr::mutate( weight = 1/RMSSE ) %>%
dplyr::mutate( rank = dplyr::min_rank(weight) ) %>%
Expand All @@ -37,6 +39,12 @@ top_5 <- tbl %>%
dplyr::rename_with( .fn = ~paste0("target_",.x) ) %>%
dplyr::rename(key = target_key)

best <- tbl %>%
dplyr::mutate( weight = 1/RMSSE ) %>%
dplyr::mutate( rank = dplyr::min_rank(weight) ) %>%
dplyr::ungroup() %>%
dplyr::filter( rank == 1 ) %>%
dplyr::select( key, .model)

qs::qsave( dplyr::full_join( full_tbl, features, by = c("key") ),
"oracle_weighed.qs")
Expand All @@ -45,7 +53,8 @@ qs::qsave( dplyr::full_join( top_5,
features, by = c("key") ),
"oracle_weighed_top_5.qs")


qs::qsave( dplyr::full_join( best, features, by = c("key") ),
"oracle_best.qs" )

# ignore individual forecast hs for now

Expand Down
16 changes: 12 additions & 4 deletions experiments/tests/test_xreg.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,24 @@ test <- dplyr::filter( ex_data, Month > tsibble::yearmonth("2017 Jan") )

fabletools::model(
train,
ar = fable::AR(Lambs ~ order(0) + 1 + Sheep),
arima = fable::ARIMA(Lambs),
# ar = fable::AR(Lambs ~ order(0) + 1 + Sheep),
soothsayer = soothsayer(Lambs ~ rules(
arima ~ .length > 12,
ar ~ TRUE
ets ~ TRUE,
# ar ~ TRUE,
theta ~ TRUE
) +
model_aliases(
ar = fix_model_parameters(fable::AR, order(0:3) + 1 + Sheep),
arima = fable::ARIMA) +
# ar = fix_model_parameters(fable::AR, order(0:3) + 1 + Sheep),
ets = fable::ETS,
arima = fable::ARIMA,
theta = fable::THETA) +
combiner(combiner_greedy_stacking) +
Sheep )
) -> fitted

fcsts <- forecast( fitted, test )
gens <- generate( fitted, test, bootstrap = TRUE, times = 5 )


0 comments on commit 4c8be5e

Please sign in to comment.