forked from paul-buerkner/brms
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathexclude_pars.R
113 lines (107 loc) · 3.11 KB
/
exclude_pars.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
exclude_pars <- function(bterms, data = NULL, ranef = empty_ranef(),
save_ranef = TRUE, save_mevars = FALSE,
save_all_pars = FALSE) {
# list parameters NOT to be saved by Stan
# Args:
# bterms: object of class brmsterms
# data: data passed by the user
# ranef: output of tidy_ranef
# save_ranef: should group-level effects be saved?
# save_mevars: should samples of noise-free variables be saved?
# Returns:
# a vector of parameters to be excluded
save_ranef <- as_one_logical(save_ranef)
save_mevars <- as_one_logical(save_mevars)
save_all_pars <- as_one_logical(save_all_pars)
out <- exclude_pars_internal(
bterms, data = data, save_all_pars = save_all_pars,
save_mevars = save_mevars
)
meef <- tidy_meef(bterms, data)
if (nrow(meef)) {
I <- seq_along(unique(meef$grname))
K <- seq_len(nrow(meef))
c(out) <- paste0(c("Xme", "Corme_"), I)
if (!save_all_pars) {
c(out) <- c(paste0("zme_", K), paste0("Lme_", I))
}
if (!save_mevars) {
c(out) <- paste0("Xme_", K)
}
}
if (has_rows(ranef)) {
rm_re_pars <- c(if (!save_all_pars) c("z", "L"), "Cor", "r")
for (id in unique(ranef$id)) {
c(out) <- paste0(rm_re_pars, "_", id)
}
if (!save_ranef) {
p <- usc(combine_prefix(ranef))
c(out) <- paste0("r_", ranef$id, p, "_", ranef$cn)
}
tranef <- get_dist_groups(ranef, "student")
if (!save_all_pars && has_rows(tranef)) {
c(out) <- paste0("udf_", tranef$ggn)
}
}
att <- nlist(save_ranef, save_mevars, save_all_pars)
do.call(structure, c(list(unique(out)), att))
}
exclude_pars_internal <- function(x, ...) {
UseMethod("exclude_pars_internal")
}
#' @export
exclude_pars_internal.default <- function(x, ...) {
NULL
}
#' @export
exclude_pars_internal.mvbrmsterms <- function(x, save_all_pars, ...) {
out <- c("Rescor", "Sigma")
if (!save_all_pars) {
c(out) <- c("Lrescor", "LSigma")
}
for (i in seq_along(x$terms)) {
c(out) <- exclude_pars_internal(x$terms[[i]], save_all_pars, ...)
}
out
}
#' @export
exclude_pars_internal.brmsterms <- function(x, save_all_pars, save_mevars, ...) {
p <- usc(combine_prefix(x))
out <- paste0(c("res_cov_matrix", names(x$dpars)), p)
if (!save_all_pars) {
c(out) <- c(
paste0("temp", p, "_Intercept1"),
paste0("ordered", p, "_Intercept"),
paste0(c("theta", "zcar"), p)
)
for (dp in names(x$dpars)) {
c(out) <- exclude_pars_internal(x$dpars[[dp]], ...)
}
}
if (!save_mevars && is.formula(x$adforms$mi)) {
c(out) <- paste0("Yl", p)
}
out
}
#' @export
exclude_pars_internal.btnl <- function(x, ...) {
out <- NULL
for (nlp in names(x$nlpars)) {
c(out) <- exclude_pars_internal(x$nlpars[[nlp]], ...)
}
out
}
#' @export
exclude_pars_internal.btl <- function(x, data, ...) {
p <- usc(combine_prefix(x))
out <- c(
paste0("temp", p, "_Intercept"),
paste0(c("hs_local", "hs_global", "zb"), p)
)
smef <- tidy_smef(x, data)
for (i in seq_len(nrow(smef))) {
nb <- seq_len(smef$nbases[i])
c(out) <- paste0("zs", p, "_", i, "_", nb)
}
out
}