diff --git a/pyearth/earth.py b/pyearth/earth.py index cba90e8..950ee78 100644 --- a/pyearth/earth.py +++ b/pyearth/earth.py @@ -7,7 +7,8 @@ check_X_y) import numpy as np from scipy import sparse - +from ._version import get_versions +__version__ = get_versions()['version'] class Earth(BaseEstimator, RegressorMixin, TransformerMixin): @@ -254,6 +255,11 @@ class Earth(BaseEstimator, RegressorMixin, TransformerMixin): array of shape m. If several feature importance types are specified, then it is dict where each key is a feature importance type name and its corresponding value is an array of shape m. + + `_version`: string + The version of py-earth in which the Earth object was originally + created. This information may be useful when dealing with + serialized Earth objects. References @@ -317,6 +323,7 @@ def __init__(self, max_terms=None, max_degree=None, allow_missing=False, self.enable_pruning = enable_pruning self.feature_importance_type = feature_importance_type self.verbose = verbose + self._version = __version__ def __eq__(self, other): if self.__class__ is not other.__class__: diff --git a/pyearth/test/test_earth.py b/pyearth/test/test_earth.py index 6fc8118..fc41377 100644 --- a/pyearth/test/test_earth.py +++ b/pyearth/test/test_earth.py @@ -18,6 +18,7 @@ from pyearth._basis import (Basis, ConstantBasisFunction, HingeBasisFunction, LinearBasisFunction) from pyearth import Earth +import pyearth numpy.random.seed(0) @@ -306,6 +307,16 @@ def test_pickle_compatibility(): assert_true(model_copy.basis_[0] is model_copy.basis_[1]._get_root()) +def test_pickle_version_storage(): + earth = Earth(**default_params) + model = earth.fit(X, y) + assert_equal(model._version, pyearth.__version__) + model._version = 'hello' + assert_equal(model._version,'hello') + model_copy = pickle.loads(pickle.dumps(model)) + assert_equal(model_copy._version, model._version) + + def test_copy_compatibility(): model = Earth(**default_params).fit(X, y) model_copy = copy.copy(model)