Skip to content

Commit

Permalink
[Dataset] Add hash function to distinguish instances from the same da…
Browse files Browse the repository at this point in the history
…taset class (dmlc#1894)

* add hash function in base data classes

* update doc
  • Loading branch information
hetong007 authored Jul 31, 2020
1 parent 71c7ee0 commit fb02aa2
Showing 1 changed file with 32 additions and 2 deletions.
34 changes: 32 additions & 2 deletions python/dgl/data/dgl_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,24 @@ class DGLDataset(object):
save_dir : str
Directory to save the processed dataset.
Default: same as raw_dir
hash_key : tuple
A tuple of values as the input for the hash function.
Users can distinguish instances (and their caches on the disk)
from the same dataset class by comparing the hash values.
Default: (), the corresponding hash value is 3527539
force_reload : bool
Whether to reload the dataset. Default: False
verbose : bool
Whether to print out progress information
"""
def __init__(self, name, url=None, raw_dir=None, save_dir=None, force_reload=False, verbose=False):
def __init__(self, name, url=None, raw_dir=None, save_dir=None,
hash_key=(), force_reload=False, verbose=False):
self._name = name
self._url = url
self._force_reload = force_reload
self._verbose = verbose
self._hash_key = hask_key
self._hash = self._get_hash()

# if no dir is provided, the default dgl download dir is used.
if raw_dir is None:
Expand Down Expand Up @@ -148,6 +156,17 @@ def _load(self):
if self.verbose:
print('Done saving data into cached files.')

def _get_hash(self):
"""Compute the hash of the input tuple
Example
-------
>>> hash_value = self._get_hash((10, False, True))
>>> hash_value
6299899980521991026
"""
return abs(hash(self._hash_key))

@property
def url(self):
r"""Get url to download the raw dataset.
Expand Down Expand Up @@ -191,6 +210,12 @@ def verbose(self):
"""
return self._verbose

@property
def hash(self):
r"""Hash value for the dataset.
"""
return self._hash

@abc.abstractmethod
def __getitem__(self, idx):
r"""Gets the data object at index.
Expand All @@ -215,16 +240,21 @@ class DGLBuiltinDataset(DGLDataset):
downloaded data or the directory that
already stores the input data.
Default: ~/.dgl/
hash_key : tuple
A tuple of values as the input for the hash function.
Users can distinguish instances (and their caches on the disk)
from the same dataset class by comparing the hash values.
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
Whether to print out progress information. Default: False
"""
def __init__(self, name, url, raw_dir=None, force_reload=False, verbose=False):
def __init__(self, name, url, raw_dir=None, hash_key=(), force_reload=False, verbose=False):
super(DGLBuiltinDataset, self).__init__(name,
url=url,
raw_dir=raw_dir,
save_dir=None,
hash_key=hash_key,
force_reload=force_reload,
verbose=verbose)

Expand Down

0 comments on commit fb02aa2

Please sign in to comment.