Skip to content

Commit

Permalink
Merge pull request scikit-learn-contrib#153 from scikit-learn-contrib…
Browse files Browse the repository at this point in the history
…/versioneering

Store pyearth version information in the Earth object.
  • Loading branch information
jcrudy authored Apr 17, 2017
2 parents 17a388d + a06c0da commit 6e989a9
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 1 deletion.
9 changes: 8 additions & 1 deletion pyearth/earth.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
check_X_y)
import numpy as np
from scipy import sparse

from ._version import get_versions
__version__ = get_versions()['version']

class Earth(BaseEstimator, RegressorMixin, TransformerMixin):

Expand Down Expand Up @@ -254,6 +255,11 @@ class Earth(BaseEstimator, RegressorMixin, TransformerMixin):
array of shape m. If several feature importance types are
specified, then it is dict where each key is a feature importance type
name and its corresponding value is an array of shape m.
`_version`: string
The version of py-earth in which the Earth object was originally
created. This information may be useful when dealing with
serialized Earth objects.
References
Expand Down Expand Up @@ -317,6 +323,7 @@ def __init__(self, max_terms=None, max_degree=None, allow_missing=False,
self.enable_pruning = enable_pruning
self.feature_importance_type = feature_importance_type
self.verbose = verbose
self._version = __version__

def __eq__(self, other):
if self.__class__ is not other.__class__:
Expand Down
11 changes: 11 additions & 0 deletions pyearth/test/test_earth.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from pyearth._basis import (Basis, ConstantBasisFunction,
HingeBasisFunction, LinearBasisFunction)
from pyearth import Earth
import pyearth

numpy.random.seed(0)

Expand Down Expand Up @@ -306,6 +307,16 @@ def test_pickle_compatibility():
assert_true(model_copy.basis_[0] is model_copy.basis_[1]._get_root())


def test_pickle_version_storage():
earth = Earth(**default_params)
model = earth.fit(X, y)
assert_equal(model._version, pyearth.__version__)
model._version = 'hello'
assert_equal(model._version,'hello')
model_copy = pickle.loads(pickle.dumps(model))
assert_equal(model_copy._version, model._version)


def test_copy_compatibility():
model = Earth(**default_params).fit(X, y)
model_copy = copy.copy(model)
Expand Down

0 comments on commit 6e989a9

Please sign in to comment.