Skip to content

Commit

Permalink
[DistGB] return eids together with etype_ids in sampling (dmlc#7084)
Browse files Browse the repository at this point in the history
  • Loading branch information
Rhett-Ying authored Feb 5, 2024
1 parent 346197c commit f3af2a9
Show file tree
Hide file tree
Showing 2 changed files with 171 additions and 77 deletions.
158 changes: 107 additions & 51 deletions python/dgl/distributed/graph_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch

from .. import backend as F, graphbolt as gb
from ..base import EID, NID
from ..base import EID, ETYPE, NID
from ..convert import graph, heterograph
from ..sampling import (
sample_etype_neighbors as local_sample_etype_neighbors,
Expand Down Expand Up @@ -40,16 +40,29 @@
class SubgraphResponse(Response):
"""The response for sampling and in_subgraph"""

def __init__(self, global_src, global_dst, global_eids):
def __init__(
self, global_src, global_dst, *, global_eids=None, etype_ids=None
):
self.global_src = global_src
self.global_dst = global_dst
self.global_eids = global_eids
self.etype_ids = etype_ids

def __setstate__(self, state):
self.global_src, self.global_dst, self.global_eids = state
(
self.global_src,
self.global_dst,
self.global_eids,
self.etype_ids,
) = state

def __getstate__(self):
return self.global_src, self.global_dst, self.global_eids
return (
self.global_src,
self.global_dst,
self.global_eids,
self.etype_ids,
)


class FindEdgeResponse(Response):
Expand All @@ -68,7 +81,7 @@ def __getstate__(self):


def _sample_neighbors_graphbolt(
g, gpb, nodes, fanout, prob=None, replace=False
g, gpb, nodes, fanout, edge_dir="in", prob=None, replace=False
):
"""Sample from local partition via graphbolt.
Expand All @@ -77,8 +90,6 @@ def _sample_neighbors_graphbolt(
space again. The sampled results are stored in three vectors that store
source nodes, destination nodes, etype IDs and edge IDs.
[Rui][TODO] edge IDs are not returned as not supported yet.
Parameters
----------
g : FusedCSCSamplingGraph
Expand All @@ -89,6 +100,8 @@ def _sample_neighbors_graphbolt(
The nodes to sample neighbors from.
fanout : tensor or int
The number of edges to be sampled for each node.
edge_dir : str, optional
Determines whether to sample inbound or outbound edges.
prob : tensor, optional
The probability associated with each neighboring edge of a node.
replace : bool, optional
Expand All @@ -100,11 +113,15 @@ def _sample_neighbors_graphbolt(
The source node ID array.
tensor
The destination node ID array.
tensor
The edge type ID array.
tensor
The edge ID array.
tensor
The edge type ID array.
"""
assert (
edge_dir == "in"
), f"GraphBolt only supports inbound edge sampling but got {edge_dir}."

# 1. Map global node IDs to local node IDs.
nodes = gpb.nid2localnid(nodes, gpb.partid)

Expand Down Expand Up @@ -139,11 +156,20 @@ def _sample_neighbors_graphbolt(
global_src = global_nid_mapping[local_src]
global_dst = global_nid_mapping[local_dst]

return global_src, global_dst, subgraph.type_per_edge
# [Rui][TODO] edge IDs are not supported yet.
return LocalSampledGraph(
global_src, global_dst, None, subgraph.type_per_edge
)


def _sample_neighbors(
local_g, partition_book, seed_nodes, fan_out, edge_dir, prob, replace
def _sample_neighbors_dgl(
local_g,
partition_book,
seed_nodes,
fan_out,
edge_dir="in",
prob=None,
replace=False,
):
"""Sample from local partition.
Expand All @@ -170,7 +196,38 @@ def _sample_neighbors(
global_nid_mapping, src
), F.gather_row(global_nid_mapping, dst)
global_eids = F.gather_row(local_g.edata[EID], sampled_graph.edata[EID])
return global_src, global_dst, global_eids
return LocalSampledGraph(global_src, global_dst, global_eids)


def _sample_neighbors(use_graphbolt, *args, **kwargs):
"""Wrapper for sampling neighbors.
The actual sampling function depends on whether to use GraphBolt.
Parameters
----------
use_graphbolt : bool
Whether to use GraphBolt for sampling.
args : list
The arguments for the sampling function.
kwargs : dict
The keyword arguments for the sampling function.
Returns
-------
tensor
The source node ID array.
tensor
The destination node ID array.
tensor
The edge ID array.
tensor
The edge type ID array.
"""
func = (
_sample_neighbors_graphbolt if use_graphbolt else _sample_neighbors_dgl
)
return func(*args, **kwargs)


def _sample_etype_neighbors(
Expand Down Expand Up @@ -211,7 +268,7 @@ def _sample_etype_neighbors(
global_nid_mapping, src
), F.gather_row(global_nid_mapping, dst)
global_eids = F.gather_row(local_g.edata[EID], sampled_graph.edata[EID])
return global_src, global_dst, global_eids
return LocalSampledGraph(global_src, global_dst, global_eids)


def _find_edges(local_g, partition_book, seed_edges):
Expand Down Expand Up @@ -257,7 +314,7 @@ def _in_subgraph(local_g, partition_book, seed_nodes):
src, dst = sampled_graph.edges()
global_src, global_dst = global_nid_mapping[src], global_nid_mapping[dst]
global_eids = F.gather_row(local_g.edata[EID], sampled_graph.edata[EID])
return global_src, global_dst, global_eids
return LocalSampledGraph(global_src, global_dst, global_eids)


# --- NOTE 1 ---
Expand Down Expand Up @@ -333,26 +390,22 @@ def process_request(self, server_state):
prob = [kv_store.data_store[self.prob]]
else:
prob = None
if self.use_graphbolt:
global_src, global_dst, etype_ids = _sample_neighbors_graphbolt(
local_g,
partition_book,
self.seed_nodes,
self.fan_out,
prob,
self.replace,
)
return SubgraphResponse(global_src, global_dst, etype_ids)
global_src, global_dst, global_eids = _sample_neighbors(
res = _sample_neighbors(
self.use_graphbolt,
local_g,
partition_book,
self.seed_nodes,
self.fan_out,
self.edge_dir,
prob,
self.replace,
edge_dir=self.edge_dir,
prob=prob,
replace=self.replace,
)
return SubgraphResponse(
res.global_src,
res.global_dst,
global_eids=res.global_eids,
etype_ids=res.etype_ids,
)
return SubgraphResponse(global_src, global_dst, global_eids)


class SamplingRequestEtype(Request):
Expand Down Expand Up @@ -407,7 +460,7 @@ def process_request(self, server_state):
]
else:
probs = None
global_src, global_dst, global_eids = _sample_etype_neighbors(
res = _sample_etype_neighbors(
local_g,
partition_book,
self.seed_nodes,
Expand All @@ -418,7 +471,12 @@ def process_request(self, server_state):
self.replace,
self.etype_sorted,
)
return SubgraphResponse(global_src, global_dst, global_eids)
return SubgraphResponse(
res.global_src,
res.global_dst,
global_eids=res.global_eids,
etype_ids=res.etype_ids,
)


class EdgesRequest(Request):
Expand Down Expand Up @@ -532,7 +590,7 @@ def process_request(self, server_state):
global_src, global_dst, global_eids = _in_subgraph(
local_g, partition_book, self.seed_nodes
)
return SubgraphResponse(global_src, global_dst, global_eids)
return SubgraphResponse(global_src, global_dst, global_eids=global_eids)


def merge_graphs(res_list, num_nodes):
Expand All @@ -541,25 +599,33 @@ def merge_graphs(res_list, num_nodes):
srcs = []
dsts = []
eids = []
etype_ids = []
for res in res_list:
srcs.append(res.global_src)
dsts.append(res.global_dst)
eids.append(res.global_eids)
etype_ids.append(res.etype_ids)
src_tensor = F.cat(srcs, 0)
dst_tensor = F.cat(dsts, 0)
eid_tensor = None if eids[0] is None else F.cat(eids, 0)
etype_id_tensor = None if etype_ids[0] is None else F.cat(etype_ids, 0)
else:
src_tensor = res_list[0].global_src
dst_tensor = res_list[0].global_dst
eid_tensor = res_list[0].global_eids
etype_id_tensor = res_list[0].etype_ids
g = graph((src_tensor, dst_tensor), num_nodes=num_nodes)
if eid_tensor is not None:
g.edata[EID] = eid_tensor
if etype_id_tensor is not None:
g.edata[ETYPE] = etype_id_tensor
return g


LocalSampledGraph = namedtuple(
"LocalSampledGraph", "global_src global_dst global_eids"
LocalSampledGraph = namedtuple( # pylint: disable=unexpected-keyword-arg
"LocalSampledGraph",
"global_src global_dst global_eids etype_ids",
defaults=(None, None, None, None),
)


Expand Down Expand Up @@ -615,10 +681,8 @@ def _distributed_access(g, nodes, issue_remote_req, local_access):
# sample neighbors for the nodes in the local partition.
res_list = []
if local_nids is not None:
src, dst, eids = local_access(
g.local_partition, partition_book, local_nids
)
res_list.append(LocalSampledGraph(src, dst, eids))
res = local_access(g.local_partition, partition_book, local_nids)
res_list.append(res)

# receive responses from remote machines.
if msgseq2pos is not None:
Expand Down Expand Up @@ -916,23 +980,15 @@ def issue_remote_req(node_ids):
def local_access(local_g, partition_book, local_nids):
# See NOTE 1
_prob = [g.edata[prob].local_partition] if prob is not None else None
if use_graphbolt:
return _sample_neighbors_graphbolt(
local_g,
partition_book,
local_nids,
fanout,
prob=_prob,
replace=replace,
)
return _sample_neighbors(
use_graphbolt,
local_g,
partition_book,
local_nids,
fanout,
edge_dir,
_prob,
replace,
edge_dir=edge_dir,
prob=_prob,
replace=replace,
)

frontier = _distributed_access(g, nodes, issue_remote_req, local_access)
Expand Down
Loading

0 comments on commit f3af2a9

Please sign in to comment.