Skip to content

Commit

Permalink
Merge pull request CamDavidsonPilon#130 from CamDavidsonPilon/model_r…
Browse files Browse the repository at this point in the history
…outines

Model routines
  • Loading branch information
aprotopopov authored Aug 20, 2017
2 parents fea8416 + 6b60cb9 commit e8d5750
Show file tree
Hide file tree
Showing 5 changed files with 156 additions and 56 deletions.
10 changes: 9 additions & 1 deletion docs/Saving and loading model.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,20 @@ bgf

### Saving model

Model will be saved with [dill](https://github.com/uqfoundation/dill) to pickle object. Optional parameter `save_data` is used for saving data from model or not (default: `True`).
Model will be saved with [dill](https://github.com/uqfoundation/dill) to pickle object. Optional parameters `save_data` and `save_generate_data_method` are present to reduce final pickle object size for big dataframes.
Optional parameters:
- `save_data` is used for saving data from model or not (default: `True`).
- `save_generate_data_method` is used for saving `generate_new_data` method from model or not (default: `True`)

```python
bgf.save_model('bgf.pkl')
```

or to save only model with minumum size without `data` and `generate_new_data`:
```python
bgf.save_model('bgf_small_size.pkl', save_data=False, save_generate_data_method=False)
```

### Loading model

Before loading you should initialize the model first and then use method `load_model`
Expand Down
53 changes: 1 addition & 52 deletions lifetimes/fitters/__init__.py
Original file line number Diff line number Diff line change
@@ -1,53 +1,2 @@
"""Init for fitters."""

import dill


class BaseFitter(object):
"""Base class for fitters."""

def __repr__(self):
"""Representation of fitter."""
classname = self.__class__.__name__
try:
param_str = ", ".join("%s: %.2f" % (param, value) for param, value
in sorted(self.params_.items()))
return "<lifetimes.%s: fitted with %d subjects, %s>" % (
classname, self.data.shape[0], param_str)
except AttributeError:
return "<lifetimes.%s>" % classname

def _unload_params(self, *args):
if not hasattr(self, 'params_'):
raise ValueError("Model has not been fit yet. Please call the .fit"
" method first.")
return [self.params_[x] for x in args]

def save_model(self, path, save_data=True):
"""
Save model with dill package.
Parameters:
path: Path where to save model.
save_date: Whether to save data from fitter.data to pickle object
"""
with open(path, 'wb') as out_file:
if save_data:
dill.dump(self, out_file)
else:
self_data = self.data.copy()
self.data = []
dill.dump(self, out_file)
self.data = self_data

def load_model(self, path):
"""
Save model with dill package.
Parameters:
path: From what path load model.
"""
with open(path, 'rb') as in_file:
self.__dict__.update(dill.load(in_file).__dict__)
from .base_fitter import BaseFitter
66 changes: 66 additions & 0 deletions lifetimes/fitters/base_fitter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
"""Base fitter for other classes."""
import dill
from ..utils import _save_obj_without_attr


class BaseFitter(object):
"""Base class for fitters."""

def __repr__(self):
"""Representation of fitter."""
classname = self.__class__.__name__
try:
subj_str = " fitted with {:d} subjects,".format(self.data.shape[0])
except AttributeError:
subj_str = ""

try:
param_str = ", ".join("{}: {:.2f}".format(par, val) for par, val
in sorted(self.params_.items()))
return "<lifetimes.{classname}:{subj_str} {param_str}>".format(
classname=classname, subj_str=subj_str, param_str=param_str)
except AttributeError:
return "<lifetimes.{classname}>".format(classname=classname)

def _unload_params(self, *args):
if not hasattr(self, 'params_'):
raise ValueError("Model has not been fit yet. Please call the .fit"
" method first.")
return [self.params_[x] for x in args]

def save_model(self, path, save_data=True, save_generate_data_method=True,
values_to_save=None):
"""
Save model with dill package.
Parameters
----------
path: str
Path where to save model.
save_date: bool, optional
Whether to save data from fitter.data to pickle object
save_generate_data_method: bool, optional
Whether to save generate_new_data method (if it exists) from
fitter.generate_new_data to pickle object.
values_to_save: list, optional
Placeholders for original attributes for saving object. If None
will be extended to attr_list length like [None] * len(attr_list)
"""
attr_list = ['data' * (not save_data),
'generate_new_data' * (not save_generate_data_method)]
_save_obj_without_attr(self, attr_list, path,
values_to_save=values_to_save)

def load_model(self, path):
"""
Load model with dill package.
Parameters
----------
path: str
From what path load model.
"""
with open(path, 'rb') as in_file:
self.__dict__.update(dill.load(in_file).__dict__)
35 changes: 35 additions & 0 deletions lifetimes/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""Lifetimes utils and helpers."""
from datetime import datetime

import numpy as np
import pandas as pd
import dill
from scipy.optimize import minimize

pd.options.mode.chained_assignment = None
Expand Down Expand Up @@ -466,3 +468,36 @@ def expected_cumulative_transactions(model, transactions, datetime_col, customer
'predicted': pred_cum_transactions}, index=date_index)

return df_cum_transactions


def _save_obj_without_attr(obj, attr_list, path, values_to_save=None):
"""Helper to save object with attributes from attr_list.
Parameters
----------
obj: obj
Object of class with __dict__ attribute.
attr_list: list
List with attributes to exclude from saving to dill object. If empty
list all attributes will be saved.
path: str
Where to save dill object.
values_to_save: list, optional
Placeholders for original attributes for saving object. If None will be
extended to attr_list length like [None] * len(attr_list)
"""
if values_to_save is None:
values_to_save = [None] * len(attr_list)

saved_attr_dict = {}
for attr, val_save in zip(attr_list, values_to_save):
if attr in obj.__dict__:
item = obj.__dict__.pop(attr)
saved_attr_dict[attr] = item
setattr(obj, attr, val_save)

with open(path, 'wb') as out_file:
dill.dump(obj, out_file)

for attr, item in saved_attr_dict.items():
setattr(obj, attr, item)
48 changes: 45 additions & 3 deletions tests/test_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
def cdnow_customers():
return load_cdnow_summary()


cdnow_customers_with_monetary_value = load_cdnow_summary_data_with_monetary_value()
donations = load_donations()
PATH_SAVE_MODEL = './base_fitter.pkl'
Expand All @@ -31,6 +32,8 @@ def test_repr(self):
base_fitter.params_ = dict(x=12.3, y=42)
base_fitter.data = np.array([1, 2, 3])
assert repr(base_fitter) == '<lifetimes.BaseFitter: fitted with 3 subjects, x: 12.30, y: 42.00>'
base_fitter.data = None
assert repr(base_fitter) == '<lifetimes.BaseFitter: x: 12.30, y: 42.00>'

def test_unload_params(self):
base_fitter = estimation.BaseFitter()
Expand All @@ -39,7 +42,6 @@ def test_unload_params(self):
base_fitter.params_ = dict(x=12.3, y=42)
npt.assert_array_almost_equal([12.3, 42], base_fitter._unload_params('x', 'y'))


def test_save_load_model(self):
base_fitter = estimation.BaseFitter()
base_fitter.save_model(PATH_SAVE_MODEL)
Expand All @@ -52,7 +54,6 @@ def test_save_load_model(self):
os.remove(PATH_SAVE_MODEL)



class TestBetaGeoBetaBinomFitter():

def test_params_out_is_close_to_Hardie_paper(self):
Expand Down Expand Up @@ -145,6 +146,7 @@ def test_fit_with_index(self):
)
assert (bbtf.data.index == index).all() == False


class TestGammaGammaFitter():

def test_params_out_is_close_to_Hardie_paper(self):
Expand Down Expand Up @@ -482,6 +484,7 @@ def test_scaling_inputs_gives_same_or_similar_results(self, cdnow_customers):
assert abs(bgf_with_large_inputs.conditional_probability_alive(1, scale * 2, scale * 10) - bgf.conditional_probability_alive(1, 2, 10)) < 10e-5

def test_save_load_bgnbd(self, cdnow_customers):
"""Test saving and loading model for BG/NBD."""
bgf = estimation.BetaGeoFitter(penalizer_coef=0.0)
bgf.fit(cdnow_customers['frequency'], cdnow_customers['recency'], cdnow_customers['T'])
bgf.save_model(PATH_SAVE_BGNBD_MODEL)
Expand All @@ -499,6 +502,7 @@ def test_save_load_bgnbd(self, cdnow_customers):
os.remove(PATH_SAVE_BGNBD_MODEL)

def test_save_load_bgnbd_no_data(self, cdnow_customers):
"""Test saving and loading model for BG/NBD without data."""
bgf = estimation.BetaGeoFitter(penalizer_coef=0.0)
bgf.fit(cdnow_customers['frequency'], cdnow_customers['recency'], cdnow_customers['T'])
bgf.save_model(PATH_SAVE_BGNBD_MODEL, save_data=False)
Expand All @@ -512,7 +516,45 @@ def test_save_load_bgnbd_no_data(self, cdnow_customers):
assert bgf_new.__dict__['predict'](1, 1, 2, 5) == bgf.__dict__['predict'](1, 1, 2, 5)
assert bgf_new.expected_number_of_purchases_up_to_time(1) == bgf.expected_number_of_purchases_up_to_time(1)

assert isinstance(bgf_new.__dict__['data'], list)
assert bgf_new.__dict__['data'] is None
# remove saved model
os.remove(PATH_SAVE_BGNBD_MODEL)

def test_save_load_bgnbd_no_generate_data(self, cdnow_customers):
"""Test saving and loading model for BG/NBD without generate_new_data method."""
bgf = estimation.BetaGeoFitter(penalizer_coef=0.0)
bgf.fit(cdnow_customers['frequency'], cdnow_customers['recency'], cdnow_customers['T'])
bgf.save_model(PATH_SAVE_BGNBD_MODEL, save_generate_data_method=False)

bgf_new = estimation.BetaGeoFitter()
bgf_new.load_model(PATH_SAVE_BGNBD_MODEL)
assert bgf_new.__dict__['penalizer_coef'] == bgf.__dict__['penalizer_coef']
assert bgf_new.__dict__['_scale'] == bgf.__dict__['_scale']
assert bgf_new.__dict__['params_'] == bgf.__dict__['params_']
assert bgf_new.__dict__['_negative_log_likelihood_'] == bgf.__dict__['_negative_log_likelihood_']
assert bgf_new.__dict__['predict'](1, 1, 2, 5) == bgf.__dict__['predict'](1, 1, 2, 5)
assert bgf_new.expected_number_of_purchases_up_to_time(1) == bgf.expected_number_of_purchases_up_to_time(1)

assert bgf_new.__dict__['generate_new_data'] is None
# remove saved model
os.remove(PATH_SAVE_BGNBD_MODEL)

def test_save_load_bgnbd_no_data_replace_with_empty_str(self, cdnow_customers):
"""Test saving and loading model for BG/NBD without data with replaced value empty str."""
bgf = estimation.BetaGeoFitter(penalizer_coef=0.0)
bgf.fit(cdnow_customers['frequency'], cdnow_customers['recency'], cdnow_customers['T'])
bgf.save_model(PATH_SAVE_BGNBD_MODEL, save_data=False, values_to_save=[''])

bgf_new = estimation.BetaGeoFitter()
bgf_new.load_model(PATH_SAVE_BGNBD_MODEL)
assert bgf_new.__dict__['penalizer_coef'] == bgf.__dict__['penalizer_coef']
assert bgf_new.__dict__['_scale'] == bgf.__dict__['_scale']
assert bgf_new.__dict__['params_'] == bgf.__dict__['params_']
assert bgf_new.__dict__['_negative_log_likelihood_'] == bgf.__dict__['_negative_log_likelihood_']
assert bgf_new.__dict__['predict'](1, 1, 2, 5) == bgf.__dict__['predict'](1, 1, 2, 5)
assert bgf_new.expected_number_of_purchases_up_to_time(1) == bgf.expected_number_of_purchases_up_to_time(1)

assert bgf_new.__dict__['data'] is ''
# remove saved model
os.remove(PATH_SAVE_BGNBD_MODEL)

Expand Down

0 comments on commit e8d5750

Please sign in to comment.