Skip to content

Commit

Permalink
[Dist] format dtypes when loading graph in server (dmlc#4228)
Browse files Browse the repository at this point in the history
* [Dist] format dtypes when loading graph in server

* add test

* refine

* add comments
  • Loading branch information
Rhett-Ying authored Jul 11, 2022
1 parent 1feec87 commit c65d6fa
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 0 deletions.
11 changes: 11 additions & 0 deletions python/dgl/distributed/dist_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,17 @@ def __init__(self, server_id, ip_config, num_servers,
self.client_g, _, _, self.gpb, graph_name, \
ntypes, etypes = load_partition(part_config, self.part_id, load_feats=False)
print('load ' + graph_name)
# formatting dtype
# TODO(Rui) Formatting forcely is not a perfect solution.
# We'd better store all dtypes when mapping to shared memory
# and map back with original dtypes.
for k, dtype in FIELD_DICT.items():
if k in self.client_g.ndata:
self.client_g.ndata[k] = F.astype(
self.client_g.ndata[k], dtype)
if k in self.client_g.edata:
self.client_g.edata[k] = F.astype(
self.client_g.edata[k], dtype)
# Create the graph formats specified the users.
self.client_g = self.client_g.formats(graph_format)
self.client_g.create_formats_()
Expand Down
9 changes: 9 additions & 0 deletions tests/distributed/test_dist_graph_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,15 @@ def run_server(graph_name, server_id, server_count, num_clients, shared_mem, kee
disable_shared_mem=not shared_mem,
graph_format=['csc', 'coo'], keep_alive=keep_alive)
print('start server', server_id)
# verify dtype of underlying graph
cg = g.client_g
for k, dtype in dgl.distributed.dist_graph.FIELD_DICT.items():
if k in cg.ndata:
assert F.dtype(
cg.ndata[k]) == dtype, "Data type of {} in ndata should be {}.".format(k, dtype)
if k in cg.edata:
assert F.dtype(
cg.edata[k]) == dtype, "Data type of {} in edata should be {}.".format(k, dtype)
g.start()

def emb_init(shape, dtype):
Expand Down

0 comments on commit c65d6fa

Please sign in to comment.