Skip to content

Commit

Permalink
[Distributed] add in_subgraph on DistGraph. (dmlc#1755)
Browse files Browse the repository at this point in the history
* add in_subgraph on DistGraph.

* check in more.

* fix test.

* add comments.

* fix test.

* update test.

* update.

* rename.

* update comment

* fix test
  • Loading branch information
zheng-da authored Jul 8, 2020
1 parent bdc1e64 commit 167216a
Show file tree
Hide file tree
Showing 3 changed files with 211 additions and 60 deletions.
2 changes: 1 addition & 1 deletion python/dgl/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@
from .rpc_client import connect_to_server, finalize_client, shutdown_servers
from .kvstore import KVServer, KVClient
from .server_state import ServerState
from .sampling import sample_neighbors
from .graph_services import sample_neighbors, in_subgraph
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
"""Sampling module"""
"""A set of graph services of getting subgraphs from DistGraph"""
from collections import namedtuple

from .rpc import Request, Response, send_requests_to_machine, recv_responses
from ..sampling import sample_neighbors as local_sample_neighbors
from ..transform import in_subgraph as local_in_subgraph
from . import register_service
from ..convert import graph
from ..base import NID, EID
from ..utils import toindex
from .. import backend as F

__all__ = ['sample_neighbors']
__all__ = ['sample_neighbors', 'in_subgraph']

SAMPLING_SERVICE_ID = 6657
INSUBGRAPH_SERVICE_ID = 6658


class SamplingResponse(Response):
"""Sampling Response"""
class SubgraphResponse(Response):
"""The response for sampling and in_subgraph"""

def __init__(self, global_src, global_dst, global_eids):
self.global_src = global_src
Expand Down Expand Up @@ -49,6 +50,25 @@ def _sample_neighbors(local_g, partition_book, seed_nodes, fan_out, edge_dir, pr
return global_src, global_dst, global_eids


def _in_subgraph(local_g, partition_book, seed_nodes):
""" Get in subgraph from local partition.
The input nodes use global Ids. We need to map the global node Ids to local node Ids,
get in-subgraph and map the sampled results to the global Ids space again.
The results are stored in three vectors that store source nodes, destination nodes
and edge Ids.
"""
local_ids = partition_book.nid2localnid(seed_nodes, partition_book.partid)
local_ids = F.astype(local_ids, local_g.idtype)
# local_ids = self.seed_nodes
sampled_graph = local_in_subgraph(local_g, local_ids)
global_nid_mapping = local_g.ndata[NID]
src, dst = sampled_graph.edges()
global_src, global_dst = global_nid_mapping[src], global_nid_mapping[dst]
global_eids = F.gather_row(local_g.edata[EID], sampled_graph.edata[EID])
return global_src, global_dst, global_eids


class SamplingRequest(Request):
"""Sampling Request"""

Expand All @@ -72,7 +92,27 @@ def process_request(self, server_state):
self.seed_nodes,
self.fan_out, self.edge_dir,
self.prob, self.replace)
return SamplingResponse(global_src, global_dst, global_eids)
return SubgraphResponse(global_src, global_dst, global_eids)


class InSubgraphRequest(Request):
"""InSubgraph Request"""

def __init__(self, nodes):
self.seed_nodes = nodes

def __setstate__(self, state):
self.seed_nodes = state

def __getstate__(self):
return self.seed_nodes

def process_request(self, server_state):
local_g = server_state.graph
partition_book = server_state.partition_book
global_src, global_dst, global_eids = _in_subgraph(local_g, partition_book,
self.seed_nodes)
return SubgraphResponse(global_src, global_dst, global_eids)


def merge_graphs(res_list, num_nodes):
Expand All @@ -99,47 +139,33 @@ def merge_graphs(res_list, num_nodes):

LocalSampledGraph = namedtuple('LocalSampledGraph', 'global_src global_dst global_eids')

def sample_neighbors(dist_graph, nodes, fanout, edge_dir='in', prob=None, replace=False):
"""Sample from the neighbors of the given nodes from a distributed graph.
def _distributed_access(g, nodes, issue_remote_req, local_access):
'''A routine that fetches local neighborhood of nodes from the distributed graph.
When sampling with replacement, the sampled subgraph could have parallel edges.
For sampling without replace, if fanout > the number of neighbors, all the
neighbors are sampled.
Node/edge features are not preserved. The original IDs of
the sampled edges are stored as the `dgl.EID` feature in the returned graph.
The local neighborhood of some nodes are stored in the local machine and the other
nodes have their neighborhood on remote machines. This code will issue remote
access requests first before fetching data from the local machine. In the end,
we combine the data from the local machine and remote machines.
In this way, we can hide the latency of accessing data on remote machines.
Parameters
----------
g : DistGraph
The distributed graph.
nodes : tensor or dict
Node ids to sample neighbors from. The allowed types
are dictionary of node types to node id tensors, or simply node id tensor if
the given graph g has only one type of nodes.
fanout : int or dict[etype, int]
The number of sampled neighbors for each node on each edge type. Provide a dict
to specify different fanout values for each edge type.
edge_dir : str, optional
Edge direction ('in' or 'out'). If is 'in', sample from in edges. Otherwise,
sample from out edges.
prob : str, optional
Feature name used as the probabilities associated with each neighbor of a node.
Its shape should be compatible with a scalar edge feature tensor.
replace : bool, optional
If True, sample with replacement.
The distributed graph
nodes : tensor
The nodes whose neighborhood are to be fetched.
issue_remote_req : callable
The function that issues requests to access remote data.
local_access : callable
The function that reads data on the local machine.
Returns
-------
DGLHeteroGraph
A sampled subgraph containing only the sampled neighbor edges from
``nodes``. The sampled subgraph has the same metagraph as the original
one.
"""
assert edge_dir == 'in'
The subgraph that contains the neighborhoods of all input nodes.
'''
req_list = []
partition_book = dist_graph.get_partition_book()
partition_book = g.get_partition_book()
nodes = toindex(nodes).tousertensor()
partition_id = partition_book.nid2partid(nodes)
local_nids = None
Expand All @@ -149,12 +175,11 @@ def sample_neighbors(dist_graph, nodes, fanout, edge_dir='in', prob=None, replac
# run on the same machine. With a good partitioning, most of the seed nodes
# should reside in the local partition. If the server and the client
# are not co-located, the client doesn't have a local partition.
if pid == partition_book.partid and dist_graph.local_partition is not None:
if pid == partition_book.partid and g.local_partition is not None:
assert local_nids is None
local_nids = node_id
elif len(node_id) != 0:
req = SamplingRequest(node_id, fanout, edge_dir=edge_dir,
prob=prob, replace=replace)
req = issue_remote_req(node_id)
req_list.append((pid, req))

# send requests to the remote machine.
Expand All @@ -165,17 +190,88 @@ def sample_neighbors(dist_graph, nodes, fanout, edge_dir='in', prob=None, replac
# sample neighbors for the nodes in the local partition.
res_list = []
if local_nids is not None:
src, dst, eids = _sample_neighbors(dist_graph.local_partition, partition_book,
local_nids, fanout, edge_dir, prob, replace)
src, dst, eids = local_access(g.local_partition, partition_book, local_nids)
res_list.append(LocalSampledGraph(src, dst, eids))

# receive responses from remote machines.
if msgseq2pos is not None:
results = recv_responses(msgseq2pos)
res_list.extend(results)

sampled_graph = merge_graphs(res_list, dist_graph.number_of_nodes())
sampled_graph = merge_graphs(res_list, g.number_of_nodes())
return sampled_graph

def sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False):
"""Sample from the neighbors of the given nodes from a distributed graph.
When sampling with replacement, the sampled subgraph could have parallel edges.
For sampling without replace, if fanout > the number of neighbors, all the
neighbors are sampled.
Node/edge features are not preserved. The original IDs of
the sampled edges are stored as the `dgl.EID` feature in the returned graph.
register_service(SAMPLING_SERVICE_ID, SamplingRequest, SamplingResponse)
For now, we only support the input graph with one node type and one edge type.
Parameters
----------
g : DistGraph
The distributed graph.
nodes : tensor
Node ids to sample neighbors from.
fanout : int
The number of sampled neighbors for each node.
edge_dir : str, optional
Edge direction ('in' or 'out'). If is 'in', sample from in edges. Otherwise,
sample from out edges.
prob : str, optional
Feature name used as the probabilities associated with each neighbor of a node.
Its shape should be compatible with a scalar edge feature tensor.
replace : bool, optional
If True, sample with replacement.
Returns
-------
DGLHeteroGraph
A sampled subgraph containing only the sampled neighbor edges from
``nodes``. The sampled subgraph has the same metagraph as the original
one.
"""
def issue_remote_req(node_ids):
return SamplingRequest(node_ids, fanout, edge_dir=edge_dir,
prob=prob, replace=replace)
def local_access(local_g, partition_book, local_nids):
return _sample_neighbors(local_g, partition_book, local_nids,
fanout, edge_dir, prob, replace)
return _distributed_access(g, nodes, issue_remote_req, local_access)

def in_subgraph(g, nodes):
"""Extract the subgraph containing only the in edges of the given nodes.
The subgraph keeps the same type schema and the cardinality of the original one.
Node/edge features are not preserved. The original IDs
the extracted edges are stored as the `dgl.EID` feature in the returned graph.
For now, we only support the input graph with one node type and one edge type.
Parameters
----------
g : DistGraph
The distributed graph structure.
nodes : tensor
Node ids to sample neighbors from.
Returns
-------
DGLHeteroGraph
The subgraph.
"""
def issue_remote_req(node_ids):
return InSubgraphRequest(node_ids)
def local_access(local_g, partition_book, local_nids):
return _in_subgraph(local_g, partition_book, local_nids)
return _distributed_access(g, nodes, issue_remote_req, local_access)

register_service(SAMPLING_SERVICE_ID, SamplingRequest, SubgraphResponse)
register_service(INSUBGRAPH_SERVICE_ID, InSubgraphRequest, SubgraphResponse)
Loading

0 comments on commit 167216a

Please sign in to comment.