Skip to content

Commit

Permalink
further improve read_csv_as_stanfit
Browse files Browse the repository at this point in the history
  • Loading branch information
paul-buerkner committed Mar 19, 2024
1 parent 85030d9 commit 05a0eb0
Showing 1 changed file with 29 additions and 40 deletions.
69 changes: 29 additions & 40 deletions R/backends.R
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ fit_model <- function(model, backend, ...) {
}

out <- read_csv_as_stanfit(
out$output_files(), variables = out$metadata()$variables,
out$output_files(), variables = out$metadata()$stan_variables,
model = model, exclude = exclude, algorithm = algorithm
)

Expand Down Expand Up @@ -691,8 +691,6 @@ read_csv_as_stanfit <- function(files, variables = NULL, sampler_diagnostics = N
variables <- repair_variable_names(variables)
variables <- unique(sub("\\[.+", "", variables))
variables <- setdiff(variables, exclude)
# cmdstanr deals with special variables inconsistently
# below is an attempt to deal with this somehow (part 1)
# temp fix for cmdstanr not recognizing the variable names it produces #1473
if (algorithm %in% c("meanfield", "fullrank")) {
variables <- ifelse(variables == "lp_approx__", "log_g__", variables)
Expand All @@ -713,41 +711,23 @@ read_csv_as_stanfit <- function(files, variables = NULL, sampler_diagnostics = N
model_name = gsub(".csv", "", basename(files[[1]]))

# @model_pars
svars <- variables %||% csfit$metadata$stan_variables
# cmdstanr deals with special variables inconsistently
# below is an attempt to deal with this somehow (part 2)
special_vars <- c("lp__", "lp_approx__", "log_g__")
special_vars <- intersect(special_vars, svars)
# by default just assume all special vars are present in draws
vars_in_draws <- svars
if ("post_warmup_draws" %in% names(csfit)) {
vars_in_draws <- variables(csfit$post_warmup_draws)
} else if ("draws" %in% names(csfit)) {
vars_in_draws <- variables(csfit$draws)
}
for (v in special_vars) {
if (v %in% vars_in_draws) {
# put special vars at the end
svars <- c(setdiff(svars, v), v)
} else {
# remove special vars as they do not seem to be stored in draws
svars <- setdiff(svars, v)
}
model_pars <- csfit$metadata$stan_variables
if (!is.null(variables)) {
model_pars <- intersect(model_pars, variables)
}
pars_oi <- svars
par_names <- csfit$metadata$model_params
# special variables will be added back later on
special_vars <- c("lp__", "lp_approx__")
model_pars <- setdiff(model_pars, special_vars)

# @par_dims
par_dims <- vector("list", length(svars))

names(par_dims) <- svars
par_dims <- lapply(par_dims, function(x) x <- integer(0))

pdims_num <- ulapply(
svars, function(x) sum(grepl(paste0("^", x, "\\[.*\\]$"), par_names))
par_dims <- vector("list", length(model_pars))
names(par_dims) <- model_pars
par_dims <- lapply(par_dims, function(x) integer(0))
pdims_num <- ulapply(model_pars, function(x)
sum(grepl(paste0("^", x, "\\[.*\\]$"), csfit$metadata$model_params))
)
par_dims[pdims_num != 0] <-
csfit$metadata$stan_variable_sizes[svars][pdims_num != 0]
csfit$metadata$stan_variable_sizes[model_pars][pdims_num != 0]

# @mode
mode <- 0L
Expand Down Expand Up @@ -787,8 +767,11 @@ read_csv_as_stanfit <- function(files, variables = NULL, sampler_diagnostics = N
csfit$post_warmup_draws <- NULL

# prepare sampler diagnostics
diagnostics <- rbind(csfit$warmup_sampler_diagnostics,
csfit$post_warmup_sampler_diagnostics)
diagnostics <- rbind(
csfit$warmup_sampler_diagnostics,
csfit$post_warmup_sampler_diagnostics
)

# manage memory
csfit$warmup_sampler_diagnostics <- NULL
csfit$post_warmup_sampler_diagnostics <- NULL
Expand Down Expand Up @@ -820,10 +803,16 @@ read_csv_as_stanfit <- function(files, variables = NULL, sampler_diagnostics = N
samples <- as.data.frame(samples)
chain_ids <- samples$.chain
samples[res_vars] <- NULL
if ("lp__" %in% colnames(samples)) {
samples <- move2end(samples, "lp__")
}

# only add special variables to dims if there are present in samples
# this ensures that dims_oi, pars_oi, and fnames_oi match with samples
for (p in special_vars) {
if (p %in% colnames(samples)) {
samples <- move2end(samples, p)
par_dims[[p]] <- integer(0)
}
}
model_pars <- names(par_dims)
fnames_oi <- colnames(samples)

# split samples into chains
Expand Down Expand Up @@ -913,7 +902,7 @@ read_csv_as_stanfit <- function(files, variables = NULL, sampler_diagnostics = N
n_save = rep(n_iter_sample + n_iter_warmup, n_chains),
warmup2 = rep(n_iter_warmup, n_chains),
permutation = perm_lst,
pars_oi = pars_oi,
pars_oi = model_pars,
dims_oi = par_dims,
fnames_oi = fnames_oi,
n_flatnames = length(fnames_oi)
Expand Down Expand Up @@ -1005,7 +994,7 @@ read_csv_as_stanfit <- function(files, variables = NULL, sampler_diagnostics = N
out <- new(
"stanfit",
model_name = model_name,
model_pars = svars,
model_pars = model_pars,
par_dims = par_dims,
mode = mode,
sim = sim,
Expand Down

0 comments on commit 05a0eb0

Please sign in to comment.