Skip to content

Commit

Permalink
Feature/swap out xgboost (#16)
Browse files Browse the repository at this point in the history
* removed xgboost from TMLE

* updated comments and docs

* updated more docs
  • Loading branch information
ronikobrosly authored Aug 3, 2020
1 parent a91f981 commit ba94fe1
Show file tree
Hide file tree
Showing 9 changed files with 20 additions and 37 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ env:

before_install:
# Here we download miniconda and install the dependencies
- pip install black coverage future joblib numpy numpydoc pandas patsy progressbar2 pygam pytest python-dateutil python-utils pytz scikit-learn scipy six statsmodels xgboost
- pip install black coverage future joblib numpy numpydoc pandas patsy progressbar2 pygam pytest python-dateutil python-utils pytz scikit-learn scipy six statsmodels

install:
- python setup.py install
Expand Down
35 changes: 8 additions & 27 deletions causal_curve/tmle.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from pandas.api.types import is_float_dtype, is_numeric_dtype
from scipy.interpolate import interp1d
from scipy.stats import norm
from sklearn.ensemble import GradientBoostingClassifier, GradientBoostingRegressor
from statsmodels.genmod.generalized_linear_model import GLM
from xgboost import XGBClassifier, XGBRegressor

from causal_curve.core import Core
from causal_curve.utils import rand_seed_wrapper
Expand All @@ -19,7 +19,7 @@
class TMLE(Core):
"""
Constructs a causal dose response curve through a series of TMLE comparisons across a grid
of the treatment values. XGBoost is used for prediction in Q model and G model.
of the treatment values. Gradient boosting is used for prediction in Q model and G model.
Assumes continuous treatment and outcome variable.
WARNING:
Expand Down Expand Up @@ -47,18 +47,14 @@ class TMLE(Core):
treatment values between the bin edges will be used to generate the CDRC.
n_estimators: int, optional (default = 100)
Optional argument to set the number of learners to use when XGBoost
Optional argument to set the number of learners to use when sklearn
creates TMLE's Q and G models.
learning_rate: float, optional (default = 0.1)
Optional argument to set the XGBoost's learning rate for TMLE's Q and G models.
Optional argument to set the sklearn's learning rate for TMLE's Q and G models.
max_depth: int, optional (default = 5)
Optional argument to set XGBoost's maximum depth when creating TMLE's Q and G models.
gamma: float, optional (default = 1.0)
Optional argument to set XGBoost's gamma parameter (regularization) when
creating TMLE's Q and G models.
Optional argument to set sklearn's maximum depth when creating TMLE's Q and G models.
random_seed: int, optional (default = None)
Sets the random seed.
Expand Down Expand Up @@ -115,7 +111,6 @@ def __init__(
n_estimators=100,
learning_rate=0.1,
max_depth=5,
gamma=1.0,
random_seed=None,
verbose=False,
):
Expand All @@ -124,7 +119,6 @@ def __init__(
self.n_estimators = n_estimators
self.learning_rate = learning_rate
self.max_depth = max_depth
self.gamma = gamma
self.random_seed = random_seed
self.verbose = verbose

Expand Down Expand Up @@ -190,16 +184,6 @@ def _validate_init_params(self):
if self.max_depth <= 0:
raise TypeError("max_depth parameter must be greater than 0")

# Checks for gamma
if not isinstance(self.gamma, float):
raise TypeError(
f"gamma parameter must be a float, "
f"but found type {type(self.gamma)}"
)

if self.gamma <= 0:
raise TypeError("gamma parameter must be greater than 0")

# Checks for random_seed
if not isinstance(self.random_seed, (int, type(None))):
raise TypeError(
Expand Down Expand Up @@ -263,11 +247,10 @@ def _initial_bucket_mean_prediction(self):
self.t_data < self.treatment_grid_bins[1]
]

init_model = XGBRegressor(
init_model = GradientBoostingRegressor(
n_estimators=self.n_estimators,
max_depth=self.max_depth,
learning_rate=self.learning_rate,
gamma=self.gamma,
random_state=self.random_seed,
).fit(X, y)

Expand Down Expand Up @@ -495,11 +478,10 @@ def _q_model(self, temp_y, temp_x, temp_t):
X = pd.concat([temp_t, temp_x], axis=1).to_numpy()
y = temp_y.to_numpy()

Q_model = XGBRegressor(
Q_model = GradientBoostingRegressor(
n_estimators=self.n_estimators,
max_depth=self.max_depth,
learning_rate=self.learning_rate,
gamma=self.gamma,
random_state=self.random_seed,
).fit(X, y)

Expand All @@ -525,11 +507,10 @@ def _g_model(self, temp_x, temp_t):
X = temp_x.to_numpy()
t = temp_t.to_numpy()

G_model = XGBClassifier(
G_model = GradientBoostingClassifier(
n_estimators=self.n_estimators,
max_depth=self.max_depth,
learning_rate=self.learning_rate,
gamma=self.gamma,
random_state=self.random_seed,
).fit(X, t)

Expand Down
4 changes: 2 additions & 2 deletions docs/TMLE_example.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ tool to estimate the marginal causal curve of some continuous treatment on a con
accounting for some mild confounding effects.

Compared with the package's GPS method, this TMLE method is double robust against model
misspecification, incorporates more powerful machine learning techniques internally,
misspecification, incorporates more powerful machine learning techniques internally (gradient boosting),
produces significantly smaller confidence intervals, however it is not computationally efficient
and will take longer to run.

Expand All @@ -39,7 +39,7 @@ References
----------

van der Laan MJ and Rubin D. Targeted maximum likelihood learning. In: ​U.C. Berkeley Division of
Biostatistics Working Paper Series, 2006.
Biostatistics Working Paper Series, 2006.

van der Laan MJ and Gruber S. Collaborative double robust penalized targeted
maximum likelihood estimation. In: The International Journal of Biostatistics 6(1), 2010.
6 changes: 6 additions & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@
Change Log
==========

Version 0.3.4
-------------
- Removed XGBoost as dependency.
- Now using sklearn's gradient boosting implementation.


Version 0.3.3
-------------
- Misc edits to paper and bibliography
Expand Down
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
author = 'Roni Kobrosly'

# The full version, including alpha/beta/rc tags
release = '0.3.3'
release = '0.3.4'

# -- General configuration ---------------------------------------------------

Expand Down
1 change: 0 additions & 1 deletion docs/install.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ causal-curve requires:
- scipy
- six
- statsmodels
- xgboost



Expand Down
1 change: 0 additions & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,3 @@ scikit-learn
scipy
six
statsmodels
xgboost
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,3 @@ scikit-learn
scipy
six
statsmodels
xgboost
5 changes: 2 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setuptools.setup(
name="causal-curve",
version="0.3.3",
version="0.3.4",
author="Roni Kobrosly",
author_email="[email protected]",
description="A python library with tools to perform causal inference using \
Expand Down Expand Up @@ -38,7 +38,6 @@
'scikit-learn',
'scipy',
'six',
'statsmodels',
'xgboost'
'statsmodels'
]
)

0 comments on commit ba94fe1

Please sign in to comment.