Skip to content

Commit

Permalink
[Distributed] Distributed heterograph training (dmlc#3069)
Browse files Browse the repository at this point in the history
* support hetero RGCN.

* fix.

* simplify code.

* sample_neighbors return heterograph directly.

* avoid using to_heterogeneous.

* compute canonical etypes in advance.

* fix tests.

* fix.

* fix distributed data loader for heterograph.

* use NodeDataLoader.

* fix bugs in partitioning on heterogeneous graphs.

* fix lint.

* fix tests.

* fix.

* fix.

* fix bugs.

* fix tests.

* fix.

* enable coo for distributed.

* fix.

* fix.

* fix.

* fix.

* fix.

Co-authored-by: Ubuntu <[email protected]>
Co-authored-by: Zheng <[email protected]>
  • Loading branch information
3 people authored Jul 17, 2021
1 parent 905c0aa commit 34426a9
Show file tree
Hide file tree
Showing 10 changed files with 313 additions and 194 deletions.
286 changes: 162 additions & 124 deletions examples/pytorch/rgcn/experimental/entity_classify_dist.py

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions python/dgl/backend/mxnet/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,7 @@ def zerocopy_to_numpy(arr):
return arr.asnumpy()

def zerocopy_from_numpy(np_data):
np_data = np.asarray(np_data, order='C')
return mx.nd.from_numpy(np_data, zero_copy=True)

def zerocopy_to_dgl_ndarray(arr):
Expand Down
3 changes: 2 additions & 1 deletion python/dgl/dataloading/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,8 @@ def collate(self, items):
def _prepare_tensor_dict(g, data, name, is_distributed):
if is_distributed:
x = F.tensor(next(iter(data.values())))
return {k: F.copy_to(F.astype(v, F.dtype(x)), F.context(x)) for k, v in data.items()}
return {k: F.copy_to(F.astype(F.tensor(v), F.dtype(x)), F.context(x)) \
for k, v in data.items()}
else:
return utils.prepare_tensor_dict(g, data, name)

Expand Down
10 changes: 6 additions & 4 deletions python/dgl/distributed/dist_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class DistDataLoader:
Parameters
----------
dataset: a tensor
A tensor of node IDs or edge IDs.
Tensors of node IDs or edge IDs.
batch_size: int
The number of samples per batch to load.
shuffle: bool, optional
Expand Down Expand Up @@ -127,7 +127,8 @@ def __init__(self, dataset, batch_size, shuffle=False, collate_fn=None, drop_las
self.shuffle = shuffle
self.is_closed = False

self.dataset = F.tensor(dataset)
self.dataset = dataset
self.data_idx = F.arange(0, len(dataset))
self.expected_idxs = len(dataset) // self.batch_size
if not self.drop_last and len(dataset) % self.batch_size != 0:
self.expected_idxs += 1
Expand Down Expand Up @@ -176,7 +177,7 @@ def __next__(self):

def __iter__(self):
if self.shuffle:
self.dataset = F.rand_shuffle(self.dataset)
self.data_idx = F.rand_shuffle(self.data_idx)
self.recv_idxs = 0
self.current_pos = 0
self.num_pending = 0
Expand Down Expand Up @@ -205,6 +206,7 @@ def _next_data(self):
end_pos = len(self.dataset)
else:
end_pos = self.current_pos + self.batch_size
ret = self.dataset[self.current_pos:end_pos]
idx = self.data_idx[self.current_pos:end_pos].tolist()
ret = [self.dataset[i] for i in idx]
self.current_pos = end_pos
return ret
61 changes: 58 additions & 3 deletions python/dgl/distributed/dist_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ class DistGraphServer(KVServer):
'''
def __init__(self, server_id, ip_config, num_servers,
num_clients, part_config, disable_shared_mem=False,
graph_format='csc'):
graph_format=('csc', 'coo')):
super(DistGraphServer, self).__init__(server_id=server_id,
ip_config=ip_config,
num_servers=num_servers,
Expand Down Expand Up @@ -482,6 +482,25 @@ def __init__(self, graph_name, gpb=None, part_config=None):
self._ntype_map = {ntype:i for i, ntype in enumerate(self.ntypes)}
self._etype_map = {etype:i for i, etype in enumerate(self.etypes)}

# Get canonical edge types.
# TODO(zhengda) this requires the server to store the graph with coo format.
eid = []
for etype in self.etypes:
type_eid = F.zeros((1,), F.int64, F.cpu())
eid.append(self._gpb.map_to_homo_eid(type_eid, etype))
eid = F.cat(eid, 0)
src, dst = dist_find_edges(self, eid)
src_tids, _ = self._gpb.map_to_per_ntype(src)
dst_tids, _ = self._gpb.map_to_per_ntype(dst)
self._canonical_etypes = []
etype_ids = F.arange(0, len(self.etypes))
for src_tid, etype_id, dst_tid in zip(src_tids, etype_ids, dst_tids):
src_tid = F.as_scalar(src_tid)
etype_id = F.as_scalar(etype_id)
dst_tid = F.as_scalar(dst_tid)
self._canonical_etypes.append((self.ntypes[src_tid], self.etypes[etype_id],
self.ntypes[dst_tid]))

def _init(self):
self._client = get_kvstore()
assert self._client is not None, \
Expand Down Expand Up @@ -576,7 +595,7 @@ def idtype(self):
int
"""
# TODO(da?): describe when self._g is None and idtype shouldn't be called.
return self._g.idtype
return F.int64

@property
def device(self):
Expand All @@ -598,7 +617,7 @@ def device(self):
Device context object
"""
# TODO(da?): describe when self._g is None and device shouldn't be called.
return self._g.device
return F.cpu()

@property
def ntypes(self):
Expand Down Expand Up @@ -635,6 +654,42 @@ def etypes(self):
# Currently, we only support a graph with one edge type.
return self._gpb.etypes

@property
def canonical_etypes(self):
"""Return all the canonical edge types in the graph.
A canonical edge type is a string triplet ``(str, str, str)``
for source node type, edge type and destination node type.
Returns
-------
list[(str, str, str)]
All the canonical edge type triplets in a list.
Notes
-----
DGL internally assigns an integer ID for each edge type. The returned
edge type names are sorted according to their IDs.
See Also
--------
etypes
Examples
--------
The following example uses PyTorch backend.
>>> import dgl
>>> import torch
>>> g = DistGraph("test")
>>> g.canonical_etypes
[('user', 'follows', 'user'),
('user', 'follows', 'game'),
('user', 'plays', 'game')]
"""
return self._canonical_etypes

def get_ntype_id(self, ntype):
"""Return the ID of the given node type.
Expand Down
12 changes: 8 additions & 4 deletions python/dgl/distributed/graph_partition_book.py
Original file line number Diff line number Diff line change
Expand Up @@ -770,16 +770,20 @@ def map_to_homo_nid(self, ids, ntype):
"""
ids = utils.toindex(ids).tousertensor()
partids = self.nid2partid(ids, ntype)
end_diff = F.tensor(self._typed_max_node_ids[ntype])[partids] - ids
return F.tensor(self._typed_nid_range[ntype][:, 1])[partids] - end_diff
typed_max_nids = F.zerocopy_from_numpy(self._typed_max_node_ids[ntype])
end_diff = F.gather_row(typed_max_nids, partids) - ids
typed_nid_range = F.zerocopy_from_numpy(self._typed_nid_range[ntype][:, 1])
return F.gather_row(typed_nid_range, partids) - end_diff

def map_to_homo_eid(self, ids, etype):
"""Map per-edge-type IDs to global edge IDs in the homoenegeous format.
"""
ids = utils.toindex(ids).tousertensor()
partids = self.eid2partid(ids, etype)
end_diff = F.tensor(self._typed_max_edge_ids[etype][partids]) - ids
return F.tensor(self._typed_eid_range[etype][:, 1])[partids] - end_diff
typed_max_eids = F.zerocopy_from_numpy(self._typed_max_edge_ids[etype])
end_diff = F.gather_row(typed_max_eids, partids) - ids
typed_eid_range = F.zerocopy_from_numpy(self._typed_eid_range[etype][:, 1])
return F.gather_row(typed_eid_range, partids) - end_diff

def nid2partid(self, nids, ntype='_N'):
"""From global node IDs to partition IDs
Expand Down
54 changes: 38 additions & 16 deletions python/dgl/distributed/graph_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from ..sampling import sample_neighbors as local_sample_neighbors
from ..subgraph import in_subgraph as local_in_subgraph
from .rpc import register_service
from ..convert import graph
from ..convert import graph, heterograph
from ..base import NID, EID
from ..utils import toindex
from .. import backend as F
Expand Down Expand Up @@ -337,19 +337,8 @@ def sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False):
Node/edge features are not preserved. The original IDs of
the sampled edges are stored as the `dgl.EID` feature in the returned graph.
This version provides an experimental support for heterogeneous graphs.
When the input graph is heterogeneous, the sampled subgraph is still stored in
the homogeneous graph format. That is, all nodes and edges are assigned with
unique IDs (in contrast, we typically use a type name and a node/edge ID to
identify a node or an edge in ``DGLGraph``). We refer to this type of IDs
as *homogeneous ID*.
Users can use :func:`dgl.distributed.GraphPartitionBook.map_to_per_ntype`
and :func:`dgl.distributed.GraphPartitionBook.map_to_per_etype`
to identify their node/edge types and node/edge IDs of that type.
For heterogeneous graphs, ``nodes`` can be a dictionary whose key is node type
and the value is type-specific node IDs; ``nodes`` can also be a tensor of
*homogeneous ID*.
For heterogeneous graphs, ``nodes`` is a dictionary whose key is node type
and the value is type-specific node IDs.
Parameters
----------
Expand Down Expand Up @@ -388,7 +377,8 @@ def sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False):
A sampled subgraph containing only the sampled neighboring edges. It is on CPU.
"""
gpb = g.get_partition_book()
if isinstance(nodes, dict):
if len(gpb.etypes) > 1:
assert isinstance(nodes, dict)
homo_nids = []
for ntype in nodes:
assert ntype in g.ntypes, 'The sampled node type does not exist in the input graph'
Expand All @@ -398,13 +388,45 @@ def sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False):
typed_nodes = toindex(nodes[ntype]).tousertensor()
homo_nids.append(gpb.map_to_homo_nid(typed_nodes, ntype))
nodes = F.cat(homo_nids, 0)
elif isinstance(nodes, dict):
assert len(nodes) == 1
nodes = list(nodes.values())[0]

def issue_remote_req(node_ids):
return SamplingRequest(node_ids, fanout, edge_dir=edge_dir,
prob=prob, replace=replace)
def local_access(local_g, partition_book, local_nids):
return _sample_neighbors(local_g, partition_book, local_nids,
fanout, edge_dir, prob, replace)
return _distributed_access(g, nodes, issue_remote_req, local_access)
frontier = _distributed_access(g, nodes, issue_remote_req, local_access)
if len(gpb.etypes) > 1:
etype_ids, frontier.edata[EID] = gpb.map_to_per_etype(frontier.edata[EID])
src, dst = frontier.edges()
etype_ids, idx = F.sort_1d(etype_ids)
src, dst = F.gather_row(src, idx), F.gather_row(dst, idx)
eid = F.gather_row(frontier.edata[EID], idx)
_, src = gpb.map_to_per_ntype(src)
_, dst = gpb.map_to_per_ntype(dst)

data_dict = dict()
edge_ids = {}
for etid in range(len(g.etypes)):
etype = g.etypes[etid]
canonical_etype = g.canonical_etypes[etid]
type_idx = etype_ids == etid
if F.sum(type_idx, 0) > 0:
data_dict[canonical_etype] = (F.boolean_mask(src, type_idx), \
F.boolean_mask(dst, type_idx))
edge_ids[etype] = F.boolean_mask(eid, type_idx)
hg = heterograph(data_dict,
{ntype: g.number_of_nodes(ntype) for ntype in g.ntypes},
idtype=g.idtype)

for etype in edge_ids:
hg.edges[etype].data[EID] = edge_ids[etype]
return hg
else:
return frontier

def _distributed_edge_access(g, edges, issue_remote_req, local_access):
"""A routine that fetches local edges from distributed graph.
Expand Down
11 changes: 9 additions & 2 deletions tests/distributed/test_dist_graph_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ def create_random_graph(n):
def run_server(graph_name, server_id, server_count, num_clients, shared_mem):
g = DistGraphServer(server_id, "kv_ip_config.txt", server_count, num_clients,
'/tmp/dist_graph/{}.json'.format(graph_name),
disable_shared_mem=not shared_mem)
disable_shared_mem=not shared_mem,
graph_format=['csc', 'coo'])
print('start server', server_id)
g.start()

Expand Down Expand Up @@ -469,6 +470,13 @@ def check_dist_graph_hetero(g, num_clients, num_nodes, num_edges):
for etype in num_edges:
assert etype in g.etypes
assert num_edges[etype] == g.number_of_edges(etype)
etypes = [('n1', 'r1', 'n2'),
('n1', 'r2', 'n3'),
('n2', 'r3', 'n3')]
for i, etype in enumerate(g.canonical_etypes):
assert etype[0] == etypes[i][0]
assert etype[1] == etypes[i][1]
assert etype[2] == etypes[i][2]
assert g.number_of_nodes() == sum([num_nodes[ntype] for ntype in num_nodes])
assert g.number_of_edges() == sum([num_edges[etype] for etype in num_edges])

Expand Down Expand Up @@ -584,7 +592,6 @@ def test_server_client():
check_server_client(True, 1, 1)
check_server_client(False, 1, 1)
check_server_client(True, 2, 2)
check_server_client(False, 2, 2)

@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@unittest.skipIf(dgl.backend.backend_name == "tensorflow", reason="TF doesn't support distributed DistEmbedding")
Expand Down
66 changes: 27 additions & 39 deletions tests/distributed/test_distributed_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from dgl.distributed import DistGraphServer, DistGraph


def start_server(rank, tmpdir, disable_shared_mem, graph_name, graph_format='csc'):
def start_server(rank, tmpdir, disable_shared_mem, graph_name, graph_format=['csc', 'coo']):
g = DistGraphServer(rank, "rpc_ip_config.txt", 1, 1,
tmpdir / (graph_name + '.json'), disable_shared_mem=disable_shared_mem,
graph_format=graph_format)
Expand Down Expand Up @@ -284,7 +284,6 @@ def start_hetero_sample_client(rank, tmpdir, disable_shared_mem):
try:
nodes = {'n3': [0, 10, 99, 66, 124, 208]}
sampled_graph = sample_neighbors(dist_graph, nodes, 3)
nodes = gpb.map_to_homo_nid(nodes['n3'], 'n3')
block = dgl.to_block(sampled_graph, nodes)
block.edata[dgl.EID] = sampled_graph.edata[dgl.EID]
except Exception as e:
Expand Down Expand Up @@ -320,47 +319,36 @@ def check_rpc_hetero_sampling_shuffle(tmpdir, num_server):
for p in pserver_list:
p.join()

orig_nid_map = F.zeros((g.number_of_nodes(),), dtype=F.int64)
orig_eid_map = F.zeros((g.number_of_edges(),), dtype=F.int64)
orig_nid_map = {ntype: F.zeros((g.number_of_nodes(ntype),), dtype=F.int64) for ntype in g.ntypes}
orig_eid_map = {etype: F.zeros((g.number_of_edges(etype),), dtype=F.int64) for etype in g.etypes}
for i in range(num_server):
part, _, _, _, _, _, _ = load_partition(tmpdir / 'test_sampling.json', i)
F.scatter_row_inplace(orig_nid_map, part.ndata[dgl.NID], part.ndata['orig_id'])
F.scatter_row_inplace(orig_eid_map, part.edata[dgl.EID], part.edata['orig_id'])

src, dst = block.edges()
# These are global Ids after shuffling.
shuffled_src = F.gather_row(block.srcdata[dgl.NID], src)
shuffled_dst = F.gather_row(block.dstdata[dgl.NID], dst)
shuffled_eid = block.edata[dgl.EID]
# Get node/edge types.
etype, _ = gpb.map_to_per_etype(shuffled_eid)
src_type, _ = gpb.map_to_per_ntype(shuffled_src)
dst_type, _ = gpb.map_to_per_ntype(shuffled_dst)
etype = F.asnumpy(etype)
src_type = F.asnumpy(src_type)
dst_type = F.asnumpy(dst_type)
# These are global Ids in the original graph.
orig_src = F.asnumpy(F.gather_row(orig_nid_map, shuffled_src))
orig_dst = F.asnumpy(F.gather_row(orig_nid_map, shuffled_dst))
orig_eid = F.asnumpy(F.gather_row(orig_eid_map, shuffled_eid))

etype_map = {g.get_etype_id(etype):etype for etype in g.etypes}
etype_to_eptype = {g.get_etype_id(etype):(src_ntype, dst_ntype) for src_ntype, etype, dst_ntype in g.canonical_etypes}
for e in np.unique(etype):
src_t = src_type[etype == e]
dst_t = dst_type[etype == e]
assert np.all(src_t == src_t[0])
assert np.all(dst_t == dst_t[0])
ntype_ids, type_nids = gpb.map_to_per_ntype(part.ndata[dgl.NID])
for ntype_id, ntype in enumerate(g.ntypes):
idx = ntype_ids == ntype_id
F.scatter_row_inplace(orig_nid_map[ntype], F.boolean_mask(type_nids, idx),
F.boolean_mask(part.ndata['orig_id'], idx))
etype_ids, type_eids = gpb.map_to_per_etype(part.edata[dgl.EID])
for etype_id, etype in enumerate(g.etypes):
idx = etype_ids == etype_id
F.scatter_row_inplace(orig_eid_map[etype], F.boolean_mask(type_eids, idx),
F.boolean_mask(part.edata['orig_id'], idx))

for src_type, etype, dst_type in block.canonical_etypes:
src, dst = block.edges(etype=etype)
# These are global Ids after shuffling.
shuffled_src = F.gather_row(block.srcnodes[src_type].data[dgl.NID], src)
shuffled_dst = F.gather_row(block.dstnodes[dst_type].data[dgl.NID], dst)
shuffled_eid = block.edges[etype].data[dgl.EID]

orig_src = F.asnumpy(F.gather_row(orig_nid_map[src_type], shuffled_src))
orig_dst = F.asnumpy(F.gather_row(orig_nid_map[dst_type], shuffled_dst))
orig_eid = F.asnumpy(F.gather_row(orig_eid_map[etype], shuffled_eid))

# Check the node Ids and edge Ids.
orig_src1, orig_dst1 = g.find_edges(orig_eid[etype == e], etype=etype_map[e])
assert np.all(F.asnumpy(orig_src1) == orig_src[etype == e])
assert np.all(F.asnumpy(orig_dst1) == orig_dst[etype == e])

# Check the node types.
src_ntype, dst_ntype = etype_to_eptype[e]
assert np.all(src_t == g.get_ntype_id(src_ntype))
assert np.all(dst_t == g.get_ntype_id(dst_ntype))
orig_src1, orig_dst1 = g.find_edges(orig_eid, etype=etype)
assert np.all(F.asnumpy(orig_src1) == orig_src)
assert np.all(F.asnumpy(orig_dst1) == orig_dst)

# Wait non shared memory graph store
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
Expand Down
3 changes: 2 additions & 1 deletion tests/distributed/test_mp_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ def start_server(rank, tmpdir, disable_shared_mem, num_clients):
import dgl
print('server: #clients=' + str(num_clients))
g = DistGraphServer(rank, "mp_ip_config.txt", 1, num_clients,
tmpdir / 'test_sampling.json', disable_shared_mem=disable_shared_mem)
tmpdir / 'test_sampling.json', disable_shared_mem=disable_shared_mem,
graph_format=['csc', 'coo'])
g.start()


Expand Down

0 comments on commit 34426a9

Please sign in to comment.