Skip to content

Commit

Permalink
[Distributed] Specify the graph format for distributed training (dmlc…
Browse files Browse the repository at this point in the history
…#2948)

* explicitly set the graph format.

* fix.

* fix.

* fix launch script.

* fix readme.

Co-authored-by: Zheng <[email protected]>
Co-authored-by: xiang song(charlie.song) <[email protected]>
Co-authored-by: Ubuntu <[email protected]>
  • Loading branch information
4 people authored May 26, 2021
1 parent 1db4ad4 commit 18dbaeb
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 8 deletions.
3 changes: 3 additions & 0 deletions examples/pytorch/graphsage/experimental/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ python3 ~/workspace/dgl/tools/launch.py \
--num_servers 1 \
--part_config data/ogb-product.json \
--ip_config ip_config.txt \
--graph_format csc,coo \
"python3 train_dist_unsupervised.py --graph_name ogb-product --ip_config ip_config.txt --num_epochs 3 --batch_size 1000"
```

Expand Down Expand Up @@ -183,6 +184,7 @@ python3 ~/workspace/dgl/tools/launch.py --workspace ~/workspace/dgl/examples/pyt
--num_servers 1 \
--part_config data/ogb-product.json \
--ip_config ip_config.txt \
--graph_format csc,coo \
"python3 train_dist_unsupervised_transductive.py --graph_name ogb-product --ip_config ip_config.txt --num_epochs 3 --batch_size 1000 --num_gpus 4"
```

Expand All @@ -194,6 +196,7 @@ python3 ~/workspace/dgl/tools/launch.py --workspace ~/workspace/dgl/examples/pyt
--num_servers 1 \
--part_config data/ogb-product.json \
--ip_config ip_config.txt \
--graph_format csc,coo \
"python3 train_dist_unsupervised_transductive.py --graph_name ogb-product --ip_config ip_config.txt --num_epochs 3 --batch_size 1000 --num_gpus 4 --dgl_sparse"
```

Expand Down
5 changes: 4 additions & 1 deletion python/dgl/distributed/dist_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,14 @@ def initialize(ip_config, num_servers=1, num_workers=0,
'Please define DGL_NUM_CLIENT to run DistGraph server'
assert os.environ.get('DGL_CONF_PATH') is not None, \
'Please define DGL_CONF_PATH to run DistGraph server'
formats = os.environ.get('DGL_GRAPH_FORMAT', 'csc').split(',')
formats = [f.strip() for f in formats]
serv = DistGraphServer(int(os.environ.get('DGL_SERVER_ID')),
os.environ.get('DGL_IP_CONFIG'),
int(os.environ.get('DGL_NUM_SERVER')),
int(os.environ.get('DGL_NUM_CLIENT')),
os.environ.get('DGL_CONF_PATH'))
os.environ.get('DGL_CONF_PATH'),
graph_format=formats)
serv.start()
sys.exit()
else:
Expand Down
14 changes: 10 additions & 4 deletions python/dgl/distributed/dist_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ def __getstate__(self):
def __setstate__(self, state):
self._graph_name = state

def _copy_graph_to_shared_mem(g, graph_name):
new_g = g.shared_memory(graph_name, formats='csc')
def _copy_graph_to_shared_mem(g, graph_name, graph_format):
new_g = g.shared_memory(graph_name, formats=graph_format)
# We should share the node/edge data to the client explicitly instead of putting them
# in the KVStore because some of the node/edge data may be duplicated.
new_g.ndata['inner_node'] = _to_shared_mem(g.ndata['inner_node'],
Expand Down Expand Up @@ -291,9 +291,12 @@ class DistGraphServer(KVServer):
The path of the config file generated by the partition tool.
disable_shared_mem : bool
Disable shared memory.
graph_format : str or list of str
The graph formats.
'''
def __init__(self, server_id, ip_config, num_servers,
num_clients, part_config, disable_shared_mem=False):
num_clients, part_config, disable_shared_mem=False,
graph_format='csc'):
super(DistGraphServer, self).__init__(server_id=server_id,
ip_config=ip_config,
num_servers=num_servers,
Expand All @@ -309,8 +312,11 @@ def __init__(self, server_id, ip_config, num_servers,
self.client_g, node_feats, edge_feats, self.gpb, graph_name, \
ntypes, etypes = load_partition(part_config, self.part_id)
print('load ' + graph_name)
# Create the graph formats specified the users.
self.client_g = self.client_g.formats(graph_format)
self.client_g.create_formats_()
if not disable_shared_mem:
self.client_g = _copy_graph_to_shared_mem(self.client_g, graph_name)
self.client_g = _copy_graph_to_shared_mem(self.client_g, graph_name, graph_format)

if not disable_shared_mem:
self.gpb.shared_memory(graph_name)
Expand Down
8 changes: 5 additions & 3 deletions tests/distributed/test_distributed_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@
from dgl.distributed import DistGraphServer, DistGraph


def start_server(rank, tmpdir, disable_shared_mem, graph_name):
def start_server(rank, tmpdir, disable_shared_mem, graph_name, graph_format='csc'):
g = DistGraphServer(rank, "rpc_ip_config.txt", 1, 1,
tmpdir / (graph_name + '.json'), disable_shared_mem=disable_shared_mem)
tmpdir / (graph_name + '.json'), disable_shared_mem=disable_shared_mem,
graph_format=graph_format)
g.start()


Expand Down Expand Up @@ -119,7 +120,8 @@ def check_rpc_find_edges_shuffle(tmpdir, num_server):
pserver_list = []
ctx = mp.get_context('spawn')
for i in range(num_server):
p = ctx.Process(target=start_server, args=(i, tmpdir, num_server > 1, 'test_find_edges'))
p = ctx.Process(target=start_server, args=(i, tmpdir, num_server > 1,
'test_find_edges', ['csr', 'coo']))
p.start()
time.sleep(1)
pserver_list.append(p)
Expand Down
7 changes: 7 additions & 0 deletions tools/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,12 +167,14 @@ def submit_jobs(args, udf_command):
server_cmd = server_cmd + ' ' + 'DGL_CONF_PATH=' + str(args.part_config)
server_cmd = server_cmd + ' ' + 'DGL_IP_CONFIG=' + str(args.ip_config)
server_cmd = server_cmd + ' ' + 'DGL_NUM_SERVER=' + str(args.num_servers)
server_cmd = server_cmd + ' ' + 'DGL_GRAPH_FORMAT=' + str(args.graph_format)
for i in range(len(hosts)*server_count_per_machine):
ip, _ = hosts[int(i / server_count_per_machine)]
cmd = server_cmd + ' ' + 'DGL_SERVER_ID=' + str(i)
cmd = cmd + ' ' + udf_command
cmd = 'cd ' + str(args.workspace) + '; ' + cmd
execute_remote(cmd, ip, args.ssh_port, thread_list)

# launch client tasks
client_cmd = 'DGL_DIST_MODE="distributed" DGL_ROLE=client DGL_NUM_SAMPLER=' + str(args.num_samplers)
client_cmd = client_cmd + ' ' + 'DGL_NUM_CLIENT=' + str(tot_num_clients)
Expand All @@ -185,6 +187,7 @@ def submit_jobs(args, udf_command):
client_cmd = client_cmd + ' ' + 'OMP_NUM_THREADS=' + str(args.num_omp_threads)
if os.environ.get('PYTHONPATH') is not None:
client_cmd = client_cmd + ' ' + 'PYTHONPATH=' + os.environ.get('PYTHONPATH')
client_cmd = client_cmd + ' ' + 'DGL_GRAPH_FORMAT=' + str(args.graph_format)

torch_cmd = '-m torch.distributed.launch'
torch_cmd = torch_cmd + ' ' + '--nproc_per_node=' + str(args.num_trainers)
Expand Down Expand Up @@ -248,6 +251,10 @@ def main():
help='The number of OMP threads in the server process. \
It should be small if server processes and trainer processes run on \
the same machine. By default, it is 1.')
parser.add_argument('--graph_format', type=str, default='csc',
help='The format of the graph structure of each partition. \
The allowed formats are csr, csc and coo. A user can specify multiple \
formats, separated by ",". For example, the graph format is "csr,csc".')
args, udf_command = parser.parse_known_args()
assert len(udf_command) == 1, 'Please provide user command line.'
assert args.num_trainers is not None and args.num_trainers > 0, \
Expand Down

0 comments on commit 18dbaeb

Please sign in to comment.