Skip to content

Commit

Permalink
Add CaseWhenRuler (#61)
Browse files Browse the repository at this point in the history
* casewhen

* casewhen

* added-tests

* final

* dependencies

* add-py3.8

* style
  • Loading branch information
koaning authored May 2, 2021
1 parent af4a84a commit c6640a7
Show file tree
Hide file tree
Showing 8 changed files with 234 additions and 32 deletions.
3 changes: 3 additions & 0 deletions docs/api/rulers.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# `CaseWhenRuler`

::: hulearn.experimental.CaseWhenRuler
14 changes: 7 additions & 7 deletions docs/guide/notebooks/01-function-classifier.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -251,13 +251,13 @@
"X_subset = X.assign(sex=lambda d: d['sex'] == 'female')[['pclass', 'sex', 'age', 'fare']]\n",
"\n",
"grid = GridSearchCV(RandomForestClassifier(), \n",
" cv=10, \n",
" param_grid={},\n",
" scoring={'accuracy': make_scorer(accuracy_score), \n",
" 'precision': make_scorer(precision_score),\n",
" 'recall': make_scorer(recall_score)},\n",
" refit='precision'\n",
" )\n",
" cv=10,\n",
" param_grid={},\n",
" scoring={'accuracy': make_scorer(accuracy_score),\n",
" 'precision': make_scorer(precision_score),\n",
" 'recall': make_scorer(recall_score)},\n",
" refit='precision'\n",
")\n",
"\n",
"pd.DataFrame(grid.fit(X_subset, y).cv_results_)[['mean_test_accuracy', 'mean_test_precision', 'mean_test_recall']]"
]
Expand Down
3 changes: 2 additions & 1 deletion hulearn/experimental/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .interactive import InteractiveCharts, parallel_coordinates
from .ruler import CaseWhenRuler

__all__ = ["InteractiveCharts", "parallel_coordinates"]
__all__ = ["InteractiveCharts", "parallel_coordinates", "CaseWhenRuler"]
133 changes: 133 additions & 0 deletions hulearn/experimental/ruler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import pandas as pd


class CaseWhenRuler:
"""
Helper class to construct "case when"-style FunctionClassifiers.
This class allows you to write a system of rules using lambda functions.
These functions cannot be pickled by scikit-learn however, so if you'd like
to use this class in a GridSearch you will need to wrap it around a
FunctionClassifier.
Arguments:
default: the default value to predict if no rules apply
Usage:
```python
from hulearn.datasets import load_titanic
from hulearn.experimental import CaseWhenRuler
from hulearn.classification import FunctionClassifier
def make_prediction(dataf, age=15):
ruler = CaseWhenRuler(default=0)
(ruler
.add_rule(lambda d: (d['pclass'] < 3.0) & (d['sex'] == "female"), 1, name="gender-rule")
.add_rule(lambda d: (d['pclass'] < 3.0) & (d['age'] <= age), 1, name="child-rule"))
return ruler.predict(dataf)
clf = FunctionClassifier(make_prediction)
```
"""

def __init__(self, default=None):
self.default = default
self.rules = []

def add_rule(self, when, then, name=None):
"""
Adds a rule to the system.
Arguments:
when: a (lambda) function that tells us when the rule applies
then: the value to output if the rule applies
name: an optional name for the rule
"""
if not name:
name = f"rule-{len(self.rules) + 1}"
self.rules.append((when, then, name))
return self

def predict(self, X):
"""
Makes a prediction based on the rules sofar.
Usage:
```python
from hulearn.classification import FunctionClassifier
from hulearn.experimental import CaseWhenRuler
def make_prediction(dataf, gender_rule=True, child_rule=True, fare_rule=True):
ruler = CaseWhenRuler(default=0)
if gender_rule:
ruler.add_rule(when=lambda d: (d['pclass'] < 3.0) & (d['sex'] == "female"),
then=1,
name="gender-rule")
if child_rule:
ruler.add_rule(when=lambda d: (d['pclass'] < 3.0) & (d['age'] <= 15),
then=1,
name="child-rule")
if fare_rule:
ruler.add_rule(when=lambda d: (d['fare'] > 100),
then=1,
name="fare-rule")
return ruler.transform(dataf)
clf = FunctionClassifier(make_prediction)
```
"""
results = [self.default for x in range(len(X))]
for rule in self.rules:
when, then, name = rule
for idx, predicate in enumerate(when(X)):
if predicate and (results[idx] == self.default):
results[idx] = then
return results

