Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/development' into development
Browse files Browse the repository at this point in the history
  • Loading branch information
rhiever committed Apr 7, 2017
2 parents 23f2791 + cd96558 commit 075d259
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 6 deletions.
7 changes: 2 additions & 5 deletions tpot/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,12 @@
from ._version import __version__
from .operator_utils import TPOTOperatorClassFactory, Operator, ARGType
from .export_utils import export_pipeline, expr_to_tree, generate_pipeline_code
#from .decorators import _timeout, _pre_test, TimedOutExc
from .decorators import _pre_test
from .built_in_operators import CombineDFs

from .metrics import SCORERS
from .gp_types import Output_Array
from .gp_deap import eaMuPlusLambda, mutNodeReplacement, _wrapped_cross_val_score
from .gp_deap import eaMuPlusLambda, mutNodeReplacement, _wrapped_cross_val_score, cxOnePoint

# hot patch for Windows: solve the problem of crashing python after Ctrl + C in Windows OS
if sys.platform.startswith('win'):
Expand Down Expand Up @@ -747,7 +746,7 @@ def _evaluate_individuals(self, individuals, features, classes, sample_weight =

@_pre_test
def _mate_operator(self, ind1, ind2):
return gp.cxOnePoint(ind1, ind2)
return cxOnePoint(ind1, ind2)

@_pre_test
def _random_mutation_operator(self, individual):
Expand All @@ -765,8 +764,6 @@ def _random_mutation_operator(self, individual):
Returns the individual with one of the mutations applied to it
"""
# debug usage
#print(str(individual))
mutation_techniques = [
partial(gp.mutInsert, pset=self._pset),
partial(mutNodeReplacement, pset=self._pset),
Expand Down
45 changes: 44 additions & 1 deletion tpot/gp_deap.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from .operator_utils import set_sample_weight
from sklearn.model_selection import cross_val_score
from sklearn.base import clone
from collections import defaultdict
import warnings
import threading

Expand Down Expand Up @@ -200,6 +201,49 @@ def eaMuPlusLambda(population, toolbox, mu, lambda_, cxpb, mutpb, ngen, pbar,
return population, logbook


def cxOnePoint(ind1, ind2):
"""Randomly select in each individual and exchange each subtree with the
point as root between each individual.
:param ind1: First tree participating in the crossover.
:param ind2: Second tree participating in the crossover.
:returns: A tuple of two trees.
"""
# Define the name of type for any types.
__type__ = object

if len(ind1) < 2 or len(ind2) < 2:
# No crossover on single node tree
return ind1, ind2

# List all available primitive types in each individual
types1 = defaultdict(list)
types2 = defaultdict(list)
if ind1.root.ret == __type__:
# Not STGP optimization
types1[__type__] = range(1, len(ind1))
types2[__type__] = range(1, len(ind2))
common_types = [__type__]
else:
for idx, node in enumerate(ind1[1:], 1):
types1[node.ret].append(idx)
for idx, node in enumerate(ind2[1:], 1):
types2[node.ret].append(idx)

common_types = [x for x in types1 if x in types2]

if len(common_types) > 0:
type_ = np.random.choice(common_types)

index1 = np.random.choice(types1[type_])
index2 = np.random.choice(types2[type_])

slice1 = ind1.searchSubtree(index1)
slice2 = ind2.searchSubtree(index2)
ind1[slice1], ind2[slice2] = ind2[slice2], ind1[slice1]

return ind1, ind2


# point mutation function
def mutNodeReplacement(individual, pset):
"""Replaces a randomly chosen primitive from *individual* by a randomly
Expand Down Expand Up @@ -283,7 +327,6 @@ def run(self):
warnings.simplefilter('ignore')
self.result = cross_val_score(*self.args, **self.kwargs)
except Exception as e:
#print(e) # for debug use
pass

def _wrapped_cross_val_score(sklearn_pipeline, features, classes,
Expand Down

0 comments on commit 075d259

Please sign in to comment.