Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

posterior_predict() returning incorrect fitted values for rstanarm multivariate regression models #271

Closed
rudeboybert opened this issue Oct 5, 2020 · 6 comments
Milestone

Comments

@rudeboybert
Copy link

rudeboybert commented Oct 5, 2020

Hello, thank you for a wonderful package. If only it existed while I was in grad school!

add_predicted_draws() seems to be returning incorrect fitted values for rstanarm::stan_mvmer() multivariate GLM models. I've narrowed down the issue to a missing m argument that's passed to rstanarm::posterior_predict() for multivariate stanmvreg models, but not for regular stanreg models.

I'm happy to take a stab at a PR, but first wanted to check in and see if you wanted to go down the stanmvreg rabbit hole, or if you'd prefer throwing a "we don't support these types of models" warning like you do here for "ulam", "quap", "map", "map2stan" models.

Here is a reprex:

Create 3-dim outcome multivariate example data set

library(tidyverse)
library(rstanarm)
library(tidybayes)

multivariate_data <- bind_rows(
  tibble(x = rep(1:50, times = 4)) %>% mutate(obs = "y1", y = 0 + 0.4*x),
  tibble(x = rep(1:50, times = 4)) %>% mutate(obs = "y2", y = 30 + 0.4*x),
  tibble(x = rep(1:50, times = 4)) %>% mutate(obs = "y3", y = 60 + 0.4*x)
) %>%
  mutate(
    group = rep(1:4, each = 50) %>% rep(times = 3),
    y = y + rnorm(n())
  )

ggplot(multivariate_data, aes(x = x, y = y, col = obs)) +
  geom_point() +
  geom_smooth(method = "lm", se = FALSE)

Note the group means:

multivariate_data %>%
  group_by(obs) %>%
  summarize(mean_y = mean(y))
#> `summarise()` ungrouping output (override with `.groups` argument)
#> # A tibble: 3 x 2
#>   obs   mean_y
#>   <chr>  <dbl>
#> 1 y1      10.2
#> 2 y2      40.1
#> 3 y3      70.3

Fit multivariate regression model & get (incorrect) posterior fitted values

# Convert data to wide format
multivariate_data_wide <- multivariate_data %>%
  pivot_wider(names_from = obs, values_from = y)

# Fit model
stanmvreg_model <- stan_mvmer(
  formula = list(
    y1 ~ x + (1|group),
    y2 ~ x + (1|group),
    y3 ~ x + (1|group)
  ),
  data = multivariate_data_wide,
  seed = 76,
  chains = 1,
  iter = 2000
)

# Get posterior means
multivariate_data %>%
  add_predicted_draws(stanmvreg_model) %>%
  group_by(obs) %>%
  summarize(mean_y = mean(y), mean_y_hat = mean(.prediction))
#> `summarise()` ungrouping output (override with `.groups` argument)
#> # A tibble: 3 x 3
#>   obs   mean_y mean_y_hat
#>   <chr>  <dbl>      <dbl>
#> 1 y1      10.2       10.2
#> 2 y2      40.1       10.2
#> 3 y3      70.3       10.2

As you can see, the posterior means are off for y2 and y3.

Posterior fitted values using rstanarm package

Using the root rstanarm::posterior_predict() function that's being wrapped by add_predicted_draws(), it seems only the correct posterior mean for y1 is being returned again.

stanmvreg_model %>%
  posterior_predict() %>%
  apply(1, mean) %>%
  mean()
#> [1] 10.20104

However, looking at help file ?rstanarm::posterior_predict -> Usage -> there is an extra argument "m" needed for models of class stanmvreg that defaults to m = 1. If you specify m = 1, 2, 3, for the (y1, y2, y3) multivariate outcome we have, we get the correct posterior means:

stanmvreg_model %>%
  posterior_predict(m = 1) %>%
  apply(1, mean) %>%
  mean()
#> [1] 10.20179
stanmvreg_model %>%
  posterior_predict(m = 2) %>%
  apply(1, mean) %>%
  mean()
#> [1] 40.1348
stanmvreg_model %>%
  posterior_predict(m = 3) %>%
  apply(1, mean) %>%
  mean()
#> [1] 70.29842

Attempting to pass a m argument to add_predicted_draws() does throw an error, but for a model of type "numeric"

multivariate_data %>%
  add_predicted_draws(stanmvreg_model, m = 2)
#> Error in predicted_draws.default(model, newdata, prediction, ..., n = n, : Models of type "numeric" are not currently supported by `predicted_draws`.
#> You might try using `add_draws()` for models that do not have explicit fit/prediction
#> support; see help("add_draws") for an example. See also help("tidybayes-models") for
#> more information on what functions are supported by what model types.