def transform(self, X):
"""
Produces a dataframe that indicates the state of all rules.
Usage:
```python
from hulearn.preprocessing import PipeTransformer
from hulearn.experimental import CaseWhenRuler
def make_prediction(dataf, gender_rule=True, child_rule=True, fare_rule=True):
ruler = CaseWhenRuler(default=0)
if gender_rule:
ruler.add_rule(when=lambda d: (d['pclass'] < 3.0) & (d['sex'] == "female"),
then=1,
name="gender-rule")
if child_rule:
ruler.add_rule(when=lambda d: (d['pclass'] < 3.0) & (d['age'] <= 15),
then=1,
name="child-rule")
if fare_rule:
ruler.add_rule(when=lambda d: (d['fare'] > 100),
then=1,
name="fare-rule")
return ruler.transform(dataf)
clf = PipeTransformer(make_prediction)
```
"""
result = pd.DataFrame()
for rule in self.rules:
when, then, name = rule
result[name] = when(X)
return result
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ nav:
- Utility:
- Common: api/common.md
- Datasets: api/datasets.md
- Rulers: api/rulers.md
- Examples:
- Examples: examples/examples.md
- FAQ: examples/faq.md
Expand Down
4 changes: 4 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,14 @@
"nbval>=0.9.6",
"scikit-lego>=0.6.0",
"matplotlib>=3.0.2",
"mktestdocs==0.1.1",
]

util_packages = [
"jupyter>=1.0.0",
"jupyterlab>=0.35.4",
]

dev_packages = docs_packages + test_packages + util_packages


Expand Down Expand Up @@ -64,6 +67,7 @@ def read(fname):
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"License :: OSI Approved :: MIT License",
"Topic :: Scientific/Engineering",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
Expand Down
37 changes: 13 additions & 24 deletions tests/test_docstrings.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,20 @@
import pytest
from mktestdocs import check_docstring, get_codeblock_members

from hulearn.datasets import load_titanic
from hulearn.experimental import CaseWhenRuler
from hulearn.common import flatten, df_to_dictlist

import pytest
members = get_codeblock_members(CaseWhenRuler)


def handle_docstring(doc):
"""
This function will read through the docstring and grab
the first python code block. It will try to execute it.
If it fails, the calling test should raise a flag.
"""
if not doc:
return
start = doc.find("```python\n")
end = doc.find("```\n")
if start != -1:
if end != -1:
code_part = doc[(start + 10) : end]
code = "\n".join([c[4:] for c in code_part.split("\n")])
print(code)
exec(code)
@pytest.mark.parametrize(
"func", [load_titanic, flatten, df_to_dictlist], ids=lambda d: d.__name__
)
def test_docstring(func):
check_docstring(obj=func)


@pytest.mark.parametrize("m", [load_titanic, flatten, df_to_dictlist])
def test_mappers_docstrings(m):
"""
Take the docstring of every method on the `Clumper` class.
The test passes if the usage examples causes no errors.
"""
handle_docstring(m.__doc__)
@pytest.mark.parametrize("obj", members, ids=lambda d: d.__qualname__)
def test_members(obj):
check_docstring(obj)
71 changes: 71 additions & 0 deletions tests/test_interactive/test_casewhen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import pandas as pd

from hulearn.datasets import load_titanic
from hulearn.classification import FunctionClassifier
from hulearn.experimental import CaseWhenRuler


def test_smoke_casewhen():
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import (
make_scorer,
accuracy_score,
precision_score,
recall_score,
)

def make_prediction(dataf, gender_rule=True, child_rule=True, fare_rule=True):
ruler = CaseWhenRuler(default=0)

if gender_rule:
ruler.add_rule(
when=lambda d: (d["pclass"] < 3.0) & (d["sex"] == "female"),
then=1,
name="gender-rule",
)

if child_rule:
ruler.add_rule(
when=lambda d: (d["pclass"] < 3.0) & (d["age"] <= 15),
then=1,
name="child-rule",
)

if fare_rule:
ruler.add_rule(when=lambda d: (d["fare"] > 100), then=1, name="fare-rule")

return ruler.predict(dataf)

df = load_titanic(as_frame=True)
X, y = df.drop(columns=["survived"]), df["survived"]

clf = FunctionClassifier(make_prediction)

cv = GridSearchCV(
clf,
cv=10,
param_grid={
"gender_rule": [True, False],
"child_rule": [True, False],
"fare_rule": [True, False],
},
scoring={
"accuracy": make_scorer(accuracy_score),
"precision": make_scorer(precision_score),
"recall": make_scorer(recall_score),
},
refit="accuracy",
)

res = pd.DataFrame(cv.fit(X, y).cv_results_)[
[
"param_child_rule",
"param_fare_rule",
"param_gender_rule",
"mean_test_accuracy",
"mean_test_precision",
"mean_test_recall",
]
]

assert res.shape[0] == 8

0 comments on commit c6640a7

Please sign in to comment.