-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Distributed] add copy_partitions.py (dmlc#1866)
* fix bugs. * eval on both vaidation and testing. * add script. * update. * update launch. * make train_dist.py independent. * update readme. * update readme. * update readme. * update readme. * generate undirected graph. * rename conf_file to part_config * use rsync * make train_dist independent. Co-authored-by: Ubuntu <[email protected]> Co-authored-by: Ubuntu <[email protected]> Co-authored-by: xiang song(charlie.song) <[email protected]>
- Loading branch information
1 parent
8b64037
commit 4be4b13
Showing
8 changed files
with
207 additions
and
34 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
"""Copy the partitions to a cluster of machines.""" | ||
import os | ||
import stat | ||
import sys | ||
import subprocess | ||
import argparse | ||
import signal | ||
import logging | ||
import json | ||
import copy | ||
|
||
def copy_file(file_name, ip, workspace): | ||
print('copy {} to {}'.format(file_name, ip + ':' + workspace + '/')) | ||
cmd = 'rsync -e \"ssh -o StrictHostKeyChecking=no\" -arvc ' + file_name + ' ' + ip + ':' + workspace + '/' | ||
subprocess.check_call(cmd, shell = True) | ||
|
||
def exec_cmd(ip, cmd): | ||
cmd = 'ssh -o StrictHostKeyChecking=no ' + ip + ' \'' + cmd + '\'' | ||
subprocess.check_call(cmd, shell = True) | ||
|
||
def main(): | ||
parser = argparse.ArgumentParser(description='Copy data to the servers.') | ||
parser.add_argument('--workspace', type=str, required=True, | ||
help='Path of user directory of distributed tasks. \ | ||
This is used to specify a destination location where \ | ||
data are copied to on remote machines.') | ||
parser.add_argument('--rel_data_path', type=str, required=True, | ||
help='Relative path in workspace to store the partition data.') | ||
parser.add_argument('--part_config', type=str, required=True, | ||
help='The partition config file. The path is on the local machine.') | ||
parser.add_argument('--ip_config', type=str, required=True, | ||
help='The file of IP configuration for servers. \ | ||
The path is on the local machine.') | ||
args = parser.parse_args() | ||
|
||
hosts = [] | ||
with open(args.ip_config) as f: | ||
for line in f: | ||
ip, _, _ = line.strip().split(' ') | ||
hosts.append(ip) | ||
|
||
|
||
# We need to update the partition config file so that the paths are relative to | ||
# the workspace in the remote machines. | ||
with open(args.part_config) as conf_f: | ||
part_metadata = json.load(conf_f) | ||
tmp_part_metadata = copy.deepcopy(part_metadata) | ||
num_parts = part_metadata['num_parts'] | ||
assert num_parts == len(hosts), \ | ||
'The number of partitions needs to be the same as the number of hosts.' | ||
graph_name = part_metadata['graph_name'] | ||
node_map = part_metadata['node_map'] | ||
edge_map = part_metadata['edge_map'] | ||
if not isinstance(node_map, list): | ||
assert node_map[-4:] == '.npy', 'node map should be stored in a NumPy array.' | ||
tmp_part_metadata['node_map'] = '{}/{}/node_map.npy'.format(args.workspace, | ||
args.rel_data_path) | ||
if not isinstance(edge_map, list): | ||
assert edge_map[-4:] == '.npy', 'edge map should be stored in a NumPy array.' | ||
tmp_part_metadata['edge_map'] = '{}/{}/edge_map.npy'.format(args.workspace, | ||
args.rel_data_path) | ||
|
||
for part_id in range(num_parts): | ||
part_files = tmp_part_metadata['part-{}'.format(part_id)] | ||
part_files['edge_feats'] = '{}/part{}/edge_feat.dgl'.format(args.rel_data_path, part_id) | ||
part_files['node_feats'] = '{}/part{}/node_feat.dgl'.format(args.rel_data_path, part_id) | ||
part_files['part_graph'] = '{}/part{}/graph.dgl'.format(args.rel_data_path, part_id) | ||
tmp_part_config = '/tmp/{}.json'.format(graph_name) | ||
with open(tmp_part_config, 'w') as outfile: | ||
json.dump(tmp_part_metadata, outfile, sort_keys=True, indent=4) | ||
|
||
# Copy ip config. | ||
for part_id, ip in enumerate(hosts): | ||
remote_path = '{}/{}'.format(args.workspace, args.rel_data_path) | ||
exec_cmd(ip, 'mkdir -p {}'.format(remote_path)) | ||
|
||
copy_file(args.ip_config, ip, args.workspace) | ||
copy_file(tmp_part_config, ip, '{}/{}'.format(args.workspace, args.rel_data_path)) | ||
node_map = part_metadata['node_map'] | ||
edge_map = part_metadata['edge_map'] | ||
if not isinstance(node_map, list): | ||
copy_file(node_map, ip, tmp_part_metadata['node_map']) | ||
if not isinstance(edge_map, list): | ||
copy_file(edge_map, ip, tmp_part_metadata['edge_map']) | ||
remote_path = '{}/{}/part{}'.format(args.workspace, args.rel_data_path, part_id) | ||
exec_cmd(ip, 'mkdir -p {}'.format(remote_path)) | ||
|
||
part_files = part_metadata['part-{}'.format(part_id)] | ||
copy_file(part_files['node_feats'], ip, remote_path) | ||
copy_file(part_files['edge_feats'], ip, remote_path) | ||
copy_file(part_files['part_graph'], ip, remote_path) | ||
|
||
|
||
def signal_handler(signal, frame): | ||
logging.info('Stop copying') | ||
sys.exit(0) | ||
|
||
if __name__ == '__main__': | ||
fmt = '%(asctime)s %(levelname)s %(message)s' | ||
logging.basicConfig(format=fmt, level=logging.INFO) | ||
signal.signal(signal.SIGINT, signal_handler) | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters