Skip to content

Commit

Permalink
Break initializers out into their own section to keep code clean
Browse files Browse the repository at this point in the history
  • Loading branch information
JakeColtman committed Jun 12, 2019
1 parent bbdd435 commit 046dce1
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 74 deletions.
Empty file added bartpy/initializers/__init__.py
Empty file.
25 changes: 25 additions & 0 deletions bartpy/initializers/initializer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from typing import Generator

from bartpy.tree import Tree


class Initializer(object):
"""
The abstract interface for the tree initializers.
Initializers are responsible for setting the starting values of the model, in particular:
- structure of decision and leaf nodes
- variables and values used in splits
- values of leaf nodes
Good initialization of trees helps speed up convergence of sampling
Default behaviour is to leave trees uninitialized
"""

def initialize_tree(self, tree: Tree) -> None:
pass

def initialize_trees(self, trees: Generator[Tree, None, None]) -> None:
for tree in trees:
self.initialize_tree(tree)
93 changes: 93 additions & 0 deletions bartpy/initializers/sklearntreeinitializer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from typing import Tuple
from operator import gt, le

from sklearn.ensemble import GradientBoostingRegressor

from bartpy.initializers.initializer import Initializer
from bartpy.mutation import GrowMutation
from bartpy.node import split_node, LeafNode
from bartpy.splitcondition import SplitCondition
from bartpy.tree import Tree, mutate


class SklearnTreeInitializer(Initializer):
"""
Initialize tree structure and leaf node values by fitting a single Sklearn GBR tree
"""

def __init__(self,
max_depth: int=4,
min_samples_split: int=2,
loss: str='ls'):
self.max_depth = max_depth
self.min_samples_split = min_samples_split,
self.loss = loss

def initialize_tree(self,
tree: Tree) -> None:

params = {
'n_estimators': 1,
'max_depth': self.max_depth,
'min_samples_split': self.min_samples_split,
'learning_rate': 0.8,
'loss': self.loss
}

clf = GradientBoostingRegressor(**params)
fit = clf.fit(tree.nodes[0].data.X.data, tree.nodes[0].data.y.data)
sklearn_tree = fit.estimators_[0][0].tree_
map_sklearn_tree_into_bartpy(tree, sklearn_tree)


def map_sklearn_split_into_bartpy_split_conditions(sklearn_tree, index: int) -> Tuple[SplitCondition, SplitCondition]:
"""
Convert how a split is stored in sklearn's gradient boosted trees library to the bartpy representation
Parameters
----------
sklearn_tree: The full tree object
index: The index of the node in the tree object
Returns
-------
"""
return (
SplitCondition(sklearn_tree.feature[index], sklearn_tree.threshold[index], le),
SplitCondition(sklearn_tree.feature[index], sklearn_tree.threshold[index], gt)
)


def map_sklearn_tree_into_bartpy(bartpy_tree: Tree, sklearn_tree):
nodes = [None for x in sklearn_tree.children_left]
nodes[0] = bartpy_tree.nodes[0]

def search(index: int=0):

left_child_index, right_child_index = sklearn_tree.children_left[index], sklearn_tree.children_right[index]

if left_child_index == -1: # Trees are binary splits, so only need to check left tree
return

searched_node: LeafNode = nodes[index]

split_conditions = map_sklearn_split_into_bartpy_split_conditions(sklearn_tree, index)
decision_node = split_node(searched_node, split_conditions)

left_child: LeafNode = decision_node.left_child
right_child: LeafNode = decision_node.right_child
left_child.set_value(sklearn_tree.value[left_child_index][0][0])
right_child.set_value(sklearn_tree.value[right_child_index][0][0])

mutation = GrowMutation(searched_node, decision_node)
mutate(bartpy_tree, mutation)

nodes[index] = decision_node
nodes[left_child_index] = decision_node.left_child
nodes[right_child_index] = decision_node.right_child

search(left_child_index)
search(right_child_index)

