diff --git a/examples/pytorch/tgn/dataloading.py b/examples/pytorch/tgn/dataloading.py index 09af54795757..ee37c5f60b6c 100644 --- a/examples/pytorch/tgn/dataloading.py +++ b/examples/pytorch/tgn/dataloading.py @@ -3,7 +3,7 @@ from dgl.dataloading.dataloader import EdgeCollator from dgl.dataloading import BlockSampler -from dgl.dataloading.pytorch import _pop_subgraph_storage, _pop_blocks_storage +from dgl.dataloading.pytorch import _pop_subgraph_storage, _pop_storages from dgl.base import DGLError from functools import partial @@ -113,7 +113,7 @@ class TemporalEdgeCollator(EdgeCollator): eids : Tensor or dict[etype, Tensor] The edge set in graph :attr:`g` to compute outputs. - block_sampler : dgl.dataloading.BlockSampler + graph_sampler : dgl.dataloading.BlockSampler The neighborhood sampler. g_sampling : DGLGraph, optional @@ -203,7 +203,7 @@ def _collate_with_negative_sampling(self, items): for i, edge in enumerate(zip(self.g.edges()[0][items], self.g.edges()[1][items])): ts = pair_graph.edata['timestamp'][i] timestamps.append(ts) - subg = self.block_sampler.sample_blocks(self.g_sampling, + subg = self.graph_sampler.sample_blocks(self.g_sampling, list(edge), timestamp=ts)[0] subg.ndata['timestamp'] = ts.repeat(subg.num_nodes()) @@ -213,7 +213,7 @@ def _collate_with_negative_sampling(self, items): self.negative_sampler.k) for i, neg_edge in enumerate(zip(neg_srcdst_raw[0].tolist(), neg_srcdst_raw[1].tolist())): ts = timestamps[i] - subg = self.block_sampler.sample_blocks(self.g_sampling, + subg = self.graph_sampler.sample_blocks(self.g_sampling, [neg_edge[1]], timestamp=ts)[0] subg.ndata['timestamp'] = ts.repeat(subg.num_nodes()) @@ -230,7 +230,7 @@ def collator(self, items): # Copy the feature from parent graph _pop_subgraph_storage(result[1], self.g) _pop_subgraph_storage(result[2], self.g) - _pop_blocks_storage(result[-1], self.g_sampling) + _pop_storages(result[-1], self.g_sampling) return result @@ -248,7 +248,7 @@ class TemporalEdgeDataLoader(dgl.dataloading.EdgeDataLoader): eids : torch.tensor() or numpy array eids range which to be batched, it is useful to split training validation test dataset - block_sampler : dgl.dataloading.BlockSampler + graph_sampler : dgl.dataloading.BlockSampler temporal neighbor sampler which sample temporal and computationally depend blocks for computation device : str @@ -264,7 +264,8 @@ class TemporalEdgeDataLoader(dgl.dataloading.EdgeDataLoader): """ - def __init__(self, g, eids, block_sampler, device='cpu', collator=TemporalEdgeCollator, **kwargs): + def __init__(self, g, eids, graph_sampler, device='cpu', collator=TemporalEdgeCollator, **kwargs): + super().__init__(g, eids, graph_sampler, device, **kwargs) collator_kwargs = {} dataloader_kwargs = {} for k, v in kwargs.items(): @@ -272,7 +273,7 @@ def __init__(self, g, eids, block_sampler, device='cpu', collator=TemporalEdgeCo collator_kwargs[k] = v else: dataloader_kwargs[k] = v - self.collator = collator(g, eids, block_sampler, **collator_kwargs) + self.collator = collator(g, eids, graph_sampler, **collator_kwargs) assert not isinstance(g, dgl.distributed.DistGraph), \ 'EdgeDataLoader does not support DistGraph for now. ' \ @@ -485,7 +486,7 @@ class FastTemporalEdgeCollator(EdgeCollator): eids : Tensor or dict[etype, Tensor] The edge set in graph :attr:`g` to compute outputs. - block_sampler : dgl.dataloading.BlockSampler + graph_sampler : dgl.dataloading.BlockSampler The neighborhood sampler. g_sampling : DGLGraph, optional @@ -570,7 +571,7 @@ def _collate_with_negative_sampling(self, items): pair_graph.edata[dgl.EID] = induced_edges seed_nodes = pair_graph.ndata[dgl.NID] - blocks = self.block_sampler.sample_blocks(self.g_sampling, seed_nodes) + blocks = self.graph_sampler.sample_blocks(self.g_sampling, seed_nodes) blocks[0].ndata['timestamp'] = torch.zeros( blocks[0].num_nodes()).double() input_nodes = blocks[0].edges()[1] @@ -578,7 +579,7 @@ def _collate_with_negative_sampling(self, items): # update sampler _src = self.g.nodes()[self.g.edges()[0][items]] _dst = self.g.nodes()[self.g.edges()[1][items]] - self.block_sampler.add_edges(_src, _dst) + self.graph_sampler.add_edges(_src, _dst) return input_nodes, pair_graph, neg_pair_graph, blocks def collator(self, items): @@ -586,7 +587,7 @@ def collator(self, items): # Copy the feature from parent graph _pop_subgraph_storage(result[1], self.g) _pop_subgraph_storage(result[2], self.g) - _pop_blocks_storage(result[-1], self.g_sampling) + _pop_storages(result[-1], self.g_sampling) return result @@ -649,7 +650,7 @@ class SimpleTemporalEdgeCollator(dgl.dataloading.EdgeCollator): eids : Tensor or dict[etype, Tensor] The edge set in graph :attr:`g` to compute outputs. - block_sampler : dgl.dataloading.BlockSampler + graph_sampler : dgl.dataloading.BlockSampler The neighborhood sampler. g_sampling : DGLGraph, optional @@ -701,11 +702,11 @@ class SimpleTemporalEdgeCollator(dgl.dataloading.EdgeCollator): A set of builtin negative samplers are provided in :ref:`the negative sampling module `. ''' - def __init__(self, g, eids, block_sampler, g_sampling=None, exclude=None, + def __init__(self, g, eids, graph_sampler, g_sampling=None, exclude=None, reverse_eids=None, reverse_etypes=None, negative_sampler=None): - super(SimpleTemporalEdgeCollator,self).__init__(g,eids,block_sampler, - g_sampling,exclude,reverse_eids,reverse_etypes,negative_sampler) - self.n_layer = len(self.block_sampler.fanouts) + super(SimpleTemporalEdgeCollator, self).__init__(g, eids, graph_sampler, + g_sampling, exclude, reverse_eids, reverse_etypes, negative_sampler) + self.n_layer = len(self.graph_sampler.fanouts) def collate(self,items): ''' @@ -713,7 +714,7 @@ def collate(self,items): We sample iteratively k-times and batch them into one single subgraph. ''' current_ts = self.g.edata['timestamp'][items[0]] #only sample edges before current timestamp - self.block_sampler.ts = current_ts # restore the current timestamp to the graph sampler. + self.graph_sampler.ts = current_ts # restore the current timestamp to the graph sampler. # if link prefiction, we use a negative_sampler to generate neg-graph for loss computing. if self.negative_sampler is None: @@ -724,8 +725,8 @@ def collate(self,items): # we sampling k-hop subgraph and batch them into one graph for i in range(self.n_layer-1): - self.block_sampler.frontiers[0].add_edges(*self.block_sampler.frontiers[i+1].edges()) - frontier = self.block_sampler.frontiers[0] + self.graph_sampler.frontiers[0].add_edges(*self.graph_sampler.frontiers[i+1].edges()) + frontier = self.graph_sampler.frontiers[0] # computing node last-update timestamp frontier.update_all(fn.copy_e('timestamp','ts'), fn.max('ts','timestamp'))