Skip to content

Commit

Permalink
Add BoostAODE initial model
Browse files Browse the repository at this point in the history
  • Loading branch information
rmontanana committed Jun 15, 2023
1 parent 923a06b commit 3812d27
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 1 deletion.
1 change: 1 addition & 0 deletions bayesclass/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,5 @@
"AODE",
"KDBNew",
"AODENew",
"BoostAODE",
]
34 changes: 33 additions & 1 deletion bayesclass/clfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,13 @@ def build_spodes(features, class_name):

class SPODE(BayesBase):
def _check_params(self, X, y, kwargs):
expected_args = ["class_name", "features", "state_names"]
expected_args = [
"class_name",
"features",
"state_names",
"sample_weight",
"weighted",
]
return self._check_params_fit(X, y, expected_args, kwargs)


Expand Down Expand Up @@ -775,3 +781,29 @@ def _local_discretization(self):
# np.array(self.state_names_[self.features_[i]]),
# )
# raise ValueError("Discretization error")


class BoostAODE(AODE):
def fit(self, X, y, **kwargs):
self.n_features_in_ = X.shape[1]
self.feature_names_in_ = kwargs.get(
"features", default_feature_names(self.n_features_in_)
)
self.class_name_ = kwargs.get("class_name", "class")
# build estimator
self._validate_estimator()
self.X_ = X
self.y_ = y
self.estimators_ = []
self._train(kwargs)
# To keep compatiblity with the benchmark platform
self.fitted_ = True
self.nodes_leaves = self.nodes_edges
return self

def _train(self, kwargs):
for dag in build_spodes(self.feature_names_in_, self.class_name_):
estimator = clone(self.estimator_)
estimator.dag_ = estimator.model_ = dag
estimator.fit(self.X_, self.y_, **kwargs)
self.estimators_.append(estimator)

0 comments on commit 3812d27

Please sign in to comment.