Skip to content

Commit

Permalink
Add pyarrow dependency and use pyarrow backed dtypes (#120)
Browse files Browse the repository at this point in the history
* try pyarrow dtypes

* try pyarrow dtypes

* try pyarrow dtypes

* try pyarrow dtypes

* try pyarrow dtypes

* try pyarrow dtypes

* try pyarrow dtypes

* try pyarrow dtypes

* try pyarrow dtypes

* try pyarrow dtypes

* cleanup

* cleanup

* cleanup
  • Loading branch information
gsheni authored Jul 31, 2023
1 parent a1f3bcf commit c8a8893
Show file tree
Hide file tree
Showing 18 changed files with 234 additions and 97 deletions.
2 changes: 2 additions & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Changelog
v0.6.0 (, 2023)
===============
* Enhancements
* Add pyarrow dependency and use pyarrow backed dtypes [#120][#120]
* Fixes
* Rename `_execute_operations_on_df` to `target` in executed prediction problem dataframe [#124][#124]
* Clean up operation description generation [#118][#118]
Expand All @@ -14,6 +15,7 @@ v0.6.0 (, 2023)

[#124]: <https://github.com/trane-dev/Trane/pull/124>
[#118]: <https://github.com/trane-dev/Trane/pull/118>
[#120]: <https://github.com/trane-dev/Trane/pull/120>


v0.5.0 (July 27, 2023)
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ dependencies = [
"tqdm >= 4.65.0",
"ipywidgets >= 8.0.0",
"importlib_resources >= 6.0.0",
"pyarrow >= 12.0.1"
]

[project.urls]
Expand Down
12 changes: 11 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,18 @@ def make_fake_df():
],
"state": ["MA", "NY", "NY", "NJ", "NJ", "CT"],
"amount": [10.0, 20.0, 30.0, 40.0, 50.0, 60.0],
"is_fraud": [False, False, False, True, False, False],
}
df = pd.DataFrame(data)
df["date"] = pd.to_datetime(df["date"])
df = df.astype(
{
"id": "int64[pyarrow]",
"date": "datetime64[ns]",
"state": "category",
"amount": "float64[pyarrow]",
"is_fraud": "boolean[pyarrow]",
},
)
df = df.sort_values(by=["date"])
return df

Expand All @@ -43,6 +52,7 @@ def make_fake_meta():
"date": ("Datetime", {}),
"state": ("Categorical", {"category"}),
"amount": ("Double", {"numeric"}),
"is_fraud": ("Boolean", {}),
}
return meta

Expand Down
47 changes: 9 additions & 38 deletions tests/integration_tests/test_examples.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import os

import pandas as pd
import pytest

import trane
from trane.datasets.load_functions import (
load_covid,
load_covid_metadata,
load_youtube,
load_youtube_metadata,
)

Expand All @@ -17,43 +18,10 @@ def current_dir():
return os.path.dirname(__file__)


@pytest.fixture
def df_youtube(current_dir):
datetime_col = "trending_date"
filename = "USvideos.csv"
df = pd.read_csv(os.path.join(current_dir, filename))
df[datetime_col] = pd.to_datetime(df[datetime_col], format="%y.%d.%m")
df = df.sort_values(by=[datetime_col])
df = df.fillna(0)
return df


@pytest.fixture
def meta_youtube(current_dir):
table_meta = load_youtube_metadata()
return table_meta


@pytest.fixture
def df_covid(current_dir):
datetime_col = "Date"
filename = "covid19.csv"
df = pd.read_csv(os.path.join(current_dir, filename))
df[datetime_col] = pd.to_datetime(df[datetime_col], format="%m/%d/%y")
# to speed up things as covid dataset takes awhile
df = df.sample(frac=0.25, random_state=1)
df = df.sort_values(by=[datetime_col])
df = df.fillna(0)
return df


@pytest.fixture
def meta_covid(current_dir):
table_meta = load_covid_metadata()
return table_meta

def test_youtube(sample):
df_youtube = load_youtube()
meta_youtube = load_youtube_metadata()

def test_youtube(df_youtube, meta_youtube, sample):
entity_col = "category_id"
time_col = "trending_date"
cutoff = "4d"
Expand All @@ -76,7 +44,10 @@ def test_youtube(df_youtube, meta_youtube, sample):
)


def test_covid(df_covid, meta_covid, sample):
def test_covid(sample):
df_covid = load_covid()
meta_covid = load_covid_metadata()

entity_col = "Country/Region"
time_col = "Date"
cutoff = "2d"
Expand Down
22 changes: 20 additions & 2 deletions tests/ops/test_aggregation_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def test_count_agg_op(df):


def test_sum_agg_op(df):
op = CountAggregationOp("col")
op = SumAggregationOp("col")
output = op(df)
assert output == np.sum(df["col"])
Expand Down Expand Up @@ -54,10 +55,27 @@ def test_min_agg_op(df):
assert "the minimum <col> in all related records" in op.generate_description()


def test_majority_agg_op(df):
@pytest.mark.parametrize(
"dtype",
[
("string"),
("string[pyarrow]"),
("int64"),
("int64[pyarrow]"),
("float64"),
("float64[pyarrow]"),
],
)
def test_majority_agg_op(df, dtype):
op = MajorityAggregationOp("col")
df["col"] = df["col"].astype(dtype)
output = op(df)
assert output == str(1)
if dtype in ["string", "string[pyarrow]"]:
assert output == str(1)
elif dtype in ["int64", "int64[pyarrow]"]:
assert output == int(1)
else:
assert output == float(1.0)
assert "the majority <col> in all related records" in op.generate_description()


Expand Down
11 changes: 9 additions & 2 deletions tests/ops/test_threshold_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,15 @@ def test_get_k_most_frequent(dtype):
assert most_frequent == ["r", "b", "g"]


def test_get_k_most_frequent_raises():
series = pd.Series([1, 2, 3, 4, 5], dtype="int64")
@pytest.mark.parametrize(
"dtype",
[
("int64"),
("int64[pyarrow]"),
],
)
def test_get_k_most_frequent_raises(dtype):
series = pd.Series([1, 2, 3, 4, 5], dtype=dtype)
with pytest.raises(ValueError):
get_k_most_frequent(series)

Expand Down
12 changes: 12 additions & 0 deletions tests/test_load_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,13 @@ def test_load_covid():
assert len(df) >= 17136
assert df["Date"].dtype == "datetime64[ns]"
assert metadata["Date"] == ColumnSchema(logical_type=Datetime)
assert df["Lat"].dtype == "float64[pyarrow]"
assert df["Long"].dtype == "float64[pyarrow]"
assert df["Confirmed"].dtype == "int64[pyarrow]"
assert df["Deaths"].dtype == "int64[pyarrow]"
assert df["Recovered"].dtype == "int64[pyarrow]"
assert df["Country/Region"].dtype == "category"
assert df["Province/State"].dtype == "category"


def test_load_youtube():
Expand All @@ -44,6 +51,11 @@ def test_load_youtube():
check_column_schema(expected_columns, df, metadata)
assert df["trending_date"].dtype == "datetime64[ns]"
assert metadata["trending_date"] == ColumnSchema(logical_type=Datetime)
assert df["views"].dtype == "int64[pyarrow]"
assert df["likes"].dtype == "int64[pyarrow]"
assert df["dislikes"].dtype == "int64[pyarrow]"
assert df["channel_title"].dtype == "category"
assert df["category_id"].dtype == "category"


def check_column_schema(columns, df, metadata):
Expand Down
Empty file removed tests/test_mock_dataset.py
Empty file.
13 changes: 12 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,18 @@ def verify_numeric_op(modified_meta, result):
)


def test_parse_table_cat():
def test_check_operations_boolean():
table_meta = {
"id": ("Categorical", {"primary_key", "category"}),
"is_fraud": ("Boolean", {}),
}
table_meta = _parse_table_meta(table_meta)
operations = [EqFilterOp("is_fraud"), MajorityAggregationOp("is_fraud")]
result, _ = _check_operations_valid(operations, table_meta)
assert result is False


def test_check_operations_cat():
table_meta = {
"id": ("Categorical", {"primary_key", "category"}),
"state": ("Categorical", {"category"}),
Expand Down
28 changes: 23 additions & 5 deletions tests/typing/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
_infer_series_schema,
infer_table_meta,
)
from trane.typing.inference_functions import (
pandas_modulo,
)
from trane.typing.logical_types import (
Boolean,
Datetime,
Expand All @@ -24,7 +27,7 @@ def pandas_integers():
return [
pd.Series(4 * [-1, 2, 1, 7]),
pd.Series(4 * [-1, 0, 5, 3]),
pd.Series(4 * [sys.maxsize, -sys.maxsize - 1, 0], dtype="str").astype("int64"),
pd.Series(4 * [sys.maxsize, -sys.maxsize - 1, 0]),
]


Expand All @@ -45,7 +48,6 @@ def pandas_doubles():
pd.Series(4 * [-1, 2.5, 1, 7]),
pd.Series(4 * [1.5, np.nan, 1, 3]),
pd.Series(4 * [1.5, np.inf, 1, 3]),
pd.Series(4 * [np.finfo("d").max, np.finfo("d").min, 3, 1]),
]


Expand All @@ -68,7 +70,7 @@ def pandas_datetimes():


def test_boolean_inference(pandas_bools):
dtypes = ["bool"]
dtypes = ["bool", "boolean", "boolean[pyarrow]", "bool_", "bool8", "object"]
for series in pandas_bools:
for dtype in dtypes:
series = series.astype(dtype)
Expand All @@ -83,10 +85,11 @@ def test_unknown_inference():
column_schema = _infer_series_schema(series)
assert isinstance(column_schema, ColumnSchema)
assert column_schema.logical_type == Unknown
assert column_schema.logical_type.dtype == "string[pyarrow]"


def test_double_inference(pandas_doubles):
dtypes = ["float", "float32", "float64", "float_"]
dtypes = ["float32", "float64", "float64[pyarrow]", "float32[pyarrow]"]
for series in pandas_doubles:
for dtype in dtypes:
column_schema = _infer_series_schema(series.astype(dtype))
Expand All @@ -104,7 +107,15 @@ def test_datetime_inference(pandas_datetimes):


def test_integer_inference(pandas_integers):
dtypes = ["int8", "int16", "int32", "int64", "intp", "int", "Int64"]
dtypes = [
"int8",
"int16",
"int32",
"int64",
"int64[pyarrow]",
"intp",
"int",
]

for series in pandas_integers:
for dtype in dtypes:
Expand Down Expand Up @@ -136,3 +147,10 @@ def test_infer_table_meta():
assert table_meta["c"].logical_type == Unknown
assert table_meta["d"].logical_type == Datetime
assert table_meta["d"].semantic_tags == {"time_index"}


def test_pandas_modulo():
dtypes = ["int64", "int64[pyarrow]"]
for dtype in dtypes:
series = pd.Series([1, 2, 3, 4, 5], dtype=dtype)
assert pandas_modulo(series, 1).tolist() == [0, 0, 0, 0, 0]
66 changes: 50 additions & 16 deletions trane/core/prediction_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,11 @@ def execute(
**kwargs,
):
"""
Executes the problem's operations on a dataframe. Generates the training examples (lable_times).
Executes the problem's operations on a dataframe. Generates the training examples (label_times).
The label_times contains the
"""

assert df.isnull().sum().sum() == 0
# assert df.isnull().sum().sum() == 0

if not self.is_valid(self.table_meta):
raise ValueError(
Expand All @@ -138,25 +138,24 @@ def execute(
df["__identity__"] = 0
target_dataframe_index = "__identity__"

self._label_maker = cp.LabelMaker(
target_dataframe_index=target_dataframe_index,
time_index=self.time_col,
labeling_function=self._execute_operations_on_df,
window_size=self.window_size,
)
minimum_data = minimum_data or self.cutoff_strategy.minimum_data
maximum_data = maximum_data or self.cutoff_strategy.maximum_data
lt = self._label_maker.search(
df=df,
num_examples_per_instance=num_examples_per_instance,
minimum_data=minimum_data,
maximum_data=maximum_data,
gap=gap,
drop_empty=drop_empty,
verbose=verbose,
lt = calculate_label_times(
target_dataframe_index,
df,
minimum_data,
maximum_data,
gap,
drop_empty,
self._execute_operations_on_df,
self.time_col,
self.window_size,
num_examples_per_instance,
verbose,
*args,
**kwargs,
)

if "__identity__" in df.columns:
df.drop(columns=["__identity__"], inplace=True)
lt = lt.rename(columns={"_execute_operations_on_df": "target"})
Expand Down Expand Up @@ -211,3 +210,38 @@ def __str__(self):
self.cutoff_strategy.window_size,
)
return description


def calculate_label_times(
target_dataframe_index,
df,
minimum_data,
maximum_data,
gap,
drop_empty,
_execute_operations_on_df,
time_col,
window_size,
num_examples_per_instance,
verbose,
*args,
**kwargs,
):
_label_maker = cp.LabelMaker(
target_dataframe_index=target_dataframe_index,
time_index=time_col,
labeling_function=_execute_operations_on_df,
window_size=window_size,
)
label_times = _label_maker.search(
df=df,
num_examples_per_instance=num_examples_per_instance,
minimum_data=minimum_data,
maximum_data=maximum_data,
gap=gap,
drop_empty=drop_empty,
verbose=verbose,
*args,
**kwargs,
)
return label_times
Loading

0 comments on commit c8a8893

Please sign in to comment.