forked from NVIDIA-Merlin/Transformers4Rec
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Introducing data module to improve our examples * Introducing data module to improve our examples * Making use of test-data in trainer tests * Update quick-start + add test for it * Adding type-hint for model * Default to json as serialization for Schema * Updating ci * Updating from_schema to also accept path * Fixing test_schema * Make json the default in Dataset * Adding preprocessing * Putting back code from Sara that somehow was gone after rebasing * Updating generate_item_interactions * Adding basic tests * Adding test_generate_item_interactions * Adding some tests for test_preprocessing * Put synthetic_ecommerce_data_schema in lib * Put synthetic_ecommerce_data_schema in lib * Fixing failing test * Update transformers4rec/data/synthetic.py Co-authored-by: sararb <[email protected]> * Update transformers4rec/data/preprocessing.py Co-authored-by: sararb <[email protected]> * Fixing process_clicks like suggested in PR comments * Running black Co-authored-by: sararb <[email protected]>
- Loading branch information
1 parent
b417e63
commit b05eb57
Showing
36 changed files
with
1,574 additions
and
266 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
include *.md | ||
include *.yaml | ||
|
||
recursive-include _images *.png | ||
recursive-exclude docs * | ||
recursive-exclude docs *.ipynb | ||
exclude docs/source/_images | ||
|
||
recursive-include requirements *.txt | ||
|
||
recursive-exclude tests *.coveragerc | ||
recursive-include tests *.parquet | ||
recursive-include tests *.pbtxt | ||
|
||
recursive-include transformers4rec *.parquet *.json *.py | ||
|
||
|
||
# Ignore notebooks & examples | ||
recursive-exclude examples * | ||
exclude examples | ||
exclude tutorial | ||
recursive-exclude tutorial * | ||
|
||
# Ignore build related things | ||
recursive-exclude conda * | ||
exclude conda | ||
recursive-exclude .github * | ||
exclude .github | ||
recursive-exclude ci * | ||
exclude ci | ||
exclude .pylintrc |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
check-manifest | ||
pytest>=5 | ||
pytest-cov>=2 | ||
black==20.8b1 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import numpy as np | ||
import pytest | ||
|
||
from merlin_standard_lib import ColumnSchema, Schema, Tag | ||
from transformers4rec.data.preprocessing import ( | ||
add_item_first_seen_col_to_df, | ||
remove_consecutive_interactions, | ||
) | ||
from transformers4rec.data.synthetic import ( | ||
generate_item_interactions, | ||
synthetic_ecommerce_data_schema, | ||
) | ||
|
||
pd = pytest.importorskip("pandas") | ||
|
||
|
||
def test_remove_consecutive_interactions(): | ||
np.random.seed(0) | ||
|
||
schema = synthetic_ecommerce_data_schema.remove_by_name("item_recency") | ||
schema += Schema([ColumnSchema.create_continuous("timestamp", tags=[Tag.SESSION])]) | ||
|
||
interactions_df = generate_item_interactions(500, schema) | ||
filtered_df = remove_consecutive_interactions(interactions_df.copy()) | ||
|
||
assert len(filtered_df) < len(interactions_df) | ||
assert len(filtered_df) == 499 | ||
assert len(list(filtered_df.columns)) == len(list(interactions_df.columns)) | ||
|
||
|
||
def test_add_item_first_seen_col_to_df(): | ||
schema = synthetic_ecommerce_data_schema.remove_by_name("item_recency") | ||
schema += Schema([ColumnSchema.create_continuous("timestamp", tags=[Tag.SESSION])]) | ||
|
||
df = add_item_first_seen_col_to_df(generate_item_interactions(500, schema)) | ||
|
||
assert len(list(df.columns)) == len(schema) + 1 | ||
assert isinstance(df["item_ts_first"], pd.Series) | ||
|
||
|
||
# TODO: Add test for session_aggregator when nvtabular 21.09 is released |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
import pytest | ||
|
||
from transformers4rec.data.synthetic import ( | ||
generate_item_interactions, | ||
synthetic_ecommerce_data_schema, | ||
) | ||
|
||
pd = pytest.importorskip("pandas") | ||
|
||
|
||
def test_generate_item_interactions(): | ||
data = generate_item_interactions(500, synthetic_ecommerce_data_schema) | ||
|
||
assert isinstance(data, pd.DataFrame) | ||
assert len(data) == 500 | ||
assert list(data.columns) == [ | ||
"session_id", | ||
"item_id", | ||
"day", | ||
"purchase", | ||
"price", | ||
"category", | ||
"item_recency", | ||
] | ||
expected_dtypes = { | ||
"session_id": "int64", | ||
"item_id": "int64", | ||
"day": "int64", | ||
"purchase": "int64", | ||
"price": "float64", | ||
"category": "int64", | ||
"item_recency": "float64", | ||
} | ||
|
||
assert all(val == expected_dtypes[key] for key, val in dict(data.dtypes).items()) |
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from transformers4rec.data.dataset import ParquetDataset | ||
from transformers4rec.data.testing.tabular_data.dataset import tabular_testing_data | ||
|
||
|
||
def test_tabular_testing_data(): | ||
assert isinstance(tabular_testing_data, ParquetDataset) | ||
assert tabular_testing_data.path.endswith( | ||
"transformers4rec/data/testing/tabular_data/data.parquet" | ||
) | ||
assert tabular_testing_data.schema_path.endswith( | ||
"transformers4rec/data/testing/tabular_data/schema.json" | ||
) | ||
assert len(tabular_testing_data.schema) == 11 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
from transformers4rec.data.dataset import ParquetDataset | ||
from transformers4rec.data.testing.dataset import tabular_sequence_testing_data | ||
|
||
|
||
def test_tabular_sequence_testing_data(): | ||
assert isinstance(tabular_sequence_testing_data, ParquetDataset) | ||
assert tabular_sequence_testing_data.path.endswith("transformers4rec/data/testing/data.parquet") | ||
assert tabular_sequence_testing_data.schema_path.endswith( | ||
"transformers4rec/data/testing/schema.json" | ||
) | ||
assert len(tabular_sequence_testing_data.schema) == 22 |
Oops, something went wrong.