Skip to content

Commit

Permalink
[Distributed] Fix a few bugs in distributed API (dmlc#3094)
Browse files Browse the repository at this point in the history
* fix.

* fix.

* fix.

* fix.

* Fix test

Co-authored-by: Ubuntu <[email protected]>
Co-authored-by: Ubuntu <[email protected]>
  • Loading branch information
3 people authored Jul 5, 2021
1 parent 595d4e3 commit 485c04c
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 10 deletions.
4 changes: 2 additions & 2 deletions python/dgl/distributed/dist_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -995,7 +995,7 @@ def barrier(self):
def _get_ndata_names(self, ntype=None):
''' Get the names of all node data.
'''
names = self._client.data_name_list()
names = self._client.gdata_name_list()
ndata_names = []
for name in names:
name = parse_hetero_data_name(name)
Expand All @@ -1007,7 +1007,7 @@ def _get_ndata_names(self, ntype=None):
def _get_edata_names(self, etype=None):
''' Get the names of all edge data.
'''
names = self._client.data_name_list()
names = self._client.gdata_name_list()
edata_names = []
for name in names:
name = parse_hetero_data_name(name)
Expand Down
2 changes: 1 addition & 1 deletion python/dgl/distributed/dist_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,8 @@ def __init__(self, shape, dtype, name=None, init_func=None, part_policy=None,
self._name = str(data_name)
self._persistent = persistent
if self._name not in exist_names:
self.kvstore.init_data(self._name, shape, dtype, part_policy, init_func, is_gdata)
self._owner = True
self.kvstore.init_data(self._name, shape, dtype, part_policy, init_func, is_gdata)
else:
self._owner = False
dtype1, shape1, _ = self.kvstore.get_data_meta(self._name)
Expand Down
6 changes: 5 additions & 1 deletion python/dgl/distributed/kvstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -1123,9 +1123,13 @@ def map_shared_data(self, partition_book):
self._gdata_name_list.add(name)
self.barrier()

def gdata_name_list(self):
"""Get all the graph data name"""
return list(self._gdata_name_list)

def data_name_list(self):
"""Get all the data name"""
return list(self._gdata_name_list)
return list(self._data_name_list)

def get_data_meta(self, name):
"""Get meta data (data_type, data_shape, partition_policy)
Expand Down
3 changes: 2 additions & 1 deletion python/dgl/distributed/optim/pytorch/sparse_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ def __init__(self, params, lr):
self._rank = th.distributed.get_rank()
self._world_size = th.distributed.get_world_size()
else:
assert 'th.distributed shoud be initialized'
self._rank = 0
self._world_size = 1

def step(self):
''' The step function.
Expand Down
15 changes: 14 additions & 1 deletion python/dgl/distributed/standalone_kvstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,19 @@ def __init__(self):
self._all_possible_part_policy = {}
self._push_handlers = {}
self._pull_handlers = {}
# Store all graph data name
self._gdata_name_list = set()

@property
def all_possible_part_policy(self):
"""Get all possible partition policies"""
return self._all_possible_part_policy

@property
def num_servers(self):
"""Get the number of servers"""
return 1

def barrier(self):
'''barrier'''

Expand All @@ -39,11 +46,13 @@ def add_data(self, name, tensor, part_policy):
if part_policy.policy_str not in self._all_possible_part_policy:
self._all_possible_part_policy[part_policy.policy_str] = part_policy

def init_data(self, name, shape, dtype, part_policy, init_func):
def init_data(self, name, shape, dtype, part_policy, init_func, is_gdata=True):
'''add new data to the client'''
self._data[name] = init_func(shape, dtype)
if part_policy.policy_str not in self._all_possible_part_policy:
self._all_possible_part_policy[part_policy.policy_str] = part_policy
if is_gdata:
self._gdata_name_list.add(name)

def delete_data(self, name):
'''delete the data'''
Expand All @@ -53,6 +62,10 @@ def data_name_list(self):
'''get the names of all data'''
return list(self._data.keys())

def gdata_name_list(self):
'''get the names of graph data'''
return list(self._gdata_name_list)

def get_data_meta(self, name):
'''get the metadata of data'''
return F.dtype(self._data[name]), F.shape(self._data[name]), None
Expand Down
11 changes: 7 additions & 4 deletions tests/distributed/test_dist_graph_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def check_dist_emb(g, num_clients, num_nodes, num_edges):
assert np.all(F.asnumpy(feats1) == np.zeros((len(rest), 1)))

policy = dgl.distributed.PartitionPolicy('node', g.get_partition_book())
grad_sum = dgl.distributed.DistTensor((g.number_of_nodes(),), F.float32,
grad_sum = dgl.distributed.DistTensor((g.number_of_nodes(), 1), F.float32,
'emb1_sum', policy)
if num_clients == 1:
assert np.all(F.asnumpy(grad_sum[nids]) == np.ones((len(nids), 1)) * num_clients)
Expand All @@ -216,12 +216,15 @@ def check_dist_emb(g, num_clients, num_nodes, num_edges):
with F.no_grad():
feats = emb(nids)
if num_clients == 1:
assert_almost_equal(F.asnumpy(feats), np.ones((len(nids), 1)) * math.sqrt(2) * -lr)
assert_almost_equal(F.asnumpy(feats), np.ones((len(nids), 1)) * 1 * -lr)
rest = np.setdiff1d(np.arange(g.number_of_nodes()), F.asnumpy(nids))
feats1 = emb(rest)
assert np.all(F.asnumpy(feats1) == np.zeros((len(rest), 1)))
except NotImplementedError as e:
pass
except Exception as e:
print(e)
sys.exit(-1)

def check_dist_graph(g, num_clients, num_nodes, num_edges):
# Test API
Expand Down Expand Up @@ -332,6 +335,7 @@ def check_dist_emb_server_client(shared_mem, num_servers, num_clients):

for p in cli_ps:
p.join()
assert p.exitcode == 0

for p in serv_ps:
p.join()
Expand Down Expand Up @@ -590,7 +594,6 @@ def test_dist_emb_server_client():
check_dist_emb_server_client(True, 1, 1)
check_dist_emb_server_client(False, 1, 1)
check_dist_emb_server_client(True, 2, 2)
check_dist_emb_server_client(False, 2, 2)

@unittest.skipIf(dgl.backend.backend_name == "tensorflow", reason="TF doesn't support some of operations in DistGraph")
def test_standalone():
Expand Down Expand Up @@ -765,9 +768,9 @@ def prepare_dist():

if __name__ == '__main__':
os.makedirs('/tmp/dist_graph', exist_ok=True)
test_dist_emb_server_client()
test_server_client()
test_split()
test_split_even()
test_standalone()

test_standalone_node_emb()

0 comments on commit 485c04c

Please sign in to comment.