Skip to content

Commit

Permalink
[GraphBolt] add pydantic-based metadata for TVT (dmlc#5942)
Browse files Browse the repository at this point in the history
  • Loading branch information
Rhett-Ying authored Jul 4, 2023
1 parent 55af15d commit 39890c0
Show file tree
Hide file tree
Showing 6 changed files with 136 additions and 2 deletions.
91 changes: 90 additions & 1 deletion python/dgl/graphbolt/dataset.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
"""GraphBolt Dataset."""

from typing import List, Optional

import pydantic
import pydantic_yaml

from .feature_store import FeatureStore
from .itemset import ItemSet, ItemSetDict

__all__ = ["Dataset"]
__all__ = ["Dataset", "OnDiskDataset"]


class Dataset:
Expand Down Expand Up @@ -48,3 +53,87 @@ def graph(self) -> object:
def feature(self) -> FeatureStore:
"""Return the feature."""
raise NotImplementedError


class OnDiskDataFormatEnum(pydantic_yaml.YamlStrEnum):
"""Enum of data format."""

TORCH = "torch"
NUMPY = "numpy"


class OnDiskTVTSet(pydantic.BaseModel):
"""Train-Validation-Test set."""

type_name: str
format: OnDiskDataFormatEnum
path: str


class OnDiskMetaData(pydantic_yaml.YamlModel):
"""Metadata specification in YAML.
As multiple node/edge types and multiple splits are supported, each TVT set
is a list of list of ``OnDiskTVTSet``.
"""

train_set: Optional[List[List[OnDiskTVTSet]]]
validation_set: Optional[List[List[OnDiskTVTSet]]]
test_set: Optional[List[List[OnDiskTVTSet]]]


class OnDiskDataset(Dataset):
"""An on-disk dataset.
An on-disk dataset is a dataset which reads graph topology, feature data
and TVT set from disk. Due to limited resources, the data which are too
large to fit into RAM will remain on disk while others reside in RAM once
``OnDiskDataset`` is initialized. This behavior could be controled by user
via ``in_memory`` field in YAML file.
A full example of YAML file is as follows:
.. code-block:: yaml
train_set:
- - type_name: paper
format: numpy
path: set/paper-train.npy
validation_set:
- - type_name: paper
format: numpy
path: set/paper-validation.npy
test_set:
- - type_name: paper
format: numpy
path: set/paper-test.npy
Parameters
----------
path: str
The YAML file path.
"""

def __init__(self, path: str) -> None:
with open(path, "r") as f:
self._meta = OnDiskMetaData.parse_raw(f.read(), proto="yaml")

def train_set(self) -> ItemSet or ItemSetDict:
"""Return the training set."""
raise NotImplementedError

def validation_set(self) -> ItemSet or ItemSetDict:
"""Return the validation set."""
raise NotImplementedError

def test_set(self) -> ItemSet or ItemSetDict:
"""Return the test set."""
raise NotImplementedError

def graph(self) -> object:
"""Return the graph."""
raise NotImplementedError

def feature(self) -> FeatureStore:
"""Return the feature."""
raise NotImplementedError
2 changes: 2 additions & 0 deletions script/dgl_dev.yml.template
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ dependencies:
- psutil
- pyarrow
- pydantic
- pydantic-yaml
- pytest
- pyyaml
- rdflib
Expand All @@ -40,5 +41,6 @@ dependencies:
- pillow
- seaborn
- jupyter_http_over_ws
- ufmt
variables:
DGL_HOME: __DGL_HOME__
37 changes: 37 additions & 0 deletions tests/python/pytorch/graphbolt/test_dataset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
import os
import tempfile

import pydantic
import pytest
from dgl import graphbolt as gb

Expand All @@ -14,3 +18,36 @@ def test_Dataset():
_ = dataset.graph()
with pytest.raises(NotImplementedError):
_ = dataset.feature()


def test_OnDiskDataset_TVTSet():
"""Test OnDiskDataset with TVTSet."""
with tempfile.TemporaryDirectory() as test_dir:
yaml_content = """
train_set:
- - type_name: paper
format: torch
path: set/paper-train.pt
- type_name: 'paper:cites:paper'
format: numpy
path: set/cites-train.pt
"""
yaml_file = os.path.join(test_dir, "test.yaml")
with open(yaml_file, "w") as f:
f.write(yaml_content)
_ = gb.OnDiskDataset(yaml_file)

# Invalid format.
yaml_content = """
train_set:
- - type_name: paper
format: torch_invalid
path: set/paper-train.pt
- type_name: 'paper:cites:paper'
format: numpy_invalid
path: set/cites-train.pt
"""
with open(yaml_file, "w") as f:
f.write(yaml_content)
with pytest.raises(pydantic.ValidationError):
_ = gb.OnDiskDataset(yaml_file)
3 changes: 3 additions & 0 deletions tests/scripts/task_distributed_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ export PYTHONUNBUFFERED=1
export OMP_NUM_THREADS=1
export DMLC_LOG_DEBUG=1

# Install required dependencies
python3 -m pip install pydantic-yaml

python3 -m pytest -v --capture=tee-sys --junitxml=pytest_distributed.xml --durations=100 tests/distributed/*.py || fail "distributed"

PYTHONPATH=tools:tools/distpartitioning:$PYTHONPATH python3 -m pytest -v --capture=tee-sys --junitxml=pytest_tools.xml --durations=100 tests/tools/*.py || fail "tools"
2 changes: 1 addition & 1 deletion tests/scripts/task_unit_test.bat
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ SET DGLBACKEND=!BACKEND!
SET DGL_LIBRARY_PATH=!CD!\build
SET DGL_DOWNLOAD_DIR=!CD!

python -m pip install pytest psutil pandas pyyaml pydantic rdflib torchmetrics || EXIT /B 1
python -m pip install pytest psutil pandas pyyaml pydantic pydantic-yaml rdflib torchmetrics || EXIT /B 1
python -m pytest -v --junitxml=pytest_backend.xml --durations=100 tests\python\!DGLBACKEND! || EXIT /B 1
python -m pytest -v --junitxml=pytest_common.xml --durations=100 tests\python\common || EXIT /B 1
ENDLOCAL
Expand Down
3 changes: 3 additions & 0 deletions tests/scripts/task_unit_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ fi

conda activate ${DGLBACKEND}-ci

# Install required dependencies
python3 -m pip install pydantic-yaml

if [ $DGLBACKEND == "mxnet" ]
then
python3 -m pytest -v --junitxml=pytest_compute.xml --durations=100 --ignore=tests/python/common/test_ffi.py tests/python/common || fail "common"
Expand Down

0 comments on commit 39890c0

Please sign in to comment.