Skip to content

Commit

Permalink
[DistGB] exclude edges right after sampling (dmlc#7442)
Browse files Browse the repository at this point in the history
  • Loading branch information
Rhett-Ying authored Jun 3, 2024
1 parent 70e6312 commit 0fee9be
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 28 deletions.
12 changes: 7 additions & 5 deletions examples/distributed/rgcn/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -237,9 +237,11 @@ Compared to `DGL`, `GraphBolt` partitions are reduced to **72%** for `ogbn-mag`.

#### ogbn-mag

Compared to `DGL`, sampling with `GraphBolt` is reduced to **22%** for `ogbn-mag`.
Compared to `DGL`, sampling with `GraphBolt` is reduced to **15%**. As for the overhead of `exclude`, it's about **5%** in this test. This number could be higher if larger `fanout` or `batch size` is applied.

| Data Formats | Mean Sampling Time Per Iteration(50 iters in total, slowest rank)(seconds) |
| ------------ | ----------------------------------------------------------- |
| DGL | 7.49 |
| GraphBolt | 1.63 |
The time shown below is the mean sampling time per iteration(60 iters in total, slowest rank). Unit: seconds

| Data Formats | No Exclude | Exclude |
| ------------ | ---------- | ------- |
| DGL | 6.50 | 6.86 |
| GraphBolt | 0.95 | 1.00 |
25 changes: 12 additions & 13 deletions python/dgl/distributed/dist_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@

import os
from collections import namedtuple
from collections.abc import MutableMapping
from collections.abc import Mapping, MutableMapping

import numpy as np
import torch

from .. import backend as F, graphbolt as gb, heterograph_index
from .._ffi.ndarray import empty_shared_mem
Expand All @@ -16,7 +17,6 @@

from ..heterograph import DGLGraph
from ..ndarray import exist_shared_mem_array
from ..sampling.utils import EidExcluder
from ..transforms import compact_graphs
from . import graph_services, role, rpc
from .dist_tensor import DistTensor
Expand Down Expand Up @@ -1419,6 +1419,14 @@ def sample_neighbors(
):
# pylint: disable=unused-argument
"""Sample neighbors from a distributed graph."""
if exclude_edges is not None:
# Convert exclude edge IDs to homogeneous edge IDs.
gpb = self.get_partition_book()
if isinstance(exclude_edges, Mapping):
exclude_eids = []
for c_etype, eids in exclude_edges.items():
exclude_eids.append(gpb.map_to_homo_eid(eids, c_etype))
exclude_edges = torch.cat(exclude_eids)
if len(self.etypes) > 1:
frontier = graph_services.sample_etype_neighbors(
self,
Expand All @@ -1427,7 +1435,7 @@ def sample_neighbors(
replace=replace,
etype_sorted=etype_sorted,
prob=prob,
exclude_edges=None,
exclude_edges=exclude_edges,
use_graphbolt=self._use_graphbolt,
)
else:
Expand All @@ -1437,18 +1445,9 @@ def sample_neighbors(
fanout,
replace=replace,
prob=prob,
exclude_edges=None,
exclude_edges=exclude_edges,
use_graphbolt=self._use_graphbolt,
)
# [TODO][Rui]
# For now, exclude_edges is applied after sampling. Namely, we first sample
# the neighbors and then exclude the edges before returning frontier. This
# is probably not efficient. We could try to exclude the edges during
# sampling. Or we pass exclude_edges IDs to local and remote sampling
# functions and let them handle the exclusion.
if exclude_edges is not None:
eid_excluder = EidExcluder(exclude_edges)
frontier = eid_excluder(frontier)
return frontier

def _get_ndata_names(self, ntype=None):
Expand Down
39 changes: 29 additions & 10 deletions python/dgl/distributed/graph_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,7 +688,7 @@ def process_request(self, server_state):
return SubgraphResponse(global_src, global_dst, global_eids=global_eids)


def merge_graphs(res_list, num_nodes):
def merge_graphs(res_list, num_nodes, exclude_edges=None):
"""Merge request from multiple servers"""
if len(res_list) > 1:
srcs = []
Expand All @@ -709,6 +709,15 @@ def merge_graphs(res_list, num_nodes):
dst_tensor = res_list[0].global_dst
eid_tensor = res_list[0].global_eids
etype_id_tensor = res_list[0].etype_ids
if exclude_edges is not None:
mask = torch.isin(
eid_tensor, exclude_edges, assume_unique=True, invert=True
)
src_tensor = src_tensor[mask]
dst_tensor = dst_tensor[mask]
eid_tensor = eid_tensor[mask]
if etype_id_tensor is not None:
etype_id_tensor = etype_id_tensor[mask]
g = graph((src_tensor, dst_tensor), num_nodes=num_nodes)
if eid_tensor is not None:
g.edata[EID] = eid_tensor
Expand All @@ -724,7 +733,9 @@ def merge_graphs(res_list, num_nodes):
)


def _distributed_access(g, nodes, issue_remote_req, local_access):
def _distributed_access(
g, nodes, issue_remote_req, local_access, exclude_edges=None
):
"""A routine that fetches local neighborhood of nodes from the distributed graph.
The local neighborhood of some nodes are stored in the local machine and the other
Expand All @@ -743,6 +754,8 @@ def _distributed_access(g, nodes, issue_remote_req, local_access):
The function that issues requests to access remote data.
local_access : callable
The function that reads data on the local machine.
exclude_edges : tensor
The edges to exclude after sampling.
Returns
-------
Expand Down Expand Up @@ -784,7 +797,9 @@ def _distributed_access(g, nodes, issue_remote_req, local_access):
results = recv_responses(msgseq2pos)
res_list.extend(results)

sampled_graph = merge_graphs(res_list, g.num_nodes())
sampled_graph = merge_graphs(
res_list, g.num_nodes(), exclude_edges=exclude_edges
)
return sampled_graph


Expand Down Expand Up @@ -906,7 +921,7 @@ def sample_etype_neighbors(
inbound/outbound edges for every node must be positive (though they don't have
to sum up to one). Otherwise, the result will be undefined.
exclude_edges : tensor, optional
The edges to exclude when sampling.
The edges to exclude when sampling. Homogeneous edge IDs are used.
replace : bool, optional
If True, sample with replacement.
Expand Down Expand Up @@ -975,7 +990,7 @@ def issue_remote_req(node_ids):
fanout,
edge_dir=edge_dir,
prob=_prob,
exclude_edges=exclude_edges,
exclude_edges=None,
replace=replace,
etype_sorted=etype_sorted,
use_graphbolt=use_graphbolt,
Expand Down Expand Up @@ -1003,13 +1018,15 @@ def local_access(local_g, partition_book, local_nids):
fanout,
edge_dir=edge_dir,
prob=_prob,
exclude_edges=exclude_edges,
exclude_edges=None,
replace=replace,
etype_offset=etype_offset,
etype_sorted=etype_sorted,
)

frontier = _distributed_access(g, nodes, issue_remote_req, local_access)
frontier = _distributed_access(
g, nodes, issue_remote_req, local_access, exclude_edges=exclude_edges
)
if not gpb.is_homogeneous:
return _frontier_to_heterogeneous_graph(g, frontier, gpb)
else:
Expand Down Expand Up @@ -1111,7 +1128,7 @@ def issue_remote_req(node_ids):
fanout,
edge_dir=edge_dir,
prob=_prob,
exclude_edges=exclude_edges,
exclude_edges=None,
replace=replace,
use_graphbolt=use_graphbolt,
)
Expand All @@ -1127,11 +1144,13 @@ def local_access(local_g, partition_book, local_nids):
fanout,
edge_dir=edge_dir,
prob=_prob,
exclude_edges=exclude_edges,
exclude_edges=None,
replace=replace,
)

frontier = _distributed_access(g, nodes, issue_remote_req, local_access)
frontier = _distributed_access(
g, nodes, issue_remote_req, local_access, exclude_edges=exclude_edges
)
if not gpb.is_homogeneous:
return _frontier_to_heterogeneous_graph(g, frontier, gpb)
else:
Expand Down

0 comments on commit 0fee9be

Please sign in to comment.