-
Notifications
You must be signed in to change notification settings - Fork 54
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* casewhen * casewhen * added-tests * final * dependencies * add-py3.8 * style
- Loading branch information
Showing
8 changed files
with
234 additions
and
32 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# `CaseWhenRuler` | ||
|
||
::: hulearn.experimental.CaseWhenRuler |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
from .interactive import InteractiveCharts, parallel_coordinates | ||
from .ruler import CaseWhenRuler | ||
|
||
__all__ = ["InteractiveCharts", "parallel_coordinates"] | ||
__all__ = ["InteractiveCharts", "parallel_coordinates", "CaseWhenRuler"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
import pandas as pd | ||
|
||
|
||
class CaseWhenRuler: | ||
""" | ||
Helper class to construct "case when"-style FunctionClassifiers. | ||
This class allows you to write a system of rules using lambda functions. | ||
These functions cannot be pickled by scikit-learn however, so if you'd like | ||
to use this class in a GridSearch you will need to wrap it around a | ||
FunctionClassifier. | ||
Arguments: | ||
default: the default value to predict if no rules apply | ||
Usage: | ||
```python | ||
from hulearn.datasets import load_titanic | ||
from hulearn.experimental import CaseWhenRuler | ||
from hulearn.classification import FunctionClassifier | ||
def make_prediction(dataf, age=15): | ||
ruler = CaseWhenRuler(default=0) | ||
(ruler | ||
.add_rule(lambda d: (d['pclass'] < 3.0) & (d['sex'] == "female"), 1, name="gender-rule") | ||
.add_rule(lambda d: (d['pclass'] < 3.0) & (d['age'] <= age), 1, name="child-rule")) | ||
return ruler.predict(dataf) | ||
clf = FunctionClassifier(make_prediction) | ||
``` | ||
""" | ||
|
||
def __init__(self, default=None): | ||
self.default = default | ||
self.rules = [] | ||
|
||
def add_rule(self, when, then, name=None): | ||
""" | ||
Adds a rule to the system. | ||
Arguments: | ||
when: a (lambda) function that tells us when the rule applies | ||
then: the value to output if the rule applies | ||
name: an optional name for the rule | ||
""" | ||
if not name: | ||
name = f"rule-{len(self.rules) + 1}" | ||
self.rules.append((when, then, name)) | ||
return self | ||
|
||
def predict(self, X): | ||
""" | ||
Makes a prediction based on the rules sofar. | ||
Usage: | ||
```python | ||
from hulearn.classification import FunctionClassifier | ||
from hulearn.experimental import CaseWhenRuler | ||
def make_prediction(dataf, gender_rule=True, child_rule=True, fare_rule=True): | ||
ruler = CaseWhenRuler(default=0) | ||
if gender_rule: | ||
ruler.add_rule(when=lambda d: (d['pclass'] < 3.0) & (d['sex'] == "female"), | ||
then=1, | ||
name="gender-rule") | ||
if child_rule: | ||
ruler.add_rule(when=lambda d: (d['pclass'] < 3.0) & (d['age'] <= 15), | ||
then=1, | ||
name="child-rule") | ||
if fare_rule: | ||
ruler.add_rule(when=lambda d: (d['fare'] > 100), | ||
then=1, | ||
name="fare-rule") | ||
return ruler.transform(dataf) | ||
clf = FunctionClassifier(make_prediction) | ||
``` | ||
""" | ||
results = [self.default for x in range(len(X))] | ||
for rule in self.rules: | ||
when, then, name = rule | ||
for idx, predicate in enumerate(when(X)): | ||
if predicate and (results[idx] == self.default): | ||
results[idx] = then | ||
return results | ||
|
||
def transform(self, X): | ||
""" | ||
Produces a dataframe that indicates the state of all rules. | ||
Usage: | ||
```python | ||
from hulearn.preprocessing import PipeTransformer | ||
from hulearn.experimental import CaseWhenRuler | ||
def make_prediction(dataf, gender_rule=True, child_rule=True, fare_rule=True): | ||
ruler = CaseWhenRuler(default=0) | ||
if gender_rule: | ||
ruler.add_rule(when=lambda d: (d['pclass'] < 3.0) & (d['sex'] == "female"), | ||
then=1, | ||
name="gender-rule") | ||
if child_rule: | ||
ruler.add_rule(when=lambda d: (d['pclass'] < 3.0) & (d['age'] <= 15), | ||
then=1, | ||
name="child-rule") | ||
if fare_rule: | ||
ruler.add_rule(when=lambda d: (d['fare'] > 100), | ||
then=1, | ||
name="fare-rule") | ||
return ruler.transform(dataf) | ||
clf = PipeTransformer(make_prediction) | ||
``` | ||
""" | ||
result = pd.DataFrame() | ||
for rule in self.rules: | ||
when, then, name = rule | ||
result[name] = when(X) | ||
return result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,31 +1,20 @@ | ||
import pytest | ||
from mktestdocs import check_docstring, get_codeblock_members | ||
|
||
from hulearn.datasets import load_titanic | ||
from hulearn.experimental import CaseWhenRuler | ||
from hulearn.common import flatten, df_to_dictlist | ||
|
||
import pytest | ||
members = get_codeblock_members(CaseWhenRuler) | ||
|
||
|
||
def handle_docstring(doc): | ||
""" | ||
This function will read through the docstring and grab | ||
the first python code block. It will try to execute it. | ||
If it fails, the calling test should raise a flag. | ||
""" | ||
if not doc: | ||
return | ||
start = doc.find("```python\n") | ||
end = doc.find("```\n") | ||
if start != -1: | ||
if end != -1: | ||
code_part = doc[(start + 10) : end] | ||
code = "\n".join([c[4:] for c in code_part.split("\n")]) | ||
print(code) | ||
exec(code) | ||
@pytest.mark.parametrize( | ||
"func", [load_titanic, flatten, df_to_dictlist], ids=lambda d: d.__name__ | ||
) | ||
def test_docstring(func): | ||
check_docstring(obj=func) | ||
|
||
|
||
@pytest.mark.parametrize("m", [load_titanic, flatten, df_to_dictlist]) | ||
def test_mappers_docstrings(m): | ||
""" | ||
Take the docstring of every method on the `Clumper` class. | ||
The test passes if the usage examples causes no errors. | ||
""" | ||
handle_docstring(m.__doc__) | ||
@pytest.mark.parametrize("obj", members, ids=lambda d: d.__qualname__) | ||
def test_members(obj): | ||
check_docstring(obj) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
import pandas as pd | ||
|
||
from hulearn.datasets import load_titanic | ||
from hulearn.classification import FunctionClassifier | ||
from hulearn.experimental import CaseWhenRuler | ||
|
||
|
||
def test_smoke_casewhen(): | ||
from sklearn.model_selection import GridSearchCV | ||
from sklearn.metrics import ( | ||
make_scorer, | ||
accuracy_score, | ||
precision_score, | ||
recall_score, | ||
) | ||
|
||
def make_prediction(dataf, gender_rule=True, child_rule=True, fare_rule=True): | ||
ruler = CaseWhenRuler(default=0) | ||
|
||
if gender_rule: | ||
ruler.add_rule( | ||
when=lambda d: (d["pclass"] < 3.0) & (d["sex"] == "female"), | ||
then=1, | ||
name="gender-rule", | ||
) | ||
|
||
if child_rule: | ||
ruler.add_rule( | ||
when=lambda d: (d["pclass"] < 3.0) & (d["age"] <= 15), | ||
then=1, | ||
name="child-rule", | ||
) | ||
|
||
if fare_rule: | ||
ruler.add_rule(when=lambda d: (d["fare"] > 100), then=1, name="fare-rule") | ||
|
||
return ruler.predict(dataf) | ||
|
||
df = load_titanic(as_frame=True) | ||
X, y = df.drop(columns=["survived"]), df["survived"] | ||
|
||
clf = FunctionClassifier(make_prediction) | ||
|
||
cv = GridSearchCV( | ||
clf, | ||
cv=10, | ||
param_grid={ | ||
"gender_rule": [True, False], | ||
"child_rule": [True, False], | ||
"fare_rule": [True, False], | ||
}, | ||
scoring={ | ||
"accuracy": make_scorer(accuracy_score), | ||
"precision": make_scorer(precision_score), | ||
"recall": make_scorer(recall_score), | ||
}, | ||
refit="accuracy", | ||
) | ||
|
||
res = pd.DataFrame(cv.fit(X, y).cv_results_)[ | ||
[ | ||
"param_child_rule", | ||
"param_fare_rule", | ||
"param_gender_rule", | ||
"mean_test_accuracy", | ||
"mean_test_precision", | ||
"mean_test_recall", | ||
] | ||
] | ||
|
||
assert res.shape[0] == 8 |