Skip to content

Commit

Permalink
Example data (NVIDIA-Merlin#222)
Browse files Browse the repository at this point in the history
* Introducing data module to improve our examples

* Introducing data module to improve our examples

* Making use of test-data in trainer tests

* Update quick-start + add test for it

* Adding type-hint for model

* Default to json as serialization for Schema

* Updating ci

* Updating from_schema to also accept path

* Fixing test_schema

* Make json the default in Dataset

* Adding preprocessing

* Putting back code from Sara that somehow was gone after rebasing

* Updating generate_item_interactions

* Adding basic tests

* Adding test_generate_item_interactions

* Adding some tests for test_preprocessing

* Put synthetic_ecommerce_data_schema in lib

* Put synthetic_ecommerce_data_schema in lib

* Fixing failing test

* Update transformers4rec/data/synthetic.py

Co-authored-by: sararb <[email protected]>

* Update transformers4rec/data/preprocessing.py

Co-authored-by: sararb <[email protected]>

* Fixing process_clicks like suggested in PR comments

* Running black

Co-authored-by: sararb <[email protected]>
  • Loading branch information
marcromeyn and sararb authored Sep 16, 2021
1 parent b417e63 commit b05eb57
Show file tree
Hide file tree
Showing 36 changed files with 1,574 additions and 266 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ jobs:
- name: Lint with black
run: |
black --check .
- name: Lint with check-manifest
run: |
check-manifest .
- name: Lint with isort
run: |
isort -c .
Expand Down
5 changes: 5 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ repos:
rev: 3.8.4
hooks:
- id: flake8
# - repo: https://github.com/mgedmin/check-manifest
# rev: "0.46"
# hooks:
# - id: check-manifest
# args: [--ignore, "*source*"]
# - repo: https://github.com/pycqa/pylint
# rev: pylint-2.7.4
# hooks:
Expand Down
31 changes: 31 additions & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
include *.md
include *.yaml

recursive-include _images *.png
recursive-exclude docs *
recursive-exclude docs *.ipynb
exclude docs/source/_images

recursive-include requirements *.txt

recursive-exclude tests *.coveragerc
recursive-include tests *.parquet
recursive-include tests *.pbtxt

recursive-include transformers4rec *.parquet *.json *.py


# Ignore notebooks & examples
recursive-exclude examples *
exclude examples
exclude tutorial
recursive-exclude tutorial *

# Ignore build related things
recursive-exclude conda *
exclude conda
recursive-exclude .github *
exclude .github
recursive-exclude ci *
exclude ci
exclude .pylintrc
26 changes: 14 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,50 +42,52 @@ Here is the PyTorch version:
```python
from transformers4rec import torch as tr

SCHEMA_PATH = "..."
schema: tr.Schema = tr.data.tabular_sequence_testing_data.schema
# Or read schema from disk: tr.Schema().from_json(SCHEMA_PATH)
max_sequence_length, d_model = 20, 64

# Define input module to process tabular input-features
input_module = tr.TabularSequenceFeatures.from_schema(
tr.Schema().from_proto_text(SCHEMA_PATH),
schema,
max_sequence_length=max_sequence_length,
continuous_projection=d_model,
aggregation="concat",
masking="causal",
)
# Define one or multiple prediction-tasks
prediction_tasks = [tr.NextItemPredictionTask()]
prediction_tasks = tr.NextItemPredictionTask()

# Define the config of the XLNet Transformer architecture
# Define a transformer-config, like the XLNet architecture
transformer_config = tr.XLNetConfig.build(
d_model=64, n_head=4, n_layer=2, total_seq_length=max_sequence_length
d_model=d_model, n_head=4, n_layer=2, total_seq_length=max_sequence_length
)
model = transformer_config.to_torch_model(input_module, *prediction_tasks)
model: tr.Model = transformer_config.to_torch_model(input_module, prediction_tasks)
```

And here is the equivalent code for TensorFlow:
```python
from transformers4rec import tf as tr

SCHEMA_PATH = "..."
schema: tr.Schema = tr.data.tabular_sequence_testing_data.schema
# Or read schema from disk: tr.Schema().from_json(SCHEMA_PATH)
max_sequence_length, d_model = 20, 64

# Define input module to process tabular input-features
input_module = tr.TabularSequenceFeatures.from_schema(
tr.Schema().from_proto_text(SCHEMA_PATH),
schema,
max_sequence_length=max_sequence_length,
continuous_projection=d_model,
aggregation="concat",
masking="causal",
)
# Define one or multiple prediction-tasks
prediction_tasks = [tr.NextItemPredictionTask()]
prediction_tasks = tr.NextItemPredictionTask()

# Define the config of the XLNet Transformer architecture
# Define a transformer-config, like the XLNet architecture
transformer_config = tr.XLNetConfig.build(
d_model=64, n_head=4, n_layer=2, total_seq_length=max_sequence_length
d_model=d_model, n_head=4, n_layer=2, total_seq_length=max_sequence_length
)
model = transformer_config.to_tf_model(input_module, *prediction_tasks)
model: tr.Model = transformer_config.to_tf_model(input_module, prediction_tasks)
```

## When to use it?
Expand Down
8 changes: 8 additions & 0 deletions merlin_standard_lib/schema/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@


import collections
import os
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union

from ..utils import proto_utils
Expand Down Expand Up @@ -365,6 +366,13 @@ def item_id_column_name(self):

return item_id_col.column_names[0]

def from_json(self, value: Union[str, bytes]) -> "Schema":
if os.path.isfile(value):
with open(value, "rb") as f:
value = f.read()

return super().from_json(value)

def to_proto_text(self) -> str:
from tensorflow_metadata.proto.v0 import schema_pb2

Expand Down
3 changes: 3 additions & 0 deletions merlin_standard_lib/schema/tag.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,11 @@ class Tag(Enum):

# Feature context
USER = "user"
USER_ID = "user_id"
ITEM = "item"
ITEM_ID = "item_id"
SESSION = "session"
SESSION_ID = "session_id"
CONTEXT = "context"

# Target related
Expand Down
1 change: 1 addition & 0 deletions requirements/dev.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
check-manifest
pytest>=5
pytest-cov>=2
black==20.8b1
Expand Down
32 changes: 11 additions & 21 deletions tests/config/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,27 +15,18 @@
#
import pytest

from merlin_standard_lib import Schema, Tag
from merlin_standard_lib import Tag
from merlin_standard_lib.utils.embedding_utils import get_embedding_sizes_from_schema


def test_schema_from_schema(schema_file):
schema = Schema().from_proto_text(str(schema_file))
def test_schema_from_yoochoose_schema(yoochoose_schema):
assert len(yoochoose_schema.column_names) == 22
assert len(yoochoose_schema.select_by_tag(Tag.CONTINUOUS).column_schemas) == 7
assert len(yoochoose_schema.select_by_tag(Tag.CATEGORICAL).column_schemas) == 3

assert len(schema.column_names) == 18


def test_schema_from_yoochoose_schema(yoochoose_schema_file):
schema = Schema().from_proto_text(str(yoochoose_schema_file))

assert len(schema.column_names) == 22
assert len(schema.select_by_tag(Tag.CONTINUOUS).column_schemas) == 7
assert len(schema.select_by_tag(Tag.CATEGORICAL).column_schemas) == 3


def test_schema_cardinalities(yoochoose_schema_file):
schema = Schema().from_proto_text(str(yoochoose_schema_file))

def test_schema_cardinalities(yoochoose_schema):
schema = yoochoose_schema
assert schema.categorical_cardinalities() == {
"item_id/list": schema.select_by_name("item_id/list").feature[0].int_domain.max + 1,
"category/list": schema.select_by_name("category/list").feature[0].int_domain.max + 1,
Expand All @@ -44,17 +35,16 @@ def test_schema_cardinalities(yoochoose_schema_file):


@pytest.mark.skip(reason="broken")
def test_schema_embedding_sizes_nvt(yoochoose_schema_file):
def test_schema_embedding_sizes_nvt(yoochoose_schema):
pytest.importorskip("nvtabular")
schema = Schema().from_proto_text(str(yoochoose_schema_file))

schema = yoochoose_schema
assert schema.categorical_cardinalities() == {"item_id/list": 51996, "category/list": 332}
embedding_sizes = schema.embedding_sizes_nvt(minimum_size=16, maximum_size=512)
assert embedding_sizes == {"item_id/list": 512, "category/list": 41, "user_country": 16}


def test_schema_embedding_sizes(yoochoose_schema_file):
schema = Schema().from_proto_text(str(yoochoose_schema_file)).remove_by_name("session_id")
def test_schema_embedding_sizes(yoochoose_schema):
schema = yoochoose_schema.remove_by_name("session_id")

assert schema.categorical_cardinalities() == {
"category/list": 333,
Expand Down
50 changes: 13 additions & 37 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,64 +14,40 @@
# limitations under the License.
#

import pathlib

import pytest

from merlin_standard_lib import Schema

ASSETS_DIR = pathlib.Path(__file__).parent / "assets"
from transformers4rec.data import tabular_sequence_testing_data, tabular_testing_data


@pytest.fixture
def assets():
return ASSETS_DIR
def yoochoose_path_file() -> str:
return tabular_sequence_testing_data.path


@pytest.fixture
def schema_file():
return ASSETS_DIR / "schema.pbtxt"


YOOCHOOSE_SCHEMA = ASSETS_DIR / "data_schema" / "data_seq_schema.pbtxt"
YOOCHOOSE_PATH = ASSETS_DIR / "data_schema" / "data_seq.parquet"
def yoochoose_schema_file() -> str:
return tabular_sequence_testing_data.schema_path


@pytest.fixture
def yoochoose_path_file():
return YOOCHOOSE_PATH
def yoochoose_schema() -> Schema:
return tabular_sequence_testing_data.schema


@pytest.fixture
def yoochoose_schema_file():
return YOOCHOOSE_SCHEMA
def tabular_data_file() -> str:
return tabular_testing_data.path


@pytest.fixture
def yoochoose_schema():
schema = Schema().from_proto_text(str(YOOCHOOSE_SCHEMA))
return schema


TABULAR_DATA_SCHEMA = ASSETS_DIR / "data_schema" / "data_schema.pbtxt"
TABULAR_DATA_PATH = ASSETS_DIR / "data_schema" / "data.parquet"
def tabular_schema_file() -> str:
return tabular_testing_data.schema_path


@pytest.fixture
def tabular_data_file():
return TABULAR_DATA_PATH


@pytest.fixture
def tabular_schema_file():
return TABULAR_DATA_SCHEMA


@pytest.fixture
def tabular_schema():
schema = Schema().from_proto_text(str(TABULAR_DATA_SCHEMA))

return schema.remove_by_name(["session_id", "session_start", "day_idx"])
def tabular_schema() -> Schema:
return tabular_testing_data.schema.remove_by_name(["session_id", "session_start", "day_idx"])


from tests.tf.conftest import * # noqa
Expand Down
Empty file added tests/data/__init__.py
Empty file.
41 changes: 41 additions & 0 deletions tests/data/test_preprocessing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import numpy as np
import pytest

from merlin_standard_lib import ColumnSchema, Schema, Tag
from transformers4rec.data.preprocessing import (
add_item_first_seen_col_to_df,
remove_consecutive_interactions,
)
from transformers4rec.data.synthetic import (
generate_item_interactions,
synthetic_ecommerce_data_schema,
)

pd = pytest.importorskip("pandas")


def test_remove_consecutive_interactions():
np.random.seed(0)

schema = synthetic_ecommerce_data_schema.remove_by_name("item_recency")
schema += Schema([ColumnSchema.create_continuous("timestamp", tags=[Tag.SESSION])])

interactions_df = generate_item_interactions(500, schema)
filtered_df = remove_consecutive_interactions(interactions_df.copy())

assert len(filtered_df) < len(interactions_df)
assert len(filtered_df) == 499
assert len(list(filtered_df.columns)) == len(list(interactions_df.columns))


def test_add_item_first_seen_col_to_df():
schema = synthetic_ecommerce_data_schema.remove_by_name("item_recency")
schema += Schema([ColumnSchema.create_continuous("timestamp", tags=[Tag.SESSION])])

df = add_item_first_seen_col_to_df(generate_item_interactions(500, schema))

assert len(list(df.columns)) == len(schema) + 1
assert isinstance(df["item_ts_first"], pd.Series)


# TODO: Add test for session_aggregator when nvtabular 21.09 is released
35 changes: 35 additions & 0 deletions tests/data/test_synthetic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import pytest

from transformers4rec.data.synthetic import (
generate_item_interactions,
synthetic_ecommerce_data_schema,
)

pd = pytest.importorskip("pandas")


def test_generate_item_interactions():
data = generate_item_interactions(500, synthetic_ecommerce_data_schema)

assert isinstance(data, pd.DataFrame)
assert len(data) == 500
assert list(data.columns) == [
"session_id",
"item_id",
"day",
"purchase",
"price",
"category",
"item_recency",
]
expected_dtypes = {
"session_id": "int64",
"item_id": "int64",
"day": "int64",
"purchase": "int64",
"price": "float64",
"category": "int64",
"item_recency": "float64",
}

assert all(val == expected_dtypes[key] for key, val in dict(data.dtypes).items())
Empty file added tests/data/testing/__init__.py
Empty file.
Empty file.
13 changes: 13 additions & 0 deletions tests/data/testing/tabular_data/test_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from transformers4rec.data.dataset import ParquetDataset
from transformers4rec.data.testing.tabular_data.dataset import tabular_testing_data


def test_tabular_testing_data():
assert isinstance(tabular_testing_data, ParquetDataset)
assert tabular_testing_data.path.endswith(
"transformers4rec/data/testing/tabular_data/data.parquet"
)
assert tabular_testing_data.schema_path.endswith(
"transformers4rec/data/testing/tabular_data/schema.json"
)
assert len(tabular_testing_data.schema) == 11
11 changes: 11 additions & 0 deletions tests/data/testing/test_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from transformers4rec.data.dataset import ParquetDataset
from transformers4rec.data.testing.dataset import tabular_sequence_testing_data


def test_tabular_sequence_testing_data():
assert isinstance(tabular_sequence_testing_data, ParquetDataset)
assert tabular_sequence_testing_data.path.endswith("transformers4rec/data/testing/data.parquet")
assert tabular_sequence_testing_data.schema_path.endswith(
"transformers4rec/data/testing/schema.json"
)
assert len(tabular_sequence_testing_data.schema) == 22
Loading

0 comments on commit b05eb57

Please sign in to comment.