Skip to content

Commit

Permalink
[Distributed] Refactor distributed training to new DGLGraph. (dmlc#1874)
Browse files Browse the repository at this point in the history
* fix tests in partition.

* fix DistGraph.

* fix without shared memory.

* fix sampling.

* enable distributed test.

* fix tests.

* fix a bug in shared-mem heterograph.

* print better error messages.

* fix.

* don't specify formats.

* fix.

* fix

* small fix.
  • Loading branch information
zheng-da authored Jul 29, 2020
1 parent 05a4337 commit 3e2f94e
Show file tree
Hide file tree
Showing 11 changed files with 31 additions and 51 deletions.
28 changes: 9 additions & 19 deletions python/dgl/distributed/dist_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
import os
import numpy as np

from ..graph import DGLGraph
from ..heterograph import DGLHeteroGraph
from .. import heterograph_index
from .. import backend as F
from ..base import NID, EID
from .kvstore import KVServer, KVClient
from .standalone_kvstore import KVClient as SA_KVClient
from ..graph_index import from_shared_mem_graph_index
from .._ffi.ndarray import empty_shared_mem
from ..frame import infer_scheme
from .partition import load_partition
Expand All @@ -21,14 +21,9 @@
from .server_state import ServerState
from .rpc_server import start_server
from .dist_tensor import DistTensor, _get_data_name
from ..transform import as_heterograph

def _get_graph_path(graph_name):
return "/" + graph_name

def _copy_graph_to_shared_mem(g, graph_name):
gidx = g._graph.copyto_shared_mem(_get_graph_path(graph_name))
new_g = DGLGraph(gidx)
new_g = g.shared_memory(graph_name, formats='csc')
# We should share the node/edge data to the client explicitly instead of putting them
# in the KVStore because some of the node/edge data may be duplicated.
local_node_path = _get_ndata_path(graph_name, 'inner_node')
Expand Down Expand Up @@ -85,11 +80,10 @@ def _get_graph_from_shared_mem(graph_name):
The client can access the graph structure and some metadata on nodes and edges directly
through shared memory to reduce the overhead of data access.
'''
gidx = from_shared_mem_graph_index(_get_graph_path(graph_name))
if gidx is None:
return gidx

g = DGLGraph(gidx)
g, ntypes, etypes = heterograph_index.create_heterograph_from_shared_memory(graph_name)
if g is None:
return None
g = DGLHeteroGraph(g, ntypes, etypes)
g.ndata['inner_node'] = _get_shared_mem_ndata(g, graph_name, 'inner_node')
g.edata['inner_edge'] = _get_shared_mem_edata(g, graph_name, 'inner_edge')
g.ndata[NID] = _get_shared_mem_ndata(g, graph_name, NID)
Expand Down Expand Up @@ -306,7 +300,7 @@ def __init__(self, ip_config, graph_name, gpb=None, conf_file=None):
'The standalone mode can only work with the graph data with one partition'
if self._gpb is None:
self._gpb = gpb
self._g = as_heterograph(g)
self._g = g
for name in node_feats:
self._client.add_data(_get_data_name(name, NODE_PART_POLICY), node_feats[name])
for name in edge_feats:
Expand All @@ -315,11 +309,7 @@ def __init__(self, ip_config, graph_name, gpb=None, conf_file=None):
else:
connect_to_server(ip_config=ip_config)
self._client = KVClient(ip_config)
g = _get_graph_from_shared_mem(graph_name)
if g is not None:
self._g = as_heterograph(g)
else:
self._g = None
self._g = _get_graph_from_shared_mem(graph_name)
self._gpb = get_shared_mem_partition_book(graph_name, self._g)
if self._gpb is None:
self._gpb = gpb
Expand Down
3 changes: 1 addition & 2 deletions python/dgl/distributed/graph_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,7 @@ def merge_graphs(res_list, num_nodes):
src_tensor = res_list[0].global_src
dst_tensor = res_list[0].global_dst
eid_tensor = res_list[0].global_eids
g = graph((src_tensor, dst_tensor),
restrict_format='coo', num_nodes=num_nodes)
g = graph((src_tensor, dst_tensor), num_nodes=num_nodes)
g.edata[EID] = eid_tensor
return g

Expand Down
7 changes: 1 addition & 6 deletions python/dgl/distributed/server_state.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
"""Server data"""

from .._ffi.function import _init_api
from ..graph import DGLGraph
from ..transform import as_heterograph

# Remove C++ bindings for now, since not used

Expand Down Expand Up @@ -63,10 +61,7 @@ def graph(self):

@graph.setter
def graph(self, graph):
if isinstance(graph, DGLGraph):
self._graph = as_heterograph(graph)
else:
self._graph = graph
self._graph = graph


_init_api("dgl.distributed.server_state")
6 changes: 4 additions & 2 deletions python/dgl/heterograph.py
Original file line number Diff line number Diff line change
Expand Up @@ -4839,7 +4839,7 @@ def shared_memory(self, name, formats=('coo', 'csr', 'csc')):
----------
name : str
The name of the shared memory.
formats : list of str (optional)
formats : str or a list of str (optional)
Desired formats to be materialized.
Returns
Expand All @@ -4849,8 +4849,10 @@ def shared_memory(self, name, formats=('coo', 'csr', 'csc')):
"""
assert len(name) > 0, "The name of shared memory cannot be empty"
assert len(formats) > 0
if isinstance(formats, str):
formats = [formats]
for fmt in formats:
assert fmt in ("coo", "csr", "csc")
assert fmt in ("coo", "csr", "csc"), '{} is not coo, csr or csc'.format(fmt)
gidx = self._graph.shared_memory(name, self.ntypes, self.etypes, formats)
return DGLHeteroGraph(gidx, self.ntypes, self.etypes)

Expand Down
4 changes: 4 additions & 0 deletions src/graph/heterograph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,10 @@ HeteroGraphPtr HeteroGraph::CopyToSharedMem(

std::tuple<HeteroGraphPtr, std::vector<std::string>, std::vector<std::string>>
HeteroGraph::CreateFromSharedMem(const std::string &name) {
bool exist = SharedMemory::Exist(name);
if (!exist) {
return std::make_tuple(nullptr, std::vector<std::string>(), std::vector<std::string>());
}
auto mem = std::make_shared<SharedMemory>(name);
auto mem_buf = mem->Open(SHARED_MEM_METAINFO_SIZE_MAX);
dmlc::MemoryFixedSizeStream strm(mem_buf, SHARED_MEM_METAINFO_SIZE_MAX);
Expand Down
2 changes: 1 addition & 1 deletion src/graph/shared_mem_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,8 @@ bool SharedMemManager::CreateFromSharedMem<COOMatrix>(COOMatrix *coo,
template <>
bool SharedMemManager::CreateFromSharedMem<CSRMatrix>(CSRMatrix *csr,
std::string name) {
CreateFromSharedMem(&csr->indices, name + "_indices");
CreateFromSharedMem(&csr->indptr, name + "_indptr");
CreateFromSharedMem(&csr->indices, name + "_indices");
CreateFromSharedMem(&csr->data, name + "_data");
strm_->Read(&csr->num_rows);
strm_->Read(&csr->num_cols);
Expand Down
9 changes: 4 additions & 5 deletions tests/distributed/test_dist_graph_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from numpy.testing import assert_array_equal
from multiprocessing import Process, Manager, Condition, Value
import multiprocessing as mp
from dgl.graph_index import create_graph_index
from dgl.heterograph_index import create_unitgraph_from_coo
from dgl.data.utils import load_graphs, save_graphs
from dgl.distributed import DistGraphServer, DistGraph
from dgl.distributed import partition_graph, load_partition, load_partition_book, node_split, edge_split
Expand Down Expand Up @@ -50,9 +50,8 @@ def get_local_usable_addr():
return ip_addr + ' ' + str(port)

def create_random_graph(n):
arr = (spsp.random(n, n, density=0.001, format='coo') != 0).astype(np.int64)
ig = create_graph_index(arr, readonly=True)
return dgl.DGLGraph(ig)
arr = (spsp.random(n, n, density=0.001, format='coo', random_state=100) != 0).astype(np.int64)
return dgl.graph(arr)

def run_server(graph_name, server_id, num_clients, shared_mem):
g = DistGraphServer(server_id, "kv_ip_config.txt", num_clients,
Expand All @@ -65,7 +64,7 @@ def emb_init(shape, dtype):
return F.zeros(shape, dtype, F.cpu())

def rand_init(shape, dtype):
return F.tensor(np.random.normal(size=shape))
return F.tensor(np.random.normal(size=shape), F.float32)

def run_client(graph_name, part_id, num_nodes, num_edges):
time.sleep(5)
Expand Down
2 changes: 0 additions & 2 deletions tests/distributed/test_distributed_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,6 @@ def test_rpc_sampling_shuffle():

def check_standalone_sampling(tmpdir):
g = CitationGraphDataset("cora")[0]
g.readonly()
num_parts = 1
num_hops = 1
partition_graph(g, 'test_sampling', num_parts, tmpdir,
Expand Down Expand Up @@ -193,7 +192,6 @@ def check_rpc_in_subgraph(tmpdir, num_server):
p.join()

src, dst = sampled_graph.edges()
g = dgl.as_heterograph(g)
assert sampled_graph.number_of_nodes() == g.number_of_nodes()
subg1 = dgl.in_subgraph(g, nodes)
src1, dst1 = subg1.edges()
Expand Down
5 changes: 0 additions & 5 deletions tests/distributed/test_new_kvstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,6 @@ def get_local_usable_addr():

return ip_addr + ' ' + str(port)

def create_random_graph(n):
arr = (spsp.random(n, n, density=0.001, format='coo') != 0).astype(np.int64)
ig = create_graph_index(arr, readonly=True)
return dgl.DGLGraph(ig)

# Create an one-part Graph
node_map = F.tensor([0,0,0,0,0,0], F.int64)
edge_map = F.tensor([0,0,0,0,0,0,0], F.int64)
Expand Down
10 changes: 4 additions & 6 deletions tests/distributed/test_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
from scipy import sparse as spsp
from numpy.testing import assert_array_equal
from dgl.graph_index import create_graph_index
from dgl.heterograph_index import create_unitgraph_from_coo
from dgl.distributed import partition_graph, load_partition
from dgl import function as fn
import backend as F
Expand All @@ -14,13 +14,12 @@

def create_random_graph(n):
arr = (spsp.random(n, n, density=0.001, format='coo', random_state=100) != 0).astype(np.int64)
ig = create_graph_index(arr, readonly=True)
return dgl.DGLGraph(ig)
return dgl.graph(arr)

def check_partition(g, part_method, reshuffle):
g.ndata['labels'] = F.arange(0, g.number_of_nodes())
g.ndata['feats'] = F.tensor(np.random.randn(g.number_of_nodes(), 10))
g.edata['feats'] = F.tensor(np.random.randn(g.number_of_edges(), 10))
g.ndata['feats'] = F.tensor(np.random.randn(g.number_of_nodes(), 10), F.float32)
g.edata['feats'] = F.tensor(np.random.randn(g.number_of_edges(), 10), F.float32)
g.update_all(fn.copy_src('feats', 'msg'), fn.sum('msg', 'h'))
g.update_all(fn.copy_edge('feats', 'msg'), fn.sum('msg', 'eh'))
num_parts = 4
Expand Down Expand Up @@ -112,7 +111,6 @@ def test_partition():

def test_hetero_partition():
g = create_random_graph(10000)
g = dgl.as_heterograph(g)
check_partition(g, 'metis', True)
check_partition(g, 'metis', False)
check_partition(g, 'random', True)
Expand Down
6 changes: 3 additions & 3 deletions tests/scripts/task_unit_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,6 @@ python3 -m pytest -v --junitxml=pytest_gindex.xml tests/graph_index || fail "gra
python3 -m pytest -v --junitxml=pytest_backend.xml tests/$DGLBACKEND || fail "backend-specific"

export OMP_NUM_THREADS=1
#if [ $2 != "gpu" ]; then
# python3 -m pytest -v --junitxml=pytest_distributed.xml tests/distributed || fail "distributed"
#fi
if [ $2 != "gpu" ]; then
python3 -m pytest -v --junitxml=pytest_distributed.xml tests/distributed || fail "distributed"
fi

0 comments on commit 3e2f94e

Please sign in to comment.