forked from CamDavidsonPilon/lifetimes
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request CamDavidsonPilon#130 from CamDavidsonPilon/model_r…
…outines Model routines
- Loading branch information
Showing
5 changed files
with
156 additions
and
56 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,53 +1,2 @@ | ||
"""Init for fitters.""" | ||
|
||
import dill | ||
|
||
|
||
class BaseFitter(object): | ||
"""Base class for fitters.""" | ||
|
||
def __repr__(self): | ||
"""Representation of fitter.""" | ||
classname = self.__class__.__name__ | ||
try: | ||
param_str = ", ".join("%s: %.2f" % (param, value) for param, value | ||
in sorted(self.params_.items())) | ||
return "<lifetimes.%s: fitted with %d subjects, %s>" % ( | ||
classname, self.data.shape[0], param_str) | ||
except AttributeError: | ||
return "<lifetimes.%s>" % classname | ||
|
||
def _unload_params(self, *args): | ||
if not hasattr(self, 'params_'): | ||
raise ValueError("Model has not been fit yet. Please call the .fit" | ||
" method first.") | ||
return [self.params_[x] for x in args] | ||
|
||
def save_model(self, path, save_data=True): | ||
""" | ||
Save model with dill package. | ||
Parameters: | ||
path: Path where to save model. | ||
save_date: Whether to save data from fitter.data to pickle object | ||
""" | ||
with open(path, 'wb') as out_file: | ||
if save_data: | ||
dill.dump(self, out_file) | ||
else: | ||
self_data = self.data.copy() | ||
self.data = [] | ||
dill.dump(self, out_file) | ||
self.data = self_data | ||
|
||
def load_model(self, path): | ||
""" | ||
Save model with dill package. | ||
Parameters: | ||
path: From what path load model. | ||
""" | ||
with open(path, 'rb') as in_file: | ||
self.__dict__.update(dill.load(in_file).__dict__) | ||
from .base_fitter import BaseFitter |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
"""Base fitter for other classes.""" | ||
import dill | ||
from ..utils import _save_obj_without_attr | ||
|
||
|
||
class BaseFitter(object): | ||
"""Base class for fitters.""" | ||
|
||
def __repr__(self): | ||
"""Representation of fitter.""" | ||
classname = self.__class__.__name__ | ||
try: | ||
subj_str = " fitted with {:d} subjects,".format(self.data.shape[0]) | ||
except AttributeError: | ||
subj_str = "" | ||
|
||
try: | ||
param_str = ", ".join("{}: {:.2f}".format(par, val) for par, val | ||
in sorted(self.params_.items())) | ||
return "<lifetimes.{classname}:{subj_str} {param_str}>".format( | ||
classname=classname, subj_str=subj_str, param_str=param_str) | ||
except AttributeError: | ||
return "<lifetimes.{classname}>".format(classname=classname) | ||
|
||
def _unload_params(self, *args): | ||
if not hasattr(self, 'params_'): | ||
raise ValueError("Model has not been fit yet. Please call the .fit" | ||
" method first.") | ||
return [self.params_[x] for x in args] | ||
|
||
def save_model(self, path, save_data=True, save_generate_data_method=True, | ||
values_to_save=None): | ||
""" | ||
Save model with dill package. | ||
Parameters | ||
---------- | ||
path: str | ||
Path where to save model. | ||
save_date: bool, optional | ||
Whether to save data from fitter.data to pickle object | ||
save_generate_data_method: bool, optional | ||
Whether to save generate_new_data method (if it exists) from | ||
fitter.generate_new_data to pickle object. | ||
values_to_save: list, optional | ||
Placeholders for original attributes for saving object. If None | ||
will be extended to attr_list length like [None] * len(attr_list) | ||
""" | ||
attr_list = ['data' * (not save_data), | ||
'generate_new_data' * (not save_generate_data_method)] | ||
_save_obj_without_attr(self, attr_list, path, | ||
values_to_save=values_to_save) | ||
|
||
def load_model(self, path): | ||
""" | ||
Load model with dill package. | ||
Parameters | ||
---------- | ||
path: str | ||
From what path load model. | ||
""" | ||
with open(path, 'rb') as in_file: | ||
self.__dict__.update(dill.load(in_file).__dict__) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters