Skip to content

Commit

Permalink
Fix tgn example (dmlc#3543)
Browse files Browse the repository at this point in the history
  • Loading branch information
VoVAllen authored Nov 29, 2021
1 parent 03c2c6d commit da53275
Showing 1 changed file with 21 additions and 20 deletions.
41 changes: 21 additions & 20 deletions examples/pytorch/tgn/dataloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from dgl.dataloading.dataloader import EdgeCollator
from dgl.dataloading import BlockSampler
from dgl.dataloading.pytorch import _pop_subgraph_storage, _pop_blocks_storage
from dgl.dataloading.pytorch import _pop_subgraph_storage, _pop_storages
from dgl.base import DGLError

from functools import partial
Expand Down Expand Up @@ -113,7 +113,7 @@ class TemporalEdgeCollator(EdgeCollator):
eids : Tensor or dict[etype, Tensor]
The edge set in graph :attr:`g` to compute outputs.
block_sampler : dgl.dataloading.BlockSampler
graph_sampler : dgl.dataloading.BlockSampler
The neighborhood sampler.
g_sampling : DGLGraph, optional
Expand Down Expand Up @@ -203,7 +203,7 @@ def _collate_with_negative_sampling(self, items):
for i, edge in enumerate(zip(self.g.edges()[0][items], self.g.edges()[1][items])):
ts = pair_graph.edata['timestamp'][i]
timestamps.append(ts)
subg = self.block_sampler.sample_blocks(self.g_sampling,
subg = self.graph_sampler.sample_blocks(self.g_sampling,
list(edge),
timestamp=ts)[0]
subg.ndata['timestamp'] = ts.repeat(subg.num_nodes())
Expand All @@ -213,7 +213,7 @@ def _collate_with_negative_sampling(self, items):
self.negative_sampler.k)
for i, neg_edge in enumerate(zip(neg_srcdst_raw[0].tolist(), neg_srcdst_raw[1].tolist())):
ts = timestamps[i]
subg = self.block_sampler.sample_blocks(self.g_sampling,
subg = self.graph_sampler.sample_blocks(self.g_sampling,
[neg_edge[1]],
timestamp=ts)[0]
subg.ndata['timestamp'] = ts.repeat(subg.num_nodes())
Expand All @@ -230,7 +230,7 @@ def collator(self, items):
# Copy the feature from parent graph
_pop_subgraph_storage(result[1], self.g)
_pop_subgraph_storage(result[2], self.g)
_pop_blocks_storage(result[-1], self.g_sampling)
_pop_storages(result[-1], self.g_sampling)
return result


Expand All @@ -248,7 +248,7 @@ class TemporalEdgeDataLoader(dgl.dataloading.EdgeDataLoader):
eids : torch.tensor() or numpy array
eids range which to be batched, it is useful to split training validation test dataset
block_sampler : dgl.dataloading.BlockSampler
graph_sampler : dgl.dataloading.BlockSampler
temporal neighbor sampler which sample temporal and computationally depend blocks for computation
device : str
Expand All @@ -264,15 +264,16 @@ class TemporalEdgeDataLoader(dgl.dataloading.EdgeDataLoader):
"""

def __init__(self, g, eids, block_sampler, device='cpu', collator=TemporalEdgeCollator, **kwargs):
def __init__(self, g, eids, graph_sampler, device='cpu', collator=TemporalEdgeCollator, **kwargs):
super().__init__(g, eids, graph_sampler, device, **kwargs)
collator_kwargs = {}
dataloader_kwargs = {}
for k, v in kwargs.items():
if k in self.collator_arglist:
collator_kwargs[k] = v
else:
dataloader_kwargs[k] = v
self.collator = collator(g, eids, block_sampler, **collator_kwargs)
self.collator = collator(g, eids, graph_sampler, **collator_kwargs)

assert not isinstance(g, dgl.distributed.DistGraph), \
'EdgeDataLoader does not support DistGraph for now. ' \
Expand Down Expand Up @@ -485,7 +486,7 @@ class FastTemporalEdgeCollator(EdgeCollator):
eids : Tensor or dict[etype, Tensor]
The edge set in graph :attr:`g` to compute outputs.
block_sampler : dgl.dataloading.BlockSampler
graph_sampler : dgl.dataloading.BlockSampler
The neighborhood sampler.
g_sampling : DGLGraph, optional
Expand Down Expand Up @@ -570,23 +571,23 @@ def _collate_with_negative_sampling(self, items):
pair_graph.edata[dgl.EID] = induced_edges

seed_nodes = pair_graph.ndata[dgl.NID]
blocks = self.block_sampler.sample_blocks(self.g_sampling, seed_nodes)
blocks = self.graph_sampler.sample_blocks(self.g_sampling, seed_nodes)
blocks[0].ndata['timestamp'] = torch.zeros(
blocks[0].num_nodes()).double()
input_nodes = blocks[0].edges()[1]

# update sampler
_src = self.g.nodes()[self.g.edges()[0][items]]
_dst = self.g.nodes()[self.g.edges()[1][items]]
self.block_sampler.add_edges(_src, _dst)
self.graph_sampler.add_edges(_src, _dst)
return input_nodes, pair_graph, neg_pair_graph, blocks

def collator(self, items):
result = super().collate(items)
# Copy the feature from parent graph
_pop_subgraph_storage(result[1], self.g)
_pop_subgraph_storage(result[2], self.g)
_pop_blocks_storage(result[-1], self.g_sampling)
_pop_storages(result[-1], self.g_sampling)
return result


Expand Down Expand Up @@ -649,7 +650,7 @@ class SimpleTemporalEdgeCollator(dgl.dataloading.EdgeCollator):
eids : Tensor or dict[etype, Tensor]
The edge set in graph :attr:`g` to compute outputs.
block_sampler : dgl.dataloading.BlockSampler
graph_sampler : dgl.dataloading.BlockSampler
The neighborhood sampler.
g_sampling : DGLGraph, optional
Expand Down Expand Up @@ -701,19 +702,19 @@ class SimpleTemporalEdgeCollator(dgl.dataloading.EdgeCollator):
A set of builtin negative samplers are provided in
:ref:`the negative sampling module <api-dataloading-negative-sampling>`.
'''
def __init__(self, g, eids, block_sampler, g_sampling=None, exclude=None,
def __init__(self, g, eids, graph_sampler, g_sampling=None, exclude=None,
reverse_eids=None, reverse_etypes=None, negative_sampler=None):
super(SimpleTemporalEdgeCollator,self).__init__(g,eids,block_sampler,
g_sampling,exclude,reverse_eids,reverse_etypes,negative_sampler)
self.n_layer = len(self.block_sampler.fanouts)
super(SimpleTemporalEdgeCollator, self).__init__(g, eids, graph_sampler,
g_sampling, exclude, reverse_eids, reverse_etypes, negative_sampler)
self.n_layer = len(self.graph_sampler.fanouts)

def collate(self,items):
'''
items: edge id in graph g.
We sample iteratively k-times and batch them into one single subgraph.
'''
current_ts = self.g.edata['timestamp'][items[0]] #only sample edges before current timestamp
self.block_sampler.ts = current_ts # restore the current timestamp to the graph sampler.
self.graph_sampler.ts = current_ts # restore the current timestamp to the graph sampler.

# if link prefiction, we use a negative_sampler to generate neg-graph for loss computing.
if self.negative_sampler is None:
Expand All @@ -724,8 +725,8 @@ def collate(self,items):

# we sampling k-hop subgraph and batch them into one graph
for i in range(self.n_layer-1):
self.block_sampler.frontiers[0].add_edges(*self.block_sampler.frontiers[i+1].edges())
frontier = self.block_sampler.frontiers[0]
self.graph_sampler.frontiers[0].add_edges(*self.graph_sampler.frontiers[i+1].edges())
frontier = self.graph_sampler.frontiers[0]
# computing node last-update timestamp
frontier.update_all(fn.copy_e('timestamp','ts'), fn.max('ts','timestamp'))

Expand Down

0 comments on commit da53275

Please sign in to comment.