Skip to content

Commit

Permalink
updates to run.sh
Browse files Browse the repository at this point in the history
  • Loading branch information
avati committed Sep 14, 2019
1 parent fde5745 commit befed87
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 12 deletions.
22 changes: 15 additions & 7 deletions examples/empirical/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from sklearn.model_selection import KFold

np.random.seed(123)
np.random.seed(1)

dataset_name_to_loader = {
"housing": lambda: pd.read_csv('https://archive.ics.uci.edu/ml/machine-learning-databases/housing/housing.data', header=None, delim_whitespace=True),
Expand Down Expand Up @@ -46,7 +46,7 @@
argparser = ArgumentParser()
argparser.add_argument("--dataset", type=str, default="concrete")
argparser.add_argument("--reps", type=int, default=5)
argparser.add_argument("--n-est", type=int, default=300)
argparser.add_argument("--n-est", type=int, default=200)
argparser.add_argument("--n-splits", type=int, default=20)
argparser.add_argument("--distn", type=str, default="Normal")
argparser.add_argument("--lr", type=float, default=0.1)
Expand Down Expand Up @@ -78,13 +78,15 @@
kf = KFold(n_splits=args.n_splits)
folds = kf.split(X)

breakpoint()
#breakpoint()

for itr, (train_index, test_index) in enumerate(folds):
X_train, X_test = X[train_index], X[test_index]
y_train, y_test = y[train_index], y[test_index]

y_true += list(y_test.flatten())

X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.2)

ngb = NGBoost(Base=base_name_to_learner[args.base],
Dist=eval(args.distn),
Expand All @@ -96,11 +98,17 @@
verbose=args.verbose)

train_loss, val_loss = ngb.fit(X_train, y_train) #, X_val, y_val)
forecast = ngb.pred_dist(X_test)

y_preds = ngb.staged_predict(X_val)
val_rmse = [mean_squared_error(y_pred, y_val) for y_pred in y_preds]
best_itr = np.argmin(val_rmse) + 1
print('[%d] Best itr: %d (%.4f)' % (itr+1, best_itr, np.sqrt(val_rmse[best_itr-1])))

forecast = ngb.pred_dist(X_test, max_iter=best_itr)

y_ngb += list(forecast.loc)

if args.verbose or True:
if args.verbose:
print("[%d/%d] %s/%s RMSE=%.4f" % (itr+1, args.n_splits, args.score, args.distn,
np.sqrt(mean_squared_error(forecast.loc, y_test))))

Expand All @@ -116,8 +124,8 @@

y_gbm += list(y_pred.flatten())

