Skip to content

Commit

Permalink
fix clone method in main classes
Browse files Browse the repository at this point in the history
  • Loading branch information
rodrigo-arenas committed Mar 2, 2023
1 parent 3afd98b commit 2407115
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 54 deletions.
9 changes: 9 additions & 0 deletions docs/release_notes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,15 @@ Release Notes

Some notes on new features in various releases

What's new in 0.10.1dev
-----------------------

^^^^^^^^^^
Bug Fixes:
^^^^^^^^^^

* Fixed a bug that wouldn't allow to clone the GA classes when used inside a pipeline

What's new in 0.10.0
--------------------

Expand Down
2 changes: 1 addition & 1 deletion sklearn_genetic/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.10.0"
__version__ = "0.10.1dev0"
66 changes: 13 additions & 53 deletions sklearn_genetic/genetic_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,9 +236,7 @@ def __init__(
return_train_score=False,
log_config=None,
):
self.estimator = clone(estimator)
self.estimator_ = None
self.toolbox = base.Toolbox()
self.estimator = estimator
self.cv = cv
self.scoring = scoring
self.population_size = population_size
Expand All @@ -260,28 +258,6 @@ def __init__(
self.error_score = error_score
self.return_train_score = return_train_score
self.creator = creator
self.logbook = None
self.history = None
self._n_iterations = self.generations + 1
self.X_ = None
self.y_ = None
self.callbacks = None
self.best_params_ = None
self.best_estimator_ = None
self._pop = None
self._stats = None
self._hof = None
self.hof = None
self.X_predict = None
self.scorer_ = None
self.cv_results_ = None
self.best_index_ = None
self.best_score_ = None
self.n_splits_ = None
self.refit_time_ = None
self.refit_metric = "score"
self.metrics_list = None
self.multimetric_ = False
self.log_config = log_config

# Check that the estimator is compatible with scikit-learn
Expand Down Expand Up @@ -321,6 +297,7 @@ def _register(self):
This function is the responsible for registering the DEAPs necessary methods
and create other objects to hold the hof, logbook and stats.
"""
self.toolbox = base.Toolbox()

self.creator.create("FitnessMax", base.Fitness, weights=[self.criteria_sign])
self.creator.create("Individual", list, fitness=creator.FitnessMax)
Expand Down Expand Up @@ -484,6 +461,9 @@ def fit(self, X, y, callbacks=None):

self.X_ = X
self.y_ = y
self._n_iterations = self.generations + 1
self.refit_metric = "score"
self.multimetric_ = False

# Make sure the callbacks are valid
self.callbacks = check_callback(callbacks)
Expand Down Expand Up @@ -623,7 +603,7 @@ def _fitted(self):
except Exception as e:
is_fitted = False

has_history = bool(self.history)
has_history = hasattr(self, "history") and bool(self.history)
return all([is_fitted, has_history, self.refit])

def __getitem__(self, index):
Expand Down Expand Up @@ -877,9 +857,7 @@ def __init__(
return_train_score=False,
log_config=None,
):
self.estimator = clone(estimator)
self.estimator_ = None
self.toolbox = base.Toolbox()
self.estimator = estimator
self.cv = cv
self.scoring = scoring
self.population_size = population_size
Expand All @@ -901,29 +879,6 @@ def __init__(
self.error_score = error_score
self.return_train_score = return_train_score
self.creator = creator
self.logbook = None
self.history = None
self._n_iterations = self.generations + 1
self.n_features = None
self.X_ = None
self.y_ = None
self.features_proportion = None
self.callbacks = None
self.best_features_ = None
self.support_ = None
self.best_estimator_ = None
self._pop = None
self._stats = None
self._hof = None
self.hof = None
self.X_predict = None
self.scorer_ = None
self.cv_results_ = None
self.n_splits_ = None
self.refit_time_ = None
self.refit_metric = "score"
self.metrics_list = None
self.multimetric_ = False
self.log_config = log_config

# Check that the estimator is compatible with scikit-learn
Expand All @@ -943,6 +898,7 @@ def _register(self):
This function is the responsible for registering the DEAPs necessary methods
and create other objects to hold the hof, logbook and stats.
"""
self.toolbox = base.Toolbox()

# Criteria sign to set max or min problem
# And -1.0 as second weight to minimize number of features
Expand Down Expand Up @@ -1081,7 +1037,11 @@ def fit(self, X, y, callbacks=None):

self.X_, self.y_ = check_X_y(X, y)
self.n_features = X.shape[1]
self._n_iterations = self.generations + 1
self.refit_metric = "score"
self.multimetric_ = False

self.features_proportion = None
if self.max_features:
self.features_proportion = self.max_features / self.n_features

Expand Down Expand Up @@ -1207,7 +1167,7 @@ def _fitted(self):
except Exception as e:
is_fitted = False

has_history = bool(self.history)
has_history = hasattr(self, "history") and bool(self.history)
return all([is_fitted, has_history, self.refit])

def __getitem__(self, index):
Expand Down

0 comments on commit 2407115

Please sign in to comment.