Skip to content

Commit

Permalink
[DistGB] enable sample etype neighbors on heterograph (dmlc#7095)
Browse files Browse the repository at this point in the history
  • Loading branch information
Rhett-Ying authored Feb 9, 2024
1 parent 3ebdee7 commit 6735a3a
Show file tree
Hide file tree
Showing 2 changed files with 234 additions and 50 deletions.
79 changes: 63 additions & 16 deletions python/dgl/distributed/graph_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,6 @@ def _sample_neighbors_graphbolt(
if isinstance(fanout, int):
fanout = torch.LongTensor([fanout])
assert isinstance(fanout, torch.Tensor), "Expect a tensor of fanout."
# [Rui][TODO] Support multiple fanouts.
assert fanout.numel() == 1, "Expect a single fanout."

return_eids = g.edge_attributes is not None and EID in g.edge_attributes
subgraph = g._sample_neighbors(nodes, fanout, return_eids=return_eids)
Expand Down Expand Up @@ -237,15 +235,15 @@ def _sample_neighbors(use_graphbolt, *args, **kwargs):
return func(*args, **kwargs)


def _sample_etype_neighbors(
def _sample_etype_neighbors_dgl(
local_g,
partition_book,
seed_nodes,
etype_offset,
fan_out,
edge_dir,
prob,
replace,
edge_dir="in",
prob=None,
replace=False,
etype_offset=None,
etype_sorted=False,
):
"""Sample from local partition.
Expand All @@ -255,6 +253,8 @@ def _sample_etype_neighbors(
The sampled results are stored in three vectors that store source nodes, destination nodes
and edge IDs.
"""
assert etype_offset is not None, "The etype offset is not provided."

local_ids = partition_book.nid2localnid(seed_nodes, partition_book.partid)
local_ids = F.astype(local_ids, local_g.idtype)

Expand All @@ -278,6 +278,43 @@ def _sample_etype_neighbors(
return LocalSampledGraph(global_src, global_dst, global_eids)


def _sample_etype_neighbors(use_graphbolt, *args, **kwargs):
"""Wrapper for sampling etype neighbors.
The actual sampling function depends on whether to use GraphBolt.
Parameters
----------
use_graphbolt : bool
Whether to use GraphBolt for sampling.
args : list
The arguments for the sampling function.
kwargs : dict
The keyword arguments for the sampling function.
Returns
-------
tensor
The source node ID array.
tensor
The destination node ID array.
tensor
The edge ID array.
tensor
The edge type ID array.
"""
func = (
_sample_neighbors_graphbolt
if use_graphbolt
else _sample_etype_neighbors_dgl
)
if use_graphbolt:
# GraphBolt does not require `etype_offset` and `etype_sorted`.
kwargs.pop("etype_offset", None)
kwargs.pop("etype_sorted", None)
return func(*args, **kwargs)


def _find_edges(local_g, partition_book, seed_edges):
"""Given an edge ID array, return the source
and destination node ID array ``s`` and ``d`` in the local partition.
Expand Down Expand Up @@ -426,13 +463,15 @@ def __init__(
prob=None,
replace=False,
etype_sorted=True,
use_graphbolt=False,
):
self.seed_nodes = nodes
self.edge_dir = edge_dir
self.prob = prob
self.replace = replace
self.fan_out = fan_out
self.etype_sorted = etype_sorted
self.use_graphbolt = use_graphbolt

def __setstate__(self, state):
(
Expand All @@ -442,6 +481,7 @@ def __setstate__(self, state):
self.replace,
self.fan_out,
self.etype_sorted,
self.use_graphbolt,
) = state

def __getstate__(self):
Expand All @@ -452,6 +492,7 @@ def __getstate__(self):
self.replace,
self.fan_out,
self.etype_sorted,
self.use_graphbolt,
)

def process_request(self, server_state):
Expand All @@ -468,15 +509,16 @@ def process_request(self, server_state):
else:
probs = None
res = _sample_etype_neighbors(
self.use_graphbolt,
local_g,
partition_book,
self.seed_nodes,
etype_offset,
self.fan_out,
self.edge_dir,
probs,
self.replace,
self.etype_sorted,
edge_dir=self.edge_dir,
prob=probs,
replace=self.replace,
etype_offset=etype_offset,
etype_sorted=self.etype_sorted,
)
return SubgraphResponse(
res.global_src,
Expand Down Expand Up @@ -772,6 +814,7 @@ def sample_etype_neighbors(
prob=None,
replace=False,
etype_sorted=True,
use_graphbolt=False,
):
"""Sample from the neighbors of the given nodes from a distributed graph.
Expand Down Expand Up @@ -825,6 +868,8 @@ def sample_etype_neighbors(
neighbors are sampled. If fanout == -1, all neighbors are collected.
etype_sorted : bool, optional
Indicates whether etypes are sorted.
use_graphbolt : bool, optional
Whether to use GraphBolt for sampling.
Returns
-------
Expand Down Expand Up @@ -882,6 +927,7 @@ def issue_remote_req(node_ids):
prob=_prob,
replace=replace,
etype_sorted=etype_sorted,
use_graphbolt=use_graphbolt,
)

def local_access(local_g, partition_book, local_nids):
Expand All @@ -897,14 +943,15 @@ def local_access(local_g, partition_book, local_nids):
for etype in g.canonical_etypes
]
return _sample_etype_neighbors(
use_graphbolt,
local_g,
partition_book,
local_nids,
etype_offset,
fanout,
edge_dir,
_prob,
replace,
edge_dir=edge_dir,
prob=_prob,
replace=replace,
etype_offset=etype_offset,
etype_sorted=etype_sorted,
)

Expand Down
Loading

0 comments on commit 6735a3a

Please sign in to comment.