Skip to content

Commit

Permalink
add exclude self option for EdgeDataLoader (dmlc#3122)
Browse files Browse the repository at this point in the history
Co-authored-by: Minjie Wang <[email protected]>
  • Loading branch information
BarclayII and jermainewang authored Jul 13, 2021
1 parent b576e61 commit 2e19ba8
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 0 deletions.
7 changes: 7 additions & 0 deletions python/dgl/dataloading/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ def _find_exclude_eids(g, exclude_mode, eids, **kwargs):
None (default)
Does not exclude any edge.
'self'
Exclude the given edges themselves but nothing else.
'reverse_id'
Exclude all edges specified in ``eids``, as well as their reverse edges
of the same edge type.
Expand Down Expand Up @@ -105,6 +108,8 @@ def _find_exclude_eids(g, exclude_mode, eids, **kwargs):
"""
if exclude_mode is None:
return None
elif exclude_mode == 'self':
return eids
elif exclude_mode == 'reverse_id':
return _find_exclude_eids_with_reverse_id(g, eids, kwargs['reverse_eid_map'])
elif exclude_mode == 'reverse_types':
Expand Down Expand Up @@ -493,6 +498,8 @@ class EdgeCollator(Collator):
* None, which excludes nothing.
* ``'self'``, which excludes the sampled edges themselves but nothing else.
* ``'reverse_id'``, which excludes the reverse edges of the sampled edges. The said
reverse edges have the same edge type as the sampled edges. Only works
on edge types whose source node type is the same as its destination node type.
Expand Down
1 change: 1 addition & 0 deletions python/dgl/dataloading/pytorch/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,7 @@ class EdgeDataLoader:
minibatch. Possible values are
* None,
* ``self``,
* ``reverse_id``,
* ``reverse_types``
Expand Down
12 changes: 12 additions & 0 deletions tests/pytorch/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,12 @@ def test_neighbor_sampler_dataloader():
nids.append({'follow': seeds})
modes.append('edge')

collators.append(dgl.dataloading.EdgeCollator(
g, seeds, sampler, exclude='self'))
graphs.append(g)
nids.append({'follow': seeds})
modes.append('edge')

collators.append(dgl.dataloading.EdgeCollator(
g, seeds, sampler, exclude='reverse_id', reverse_eids=reverse_eids))
graphs.append(g)
Expand All @@ -133,6 +139,12 @@ def test_neighbor_sampler_dataloader():
nids.append({'follow': seeds})
modes.append('link')

collators.append(dgl.dataloading.EdgeCollator(
g, seeds, sampler, exclude='self', negative_sampler=dgl.dataloading.negative_sampler.Uniform(2)))
graphs.append(g)
nids.append({'follow': seeds})
modes.append('link')

collators.append(dgl.dataloading.EdgeCollator(
g, seeds, sampler, exclude='reverse_id', reverse_eids=reverse_eids,
negative_sampler=dgl.dataloading.negative_sampler.Uniform(2)))
Expand Down

0 comments on commit 2e19ba8

Please sign in to comment.