-
Notifications
You must be signed in to change notification settings - Fork 59
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
posterior_predict() returning incorrect fitted values for rstanarm multivariate regression models #271
Comments
Thanks for this! Yes, I'd definitely like to support this model type, and a PR would be welcome. The ideal solution I think would be to pattern the solution after how tidybayes handles multivariate models for |
Yes, your comment makes sense and thanks for the pointers. I'll take a stab at it in the next few days using |
Sweet, thanks! It would be much appreciated. Let me know if you need any pointers about the internals, the fitted/predicted_draws stuff for brms is particularly hairy. |
Note to self: pretty sure the |
This should now be fixed in the github version: you can now pass the grid = multivariate_data %>%
modelr::data_grid(x, group)
preds = bind_rows(
add_predicted_draws(mutate(grid, obs = "y1"), stanmvreg_model, m = 1),
add_predicted_draws(mutate(grid, obs = "y2"), stanmvreg_model, m = 2),
add_predicted_draws(mutate(grid, obs = "y3"), stanmvreg_model, m = 3)
)
preds %>%
median_qi()
# # A tibble: 600 x 10
# x group obs .row .prediction .lower .upper .width .point .interval
# <int> <int> <chr> <int> <dbl> <dbl> <dbl> <dbl> <chr> <chr>
# 1 1 1 y1 1 0.616 -1.48 2.68 0.95 median qi
# 2 1 1 y2 1 30.4 28.4 32.3 0.95 median qi
# 3 1 1 y3 1 60.3 58.3 62.5 0.95 median qi
# 4 1 2 y1 2 0.390 -1.61 2.60 0.95 median qi
# 5 1 2 y2 2 30.4 28.4 32.3 0.95 median qi
# 6 1 2 y3 2 60.4 58.4 62.2 0.95 median qi
# 7 1 3 y1 3 0.452 -1.82 2.48 0.95 median qi
# 8 1 3 y2 3 30.4 28.4 32.4 0.95 median qi
# 9 1 3 y3 3 60.6 58.5 62.4 0.95 median qi
# 10 1 4 y1 4 0.566 -1.53 2.64 0.95 median qi
# # ... with 590 more rows With the new rvar-based workflow that will be coming in the next version (using the rvar datatype from {posterior} when it hits CRAN, which should be soon), you will also be able to easily create rvar columns of predictions using multivariate_data %>%
modelr::data_grid(x, group) %>%
add_predicted_rvars(stanmvreg_model, prediction = "y1", m = 1) %>%
add_predicted_rvars(stanmvreg_model, prediction = "y2", m = 2) %>%
add_predicted_rvars(stanmvreg_model, prediction = "y3", m = 3)
# # A tibble: 200 x 5
# x group y1 y2 y3
# <int> <int> <rvar> <rvar> <rvar>
# 1 1 1 0.57 ± 1.0 30 ± 1.00 60 ± 1.0
# 2 1 2 0.48 ± 1.1 30 ± 1.04 60 ± 1.0
# 3 1 3 0.47 ± 1.1 30 ± 1.03 61 ± 1.0
# 4 1 4 0.52 ± 1.1 30 ± 0.99 60 ± 1.0
# 5 2 1 0.98 ± 1.0 31 ± 1.03 61 ± 1.0
# 6 2 2 0.88 ± 1.1 31 ± 1.03 61 ± 1.0
# 7 2 3 0.87 ± 1.1 31 ± 1.02 61 ± 1.0
# 8 2 4 0.91 ± 1.1 31 ± 1.03 61 ± 1.0
# 9 3 1 1.33 ± 1.1 31 ± 1.02 61 ± 1.1
# 10 3 2 1.29 ± 1.1 31 ± 1.01 61 ± 1.0
# # ... with 190 more rows For more on the rvar stuff you can check out |
Awesome! Thanks for doing this! |
Hello, thank you for a wonderful package. If only it existed while I was in grad school!
add_predicted_draws()
seems to be returning incorrect fitted values forrstanarm::stan_mvmer()
multivariate GLM models. I've narrowed down the issue to a missingm
argument that's passed torstanarm::posterior_predict()
for multivariatestanmvreg
models, but not for regularstanreg
models.I'm happy to take a stab at a PR, but first wanted to check in and see if you wanted to go down the
stanmvreg
rabbit hole, or if you'd prefer throwing a "we don't support these types of models" warning like you do here for"ulam", "quap", "map", "map2stan"
models.Here is a reprex:
Create 3-dim outcome multivariate example data set
Note the group means:
Fit multivariate regression model & get (incorrect) posterior fitted values
As you can see, the posterior means are off for y2 and y3.
Posterior fitted values using rstanarm package
Using the root
rstanarm::posterior_predict()
function that's being wrapped byadd_predicted_draws()
, it seems only the correct posterior mean for y1 is being returned again.However, looking at help file
?rstanarm::posterior_predict
-> Usage -> there is an extra argument "m" needed for models of classstanmvreg
that defaults tom = 1
. If you specify m = 1, 2, 3, for the (y1, y2, y3) multivariate outcome we have, we get the correct posterior means:Attempting to pass a
m
argument toadd_predicted_draws()
does throw an error, but for a model of type"numeric"
sessionInfo()
The text was updated successfully, but these errors were encountered: