-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathoem_fb_big.cpp
258 lines (199 loc) · 7.31 KB
/
oem_fb_big.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
#include "oem_big.h"
using Eigen::MatrixXf;
using Eigen::VectorXf;
using Eigen::MatrixXd;
using Eigen::VectorXd;
using Eigen::VectorXi;
using Eigen::ArrayXf;
using Eigen::ArrayXd;
using Eigen::ArrayXXf;
using Eigen::Map;
using Rcpp::wrap;
using Rcpp::as;
using Rcpp::List;
using Rcpp::Named;
using Rcpp::IntegerVector;
using Rcpp::CharacterVector;
typedef Map<VectorXd> MapVecd;
typedef Map<VectorXi> MapVeci;
typedef Map<Eigen::MatrixXd> MapMatd;
typedef Eigen::SparseVector<double> SpVec;
typedef Eigen::SparseMatrix<double> SpMat;
// we need a separate one for FileBackedBigMatrix objects
RcppExport SEXP oem_fit_fb_big(SEXP x_,
SEXP y_,
SEXP family_,
SEXP penalty_,
SEXP weights_,
SEXP groups_,
SEXP unique_groups_,
SEXP group_weights_,
SEXP lambda_,
SEXP nlambda_,
SEXP lmin_ratio_,
SEXP alpha_,
SEXP gamma_,
SEXP tau_,
SEXP penalty_factor_,
SEXP standardize_,
SEXP intercept_,
SEXP compute_loss_,
SEXP opts_)
{
BEGIN_RCPP
XPtr<FileBackedBigMatrix> bMPtr(x_);
const int n = bMPtr->nrow();
const int p = bMPtr->ncol();
unsigned int typedata = bMPtr->matrix_type();
if (typedata != 8)
{
throw Rcpp::exception("type for provided big.matrix not available");
}
const Map<MatrixXd> X = Map<MatrixXd>((double *)bMPtr->matrix(), bMPtr->nrow(), bMPtr->ncol() );
const Map<VectorXd> Y(as<Map<VectorXd> >(y_));
const VectorXi groups(as<VectorXi>(groups_));
const VectorXi unique_groups(as<VectorXi>(unique_groups_));
// In glmnet, we minimize
// 1/(2n) * ||y - X * beta||^2 + lambda * ||beta||_1
// which is equivalent to minimizing
// 1/2 * ||y - X * beta||^2 + n * lambda * ||beta||_1
//ArrayXd lambda(as<ArrayXd>(lambda_)); // old lambda code
VectorXd weights(as<VectorXd>(weights_));
VectorXd group_weights(as<VectorXd>(group_weights_));
std::vector<VectorXd> lambda(as< std::vector<VectorXd> >(lambda_));
VectorXd lambda_tmp;
lambda_tmp = lambda[0];
int nl = as<int>(nlambda_);
VectorXd lambda_base(nl);
int nlambda = lambda_tmp.size();
List opts(opts_);
const int maxit = as<int>(opts["maxit"]);
const double tol = as<double>(opts["tol"]);
const double gigs = as<double>(opts["gigs"]);
const double alpha = as<double>(alpha_);
const double gamma = as<double>(gamma_);
const double tau = as<double>(tau_);
bool standardize = as<bool>(standardize_);
bool intercept = as<bool>(intercept_);
bool compute_loss = as<bool>(compute_loss_);
CharacterVector family(as<CharacterVector>(family_));
std::vector<std::string> penalty(as< std::vector<std::string> >(penalty_));
VectorXd penalty_factor(as<VectorXd>(penalty_factor_));
if (intercept)
{
// dont penalize the intercept
VectorXd penalty_factor_tmp(p+1);
penalty_factor_tmp << 0, penalty_factor;
penalty_factor.swap(penalty_factor_tmp);
}
// initialize pointers
oemBase<Eigen::VectorXd> *solver = NULL; // solver doesn't point to anything yet
// initialize classes
if (family(0) == "gaussian")
{
solver = new oemBig(X, Y, weights, groups, unique_groups,
group_weights, penalty_factor,
intercept, standardize, tol, gigs);
} else if (family(0) == "binomial")
{
throw std::invalid_argument("binomial not available for oem_fit_dense, use oem_fit_logistic_dense");
//solver = new oem(X, Y, penalty_factor, irls_tol, irls_maxit, eps_abs, eps_rel);
}
// compute initial pieces of oem
solver->init_oem();
double lmax = 0.0;
lmax = solver->compute_lambda_zero(); //
bool provided_lambda = false;
if (nlambda < 1)
{
double lmin = as<double>(lmin_ratio_) * lmax;
lambda_base.setLinSpaced(nl, std::log(lmax), std::log(lmin));
lambda_base = lambda_base.array().exp();
nlambda = lambda_base.size();
lambda_tmp.resize(nlambda);
} else
{
provided_lambda = true;
}
MatrixXd beta(p + 1, nlambda);
List beta_list(penalty.size());
List iter_list(penalty.size());
List loss_list(penalty.size());
IntegerVector niter(nlambda);
int nlambda_store = nlambda;
double ilambda = 0.0;
std::string elasticnettxt(".net");
for (unsigned int pp = 0; pp < penalty.size(); pp++)
{
if (penalty[pp] == "ols")
{
nlambda = 1L;
}
bool is_net_pen = penalty[pp].find(elasticnettxt) != std::string::npos;
if (provided_lambda)
{
lambda_tmp = lambda[pp];
} else
{
if (is_net_pen)
{
lambda_tmp = (lambda_base.array() / alpha).matrix(); // * n; //
} else
{
lambda_tmp = lambda_base; // * n; //
}
}
VectorXd loss(nlambda);
loss.fill(1e99);
for(int i = 0; i < nlambda; i++)
{
if (i % 3 == 0)
{
Rcpp::checkUserInterrupt();
}
ilambda = lambda_tmp(i);
if(i == 0)
solver->init(ilambda, penalty[pp],
alpha, gamma, tau);
else
solver->init_warm(ilambda);
niter[i] = solver->solve(maxit);
VectorXd res = solver->get_beta();
if (intercept)
{
beta.col(i) = res;
} else
{
beta(0,i) = 0.0;
beta.block(1, i, p, 1) = res;
}
if (compute_loss)
{
// get associated loss
loss(i) = solver->get_loss();
}
} //end loop over lambda values
lambda[pp] = lambda_tmp;
if (penalty[pp] == "ols")
{
// reset to old nlambda
nlambda = nlambda_store;
beta_list(pp) = beta.col(0);
iter_list(pp) = niter(0);
loss_list(pp) = loss(0);
} else
{
beta_list(pp) = beta;
iter_list(pp) = niter;
loss_list(pp) = loss;
}
} // end loop over penalties
double d = solver->get_d();
delete solver;
return List::create(Named("beta") = beta_list,
Named("lambda") = lambda,
Named("niter") = iter_list,
Named("loss") = loss_list,
Named("d") = d);
END_RCPP
}