search()
81 changes: 7 additions & 74 deletions bartpy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@
import pandas as pd

from bartpy.data import Data
from bartpy.mutation import GrowMutation
from bartpy.node import split_node
from bartpy.initializers.initializer import Initializer
from bartpy.initializers.sklearntreeinitializer import SklearnTreeInitializer
from bartpy.sigma import Sigma
from bartpy.split import Split
from bartpy.splitcondition import SplitCondition
from bartpy.tree import Tree, LeafNode, deep_copy_tree, mutate
from bartpy.tree import Tree, LeafNode, deep_copy_tree


class Model:
Expand All @@ -22,19 +21,21 @@ def __init__(self,
n_trees: int = 50,
alpha: float=0.95,
beta: float=2.,
k: int=2.):
k: int=2.,
initializer: Initializer=SklearnTreeInitializer()):

self.data = data
self.alpha = float(alpha)
self.beta = float(beta)
self.k = k
self._sigma = sigma
self._prediction = None
self._initializer=initializer

if trees is None:
self.n_trees = n_trees
self._trees = self.initialize_trees()
self.initialize_tree_values()
self._initializer.initialize_trees(self.refreshed_trees())
else:
self.n_trees = len(trees)
self._trees = trees
Expand Down Expand Up @@ -83,74 +84,6 @@ def sigma_m(self):
def sigma(self):
return self._sigma

def initialize_tree_values(self) -> None:
"""
Generate a set of initial values to start sampling from. Helpful for speeding up convergence
Works by using sklearn's GBT package to generate a single estimator for each tree.
Returns
-------
None
"""

from sklearn.ensemble import GradientBoostingRegressor
for tree in self.refreshed_trees():
params = {'n_estimators': 1, 'max_depth': 4, 'min_samples_split': 2,
'learning_rate': 0.8, 'loss': 'ls'}
clf = GradientBoostingRegressor(**params)
fit = clf.fit(tree.nodes[0].data.X.data, tree.nodes[0].data.y.data)
sklearn_tree = fit.estimators_[0][0].tree_
map_sklearn_tree_into_bartpy(tree, sklearn_tree)


def map_sklearn_split_into_bartpy_split_conditions(sklearn_tree, index: int) -> List[SplitCondition]:
"""
Convert how a split is stored in sklearn's gradient boosted trees library to the bartpy representation
Parameters
----------
sklearn_tree: The full tree object
index: The index of the node in the tree object
Returns
-------
"""
return [
SplitCondition(sklearn_tree.feature[index], sklearn_tree.threshold[index], le),
SplitCondition(sklearn_tree.feature[index], sklearn_tree.threshold[index], gt)
]


def map_sklearn_tree_into_bartpy(bartpy_tree: Tree, sklearn_tree):
nodes = [None for x in sklearn_tree.children_left]
nodes[0] = bartpy_tree.nodes[0]

def search(index: int=0):

left_child_index, right_child_index = sklearn_tree.children_left[index], sklearn_tree.children_right[index]

if left_child_index == -1: # Trees are binary splits, so only need to check left tree
return

split_conditions = map_sklearn_split_into_bartpy_split_conditions(sklearn_tree, index)
decision_node = split_node(nodes[index], split_conditions)
decision_node.left_child.set_value(sklearn_tree.value[left_child_index][0][0])
decision_node.right_child.set_value(sklearn_tree.value[right_child_index][0][0])

mutation = GrowMutation(nodes[index], decision_node)
mutate(bartpy_tree, mutation)

nodes[index] = decision_node
nodes[left_child_index] = decision_node.left_child
nodes[right_child_index] = decision_node.right_child

search(left_child_index)
search(right_child_index)

search()


def deep_copy_model(model: Model) -> Model:
copied_model = Model(None, deepcopy(model.sigma), [deep_copy_tree(tree) for tree in model.trees])
Expand Down

0 comments on commit 046dce1

Please sign in to comment.