if args.verbose or True:
print("[%d/%d] GBR RMSE=%.4f" % (itr+1, args.n_splits,
if args.verbose:
print("[%d/%d] GBM RMSE=%.4f" % (itr+1, args.n_splits,
np.sqrt(mean_squared_error(y_pred.flatten(), y_test.flatten()))))
gbrlog.tick(forecast, y_test)

Expand Down
24 changes: 22 additions & 2 deletions ngboost/ngboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,9 @@ def fit_base(self, X, grads):
self.base_models.append(models)
return fitted

def line_search(self, resids, start, Y):
def line_search(self, resids, start, Y, scale_init=1):
loss_init = self.loss_fn(start, Y).mean()
scale = 1
scale = scale_init
while True:
scaled_resids = resids * scale
loss = self.loss_fn(start - scaled_resids, Y).mean()
Expand Down Expand Up @@ -110,6 +110,9 @@ def fit(self, X, Y, X_val = None, Y_val = None):
if self.natural_gradient:
grads = self.Score.naturalize(P_batch, grads)

#scale = self.line_search(grads, P_batch, Y_batch, scale_init=1)
#grads = grads * scale

if np.any(np.isnan(grads)) or np.any(np.isinf(grads)):
print(grads)
grads = self.grad_fn(P_batch, Y_batch)
Expand Down Expand Up @@ -163,3 +166,20 @@ def pred_dist(self, X, max_iter=None):
params = onp.asarray(self.pred_param(X, max_iter))
dist = self.Dist(params.T)
return dist

def predict(self, X):
dists = self.pred_dist(X)
return list(dist.loc.flatten())

def staged_predict(self, X, max_iter=None):
predictions = []
m, n = X.shape
params = np.ones((m, self.Dist.n_params)) * self.init_params
for i, (models, s) in enumerate(zip(self.base_models, self.scalings)):
if max_iter and i == max_iter:
break
resids = np.array([model.predict(X) for model in models]).T
params -= self.learning_rate * resids * s
dists = self.Dist(onp.asarray(params).T)
predictions.append(dists.loc.flatten())
return predictions
15 changes: 12 additions & 3 deletions scripts/run_empirical_regression.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ python3 -m examples.empirical.regression --dataset=naval --score=MLE --natural
python3 -m examples.empirical.regression --dataset=power --score=MLE --natural
python3 -m examples.empirical.regression --dataset=energy --score=MLE --natural
python3 -m examples.empirical.regression --dataset=yacht --score=MLE --natural
python3 -m examples.empirical.regression --dataset=protein --score=MLE --natural --n_splits=5
python3 -m examples.empirical.regression --dataset=protein --score=MLE --natural --n-splits=5
python3 -m examples.empirical.regression --dataset=msd --score=MLE --natural

exit

python3 -m examples.empirical.regression --dataset=concrete --score=CRPS --natural
python3 -m examples.empirical.regression --dataset=wine --score=CRPS --natural
Expand All @@ -20,7 +20,10 @@ python3 -m examples.empirical.regression --dataset=power --score=CRPS --natural
python3 -m examples.empirical.regression --dataset=energy --score=CRPS --natural
python3 -m examples.empirical.regression --dataset=yacht --score=CRPS --natural
python3 -m examples.empirical.regression --dataset=housing --score=CRPS --natural
python3 -m examples.empirical.regression --dataset=protein --score=CRPS --natural
python3 -m examples.empirical.regression --dataset=protein --score=CRPS --natural --n-splits=5
python3 -m examples.empirical.regression --dataset=msd --score=CRPS --natural

exit

python3 -m examples.empirical.regression --dataset=concrete --score=MLE --distn=HomoskedasticNormal
python3 -m examples.empirical.regression --dataset=wine --score=MLE --distn=HomoskedasticNormal
Expand All @@ -30,6 +33,8 @@ python3 -m examples.empirical.regression --dataset=power --score=MLE --distn=Ho
python3 -m examples.empirical.regression --dataset=energy --score=MLE --distn=HomoskedasticNormal
python3 -m examples.empirical.regression --dataset=yacht --score=MLE --distn=HomoskedasticNormal
python3 -m examples.empirical.regression --dataset=housing --score=MLE --distn=HomoskedasticNormal
python3 -m examples.empirical.regression --dataset=msd --score=MLE --distn=HomoskedasticNormal
python3 -m examples.empirical.regression --dataset=protein --score=MLE --distn=HomoskedasticNormal --n-splits=5


python3 -m examples.empirical.regression --dataset=concrete --score=CRPS --natural --distn=Laplace
Expand All @@ -40,6 +45,8 @@ python3 -m examples.empirical.regression --dataset=power --score=CRPS --natural
python3 -m examples.empirical.regression --dataset=energy --score=CRPS --natural --distn=Laplace
python3 -m examples.empirical.regression --dataset=yacht --score=CRPS --natural --distn=Laplace
python3 -m examples.empirical.regression --dataset=housing --score=CRPS --natural --distn=Laplace
python3 -m examples.empirical.regression --dataset=protein --score=CRPS --natural --distn=Laplace --n-splits=5
python3 -m examples.empirical.regression --dataset=msd --score=CRPS --natural --distn=Laplace

python3 -m examples.empirical.regression --dataset=concrete --score=MLE --natural --distn=Laplace
python3 -m examples.empirical.regression --dataset=wine --score=MLE --natural --distn=Laplace
Expand All @@ -49,3 +56,5 @@ python3 -m examples.empirical.regression --dataset=power --score=MLE --natural
python3 -m examples.empirical.regression --dataset=energy --score=MLE --natural --distn=Laplace
python3 -m examples.empirical.regression --dataset=yacht --score=MLE --natural --distn=Laplace
python3 -m examples.empirical.regression --dataset=housing --score=MLE --natural --distn=Laplace
python3 -m examples.empirical.regression --dataset=protein --score=MLE --natural --distn=Laplace --n-splits=5
python3 -m examples.empirical.regression --dataset=msd --score=MLE --natural --distn=Laplace

0 comments on commit befed87

Please sign in to comment.