Skip to content

Commit

Permalink
fix initial xbcf hsk model and pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
socket778 committed Nov 17, 2022
1 parent 5e965c1 commit ff38943
Show file tree
Hide file tree
Showing 5 changed files with 533 additions and 119 deletions.
4 changes: 2 additions & 2 deletions src/XBCF_discrete_heterosk.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ Rcpp::List XBCF_discrete_heterosk_cpp(arma::mat y, arma::mat Z, arma::mat X_con,
}

// define model
XBCFDiscreteModel *model = new XBCFDiscreteModel(kap, s, tau_con, tau_mod, alpha_con, beta_con, alpha_mod, beta_mod, sampling_tau, tau_con_kap, tau_con_s, tau_mod_kap, tau_mod_s);
hskXBCFDiscreteModel *model = new hskXBCFDiscreteModel(kap, s, tau_con, tau_mod, alpha_con, beta_con, alpha_mod, beta_mod, sampling_tau, tau_con_kap, tau_con_s, tau_mod_kap, tau_mod_s);
model->setNoSplitPenalty(no_split_penalty);

// State settings
Expand All @@ -151,7 +151,7 @@ Rcpp::List XBCF_discrete_heterosk_cpp(arma::mat y, arma::mat Z, arma::mat X_con,
X_struct x_struct_mod(Xpointer_mod, &y_std, N, Xorder_std_mod, p_categorical_mod, p_continuous_mod, &initial_theta_mod, num_trees_mod);

////////////////////////////////////////////////////////////////
mcmc_loop_xbcf_discrete(Xorder_std_con, Xorder_std_mod, verbose, sigma0_draw_xinfo, sigma1_draw_xinfo, a_xinfo, b_xinfo, trees_con, trees_mod, no_split_penalty, state, model, x_struct_con, x_struct_mod);
mcmc_loop_xbcf_discrete_heteroskedastic(Xorder_std_con, Xorder_std_mod, verbose, sigma0_draw_xinfo, sigma1_draw_xinfo, a_xinfo, b_xinfo, trees_con, trees_mod, no_split_penalty, state, model, x_struct_con, x_struct_mod);

// R Objects to Return
Rcpp::NumericMatrix sigma0_draw(num_trees_con + num_trees_mod, num_sweeps); // save predictions of each tree
Expand Down
2 changes: 1 addition & 1 deletion src/mcmc_loop.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,6 @@ void mcmc_loop_xbcf_discrete_heteroskedastic(matrix<size_t> &Xorder_std_con,
vector<vector<tree>> &trees_mod,
double no_split_penalty,
State &state,
XBCFDiscreteModel *model,
hskXBCFDiscreteModel *model,
X_struct &x_struct_con,
X_struct &x_struct_mod);
2 changes: 1 addition & 1 deletion src/mcmc_loop_xbcf_hsk.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ void mcmc_loop_xbcf_discrete_heteroskedastic(matrix<size_t> &Xorder_std_con,
vector<vector<tree>> &trees_mod,
double no_split_penalty,
State &state,
XBCFDiscreteModel *model,
hskXBCFDiscreteModel *model,
X_struct &x_struct_con,
X_struct &x_struct_mod)
{
Expand Down
93 changes: 55 additions & 38 deletions src/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -691,80 +691,97 @@ class logNormalModel : public Model
//////////////////////////////////////////////////////////////////////////////////////
// Heteroskedastic XBCF discrete (binary) Model
//////////////////////////////////////////////////////////////////////////////////////

class hskXBCFDiscreteModel : public Model
{
public:
size_t dim_suffstat = 2;
size_t dim_suffstat = 4;

// model prior
// prior on sigma
double kap;
double s;
double tau_kap;
double tau_s;
double tau_con_kap;
double tau_con_s;
double tau_mod_kap;
double tau_mod_s;
// prior on leaf parameter
double tau; // might be updated if sampling tau
double tau_mean; // copy of the original value
double tau_con; // might be updated if sampling tau
double tau_mod;
double tau_con_mean; // copy of the original value
double tau_mod_mean;

double alpha_con;
double alpha_mod;
double beta_con;
double beta_mod;
bool sampling_tau;

hskXBCFDiscreteModel(double kap, double s, double tau, double alpha, double beta, bool sampling_tau, double tau_kap, double tau_s) : Model(1, 2)
hskXBCFDiscreteModel(double kap, double s, double tau_con, double tau_mod, double alpha_con, double beta_con, double alpha_mod, double beta_mod, bool sampling_tau, double tau_con_kap, double tau_con_s, double tau_mod_kap, double tau_mod_s) : Model(1, 4)
{
this->kap = kap;
this->s = s;
this->tau_kap = tau_kap;
this->tau_s = tau_s;
this->tau = tau;
this->tau_mean = tau;
this->alpha = alpha;
this->beta = beta;
this->dim_residual = 3;
this->tau_con_kap = tau_con_kap;
this->tau_con_s = tau_con_s;
this->tau_mod_kap = tau_mod_kap;
this->tau_mod_s = tau_mod_s;
this->tau_con = tau_con;
this->tau_mod = tau_mod;
this->tau_con_mean = tau_con;
this->tau_mod_mean = tau_mod;
this->alpha_con = alpha_con;
this->alpha_mod = alpha_mod;
this->beta_con = beta_con;
this->beta_mod = beta_mod;
this->alpha = alpha_con;
this->beta = beta_con;
this->dim_residual = 1;
this->class_operating = 0;
this->sampling_tau = sampling_tau;
}

hskXBCFDiscreteModel(double kap, double s, double tau, double alpha, double beta) : Model(1, 2)
{
this->kap = kap;
this->s = s;
this->tau = tau;
this->tau_mean = tau;
this->alpha = alpha;
this->beta = beta;
this->dim_residual = 3;
this->class_operating = 0;
this->sampling_tau = true;
}

hskXBCFDiscreteModel() : Model(1, 2) {}
hskXBCFDiscreteModel() : Model(1, 4) {}

Model *clone() { return new hskXBCFDiscreteModel(*this); }

// redefined functions
void ini_residual_std(State &state);

void initialize_root_suffstat(State &state, std::vector<double> &suff_stat);

void incSuffStat(State &state, size_t index_next_obs, std::vector<double> &suffstats);

void samplePars(State &state, std::vector<double> &suff_stat, std::vector<double> &theta_vector, double &prob_leaf);

void update_tau_per_forest(State &state, size_t sweeps, vector<vector<tree>> & trees);
void update_state(State &state, size_t tree_ind, X_struct &x_struct, size_t ind);

void update_tau(State &state, size_t tree_ind, size_t sweeps, vector<vector<tree>> &trees);

void update_tau_per_forest(State &state, size_t sweeps, vector<vector<tree>> &trees);

void initialize_root_suffstat(State &state, std::vector<double> &suff_stat);

void updateNodeSuffStat(State &state, std::vector<double> &suff_stat, matrix<size_t> &Xorder_std, size_t &split_var, size_t row_ind);

void calculateOtherSideSuffStat(std::vector<double> &parent_suff_stat, std::vector<double> &lchild_suff_stat, std::vector<double> &rchild_suff_stat, size_t &N_parent, size_t &N_left, size_t &N_right, bool &compute_left_side);

void state_sweep(size_t tree_ind, size_t M, State &state, X_struct &x_struct) const;
// void state_sweep(State&state, size_t tree_ind, size_t M, X_struct &x_struct) const;

double likelihood(std::vector<double> &temp_suff_stat, std::vector<double> &suff_stat_all, size_t N_left, bool left_side, bool no_split, State &state) const;

void predict_std(const double *Xtestpointer, size_t N_test, size_t p, size_t num_trees, size_t num_sweeps, matrix<double> &yhats_test_xinfo, vector<vector<tree>> &trees);
void ini_tau_mu_fit(State &state);

void switch_state_params(State &state);
void ini_residual_std(State &state);

void store_residual(State &state);
void predict_std(matrix<double> &Ztestpointer, const double *Xtestpointer_con, const double *Xtestpointer_mod, size_t N_test, size_t p_con, size_t p_mod, size_t num_trees_con, size_t num_trees_mod, size_t num_sweeps, matrix<double> &yhats_test_xinfo, matrix<double> &prognostic_xinfo, matrix<double> &treatment_xinfo, vector<vector<tree>> &trees_con, vector<vector<tree>> &trees_mod);

void set_treatmentflag(State &state, bool value);

void subtract_old_tree_fit(size_t tree_ind, State &state, X_struct &x_struct);

void add_new_tree_fit(size_t tree_ind, State &state, X_struct &x_struct);

void update_partial_residuals(size_t tree_ind, State &state, X_struct &x_struct);

void update_split_counts(State &state, size_t tree_ind);

void update_a(State &state);

void update_b(State &state);
};

#endif
Loading

0 comments on commit ff38943

Please sign in to comment.