sessionInfo()

R version 4.0.1 (2020-06-06)
Platform: x86_64-apple-darwin17.0 (64-bit)
Running under: macOS Catalina 10.15.6

Matrix products: default
BLAS:   /System/Library/Frameworks/Accelerate.framework/Versions/A/Frameworks/vecLib.framework/Versions/A/libBLAS.dylib
LAPACK: /Library/Frameworks/R.framework/Versions/4.0/Resources/lib/libRlapack.dylib

locale:
[1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
 [1] tidybayes_2.1.1.9000 rstanarm_2.21.1      Rcpp_1.0.5           forcats_0.5.0       
 [5] stringr_1.4.0        dplyr_1.0.2          purrr_0.3.4          readr_1.3.1         
 [9] tidyr_1.1.2          tibble_3.0.3         ggplot2_3.3.2        tidyverse_1.3.0     

loaded via a namespace (and not attached):
  [1] minqa_1.2.4          colorspace_1.4-1     ellipsis_0.3.1       ggridges_0.5.2      
  [5] rsconnect_0.8.16     markdown_1.1         base64enc_0.1-3      fs_1.5.0            
  [9] rstudioapi_0.11      farver_2.0.3         rstan_2.21.2         svUnit_1.0.3        
 [13] DT_0.15              fansi_0.4.1          lubridate_1.7.9      xml2_1.3.2          
 [17] splines_4.0.1        codetools_0.2-16     knitr_1.30           shinythemes_1.1.2   
 [21] bayesplot_1.7.2      jsonlite_1.7.1       nloptr_1.2.2.2       packrat_0.5.0       
 [25] broom_0.7.0          dbplyr_1.4.4         ggdist_2.2.0         shiny_1.5.0         
 [29] clipr_0.7.0          compiler_4.0.1       httr_1.4.2           backports_1.1.10    
 [33] assertthat_0.2.1     Matrix_1.2-18        fastmap_1.0.1        cli_2.0.2           
 [37] later_1.1.0.1        htmltools_0.5.0      prettyunits_1.1.1    tools_4.0.1         
 [41] igraph_1.2.5         coda_0.19-3          gtable_0.3.0         glue_1.4.2          
 [45] reshape2_1.4.4       V8_3.2.0             cellranger_1.1.0     vctrs_0.3.4         
 [49] nlme_3.1-149         crosstalk_1.1.0.1    xfun_0.17            ps_1.3.4            
 [53] lme4_1.1-23          rvest_0.3.6          mime_0.9             miniUI_0.1.1.1      
 [57] lifecycle_0.2.0      gtools_3.8.2         statmod_1.4.34       MASS_7.3-53         
 [61] zoo_1.8-8            scales_1.1.1         colourpicker_1.1.0   hms_0.5.3           
 [65] promises_1.1.1       parallel_4.0.1       inline_0.3.16        shinystan_2.5.0     
 [69] yaml_2.2.1           curl_4.3             gridExtra_2.3        loo_2.3.1           
 [73] StanHeaders_2.21.0-6 stringi_1.5.3        dygraphs_1.1.1.6     boot_1.3-25         
 [77] pkgbuild_1.1.0       rlang_0.4.7          pkgconfig_2.0.3      matrixStats_0.57.0  
 [81] distributional_0.2.0 evaluate_0.14        lattice_0.20-41      labeling_0.3        
 [85] rstantools_2.1.1     htmlwidgets_1.5.1    tidyselect_1.1.0     processx_3.4.4      
 [89] plyr_1.8.6           magrittr_1.5         R6_2.4.1             generics_0.0.2      
 [93] DBI_1.1.0            whisker_0.4          mgcv_1.8-33          pillar_1.4.6        
 [97] haven_2.3.1          withr_2.3.0          xts_0.12.1           survival_3.2-7      
[101] modelr_0.1.8         crayon_1.3.4         arrayhelpers_1.1-0   utf8_1.1.4          
[105] rmarkdown_2.3        grid_4.0.1           readxl_1.3.1         blob_1.2.1          
[109] callr_3.4.4          threejs_0.3.3        reprex_0.3.0         digest_0.6.25       
[113] xtable_1.8-4         httpuv_1.5.4         RcppParallel_5.0.2   stats4_4.0.1        
[117] munsell_0.5.0        shinyjs_2.0.0     
@mjskay
Copy link
Owner

mjskay commented Oct 10, 2020

Thanks for this! Yes, I'd definitely like to support this model type, and a PR would be welcome.

The ideal solution I think would be to pattern the solution after how tidybayes handles multivariate models for brms, using the (slightly poorly named, for historical reasons) category argument. In the best case scenario, if it is possible to determine the names of the different y variables from the model object itself, the default approach should be to generate predictions from all of the response variables and return a tidy format dataframe with all response variables and a .category column indicating which y variable each row comes from. Does that make sense?

@rudeboybert
Copy link
Author

Yes, your comment makes sense and thanks for the pointers. I'll take a stab at it in the next few days using brms as a template.

@mjskay
Copy link
Owner

mjskay commented Oct 16, 2020

Sweet, thanks! It would be much appreciated. Let me know if you need any pointers about the internals, the fitted/predicted_draws stuff for brms is particularly hairy.

@mjskay mjskay added this to the Next release milestone Jun 29, 2021
@mjskay
Copy link
Owner

mjskay commented Jul 7, 2021

Note to self: pretty sure the m parameter doesn't work because of partial argument matching. Should be able to fix this in the next round of iterations for predicted_draws

@mjskay
Copy link
Owner

mjskay commented Jul 9, 2021

This should now be fixed in the github version: you can now pass the m parameter through properly. So you can do something like this (note obs should not be in the prediction grid):

grid = multivariate_data %>%
  modelr::data_grid(x, group) 

preds = bind_rows(
  add_predicted_draws(mutate(grid, obs = "y1"), stanmvreg_model, m = 1),
  add_predicted_draws(mutate(grid, obs = "y2"), stanmvreg_model, m = 2),
  add_predicted_draws(mutate(grid, obs = "y3"), stanmvreg_model, m = 3)
)

preds %>%
  median_qi()
# # A tibble: 600 x 10
#        x group obs    .row .prediction .lower .upper .width .point .interval
#    <int> <int> <chr> <int>       <dbl>  <dbl>  <dbl>  <dbl> <chr>  <chr>    
#  1     1     1 y1        1       0.616  -1.48   2.68   0.95 median qi       
#  2     1     1 y2        1      30.4    28.4   32.3    0.95 median qi       
#  3     1     1 y3        1      60.3    58.3   62.5    0.95 median qi       
#  4     1     2 y1        2       0.390  -1.61   2.60   0.95 median qi       
#  5     1     2 y2        2      30.4    28.4   32.3    0.95 median qi       
#  6     1     2 y3        2      60.4    58.4   62.2    0.95 median qi       
#  7     1     3 y1        3       0.452  -1.82   2.48   0.95 median qi       
#  8     1     3 y2        3      30.4    28.4   32.4    0.95 median qi       
#  9     1     3 y3        3      60.6    58.5   62.4    0.95 median qi       
# 10     1     4 y1        4       0.566  -1.53   2.64   0.95 median qi       
# # ... with 590 more rows

With the new rvar-based workflow that will be coming in the next version (using the rvar datatype from {posterior} when it hits CRAN, which should be soon), you will also be able to easily create rvar columns of predictions using add_predicted_rvars() instead of add_predicted_draws():

multivariate_data %>%
  modelr::data_grid(x, group) %>%
  add_predicted_rvars(stanmvreg_model, prediction = "y1", m = 1) %>%
  add_predicted_rvars(stanmvreg_model, prediction = "y2", m = 2) %>%
  add_predicted_rvars(stanmvreg_model, prediction = "y3", m = 3)
# # A tibble: 200 x 5
#        x group          y1         y2        y3
#    <int> <int>      <rvar>     <rvar>    <rvar>
#  1     1     1  0.57 ± 1.0  30 ± 1.00  60 ± 1.0
#  2     1     2  0.48 ± 1.1  30 ± 1.04  60 ± 1.0
#  3     1     3  0.47 ± 1.1  30 ± 1.03  61 ± 1.0
#  4     1     4  0.52 ± 1.1  30 ± 0.99  60 ± 1.0
#  5     2     1  0.98 ± 1.0  31 ± 1.03  61 ± 1.0
#  6     2     2  0.88 ± 1.1  31 ± 1.03  61 ± 1.0
#  7     2     3  0.87 ± 1.1  31 ± 1.02  61 ± 1.0
#  8     2     4  0.91 ± 1.1  31 ± 1.03  61 ± 1.0
#  9     3     1  1.33 ± 1.1  31 ± 1.02  61 ± 1.1
# 10     3     2  1.29 ± 1.1  31 ± 1.01  61 ± 1.0
# # ... with 190 more rows

For more on the rvar stuff you can check out vignette("rvar", package = "posterior") or vignette("tidy-posterior", package = "tidybayes") on the github versions of both packages.

@mjskay mjskay closed this as completed Jul 9, 2021
@rudeboybert
Copy link
Author

Awesome! Thanks for doing this!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants