Skip to content

Commit

Permalink
[Distributed] Fix a bug for graphs without node/edge data. (dmlc#2838)
Browse files Browse the repository at this point in the history
* fix.

* test distributed graph without node/edge data.

* remove some tests.

* fix lint
  • Loading branch information
zheng-da authored Apr 13, 2021
1 parent afc83aa commit de5e8e2
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 5 deletions.
16 changes: 11 additions & 5 deletions python/dgl/distributed/kvstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np

from . import rpc
from .graph_partition_book import PartitionPolicy
from .graph_partition_book import NodePartitionPolicy, EdgePartitionPolicy
from .standalone_kvstore import KVClient as SA_KVClient

from .. import backend as F
Expand Down Expand Up @@ -365,8 +365,6 @@ def process_request(self, server_state):
meta[name] = (F.shape(data),
F.reverse_data_type_dict[F.dtype(data)],
kv_store.part_policy[name].policy_str)
if len(meta) == 0:
raise RuntimeError('There is no data on kvserver.')
res = GetSharedDataResponse(meta)
return res

Expand Down Expand Up @@ -1058,6 +1056,14 @@ def map_shared_data(self, partition_book):
partition_book : GraphPartitionBook
Store the partition information
"""
# Get all partition policies
for ntype in partition_book.ntypes:
policy = NodePartitionPolicy(partition_book, ntype)
self._all_possible_part_policy[policy.policy_str] = policy
for etype in partition_book.etypes:
policy = EdgePartitionPolicy(partition_book, etype)
self._all_possible_part_policy[policy.policy_str] = policy

# Get shared data from server side
self.barrier()
request = GetSharedDataRequest(GET_SHARED_MSG)
Expand All @@ -1066,11 +1072,11 @@ def map_shared_data(self, partition_book):
for name, meta in response.meta.items():
if name not in self._data_name_list:
shape, dtype, policy_str = meta
assert policy_str in self._all_possible_part_policy
shared_data = empty_shared_mem(name+'-kvdata-', False, shape, dtype)
dlpack = shared_data.to_dlpack()
self._data_store[name] = F.zerocopy_from_dlpack(dlpack)
self._part_policy[name] = PartitionPolicy(policy_str, partition_book)
self._all_possible_part_policy[policy_str] = self._part_policy[name]
self._part_policy[name] = self._all_possible_part_policy[policy_str]
self._pull_handlers[name] = default_pull_handler
self._push_handlers[name] = default_push_handler
# Get full data shape across servers
Expand Down
74 changes: 74 additions & 0 deletions tests/distributed/test_dist_graph_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,79 @@ def emb_init(shape, dtype):
def rand_init(shape, dtype):
return F.tensor(np.random.normal(size=shape), F.float32)

def check_dist_graph_empty(g, num_clients, num_nodes, num_edges):
# Test API
assert g.number_of_nodes() == num_nodes
assert g.number_of_edges() == num_edges

# Test init node data
new_shape = (g.number_of_nodes(), 2)
g.ndata['test1'] = dgl.distributed.DistTensor(new_shape, F.int32)
nids = F.arange(0, int(g.number_of_nodes() / 2))
feats = g.ndata['test1'][nids]
assert np.all(F.asnumpy(feats) == 0)

# create a tensor and destroy a tensor and create it again.
test3 = dgl.distributed.DistTensor(new_shape, F.float32, 'test3', init_func=rand_init)
del test3
test3 = dgl.distributed.DistTensor((g.number_of_nodes(), 3), F.float32, 'test3')
del test3

# Test write data
new_feats = F.ones((len(nids), 2), F.int32, F.cpu())
g.ndata['test1'][nids] = new_feats
feats = g.ndata['test1'][nids]
assert np.all(F.asnumpy(feats) == 1)

# Test metadata operations.
assert g.node_attr_schemes()['test1'].dtype == F.int32

print('end')

def run_client_empty(graph_name, part_id, server_count, num_clients, num_nodes, num_edges):
time.sleep(5)
os.environ['DGL_NUM_SERVER'] = str(server_count)
dgl.distributed.initialize("kv_ip_config.txt")
gpb, graph_name, _, _ = load_partition_book('/tmp/dist_graph/{}.json'.format(graph_name),
part_id, None)
g = DistGraph(graph_name, gpb=gpb)
check_dist_graph_empty(g, num_clients, num_nodes, num_edges)

def check_server_client_empty(shared_mem, num_servers, num_clients):
prepare_dist()
g = create_random_graph(10000)

# Partition the graph
num_parts = 1
graph_name = 'dist_graph_test_1'
partition_graph(g, graph_name, num_parts, '/tmp/dist_graph')

# let's just test on one partition for now.
# We cannot run multiple servers and clients on the same machine.
serv_ps = []
ctx = mp.get_context('spawn')
for serv_id in range(num_servers):
p = ctx.Process(target=run_server, args=(graph_name, serv_id, num_servers,
num_clients, shared_mem))
serv_ps.append(p)
p.start()

cli_ps = []
for cli_id in range(num_clients):
print('start client', cli_id)
p = ctx.Process(target=run_client_empty, args=(graph_name, 0, num_servers, num_clients,
g.number_of_nodes(), g.number_of_edges()))
p.start()
cli_ps.append(p)

for p in cli_ps:
p.join()

for p in serv_ps:
p.join()

print('clients have terminated')

def run_client(graph_name, part_id, server_count, num_clients, num_nodes, num_edges):
time.sleep(5)
os.environ['DGL_NUM_SERVER'] = str(server_count)
Expand Down Expand Up @@ -380,6 +453,7 @@ def check_server_client_hetero(shared_mem, num_servers, num_clients):
@unittest.skipIf(dgl.backend.backend_name == "tensorflow", reason="TF doesn't support some of operations in DistGraph")
def test_server_client():
os.environ['DGL_DIST_MODE'] = 'distributed'
check_server_client_empty(True, 1, 1)
check_server_client_hetero(True, 1, 1)
check_server_client_hetero(False, 1, 1)
check_server_client(True, 1, 1)
Expand Down

0 comments on commit de5e8e2

Please sign in to comment.