From 4be4b134247cf79617480e5f4646dfa07bd96a4e Mon Sep 17 00:00:00 2001 From: Da Zheng Date: Fri, 31 Jul 2020 11:11:56 -0700 Subject: [PATCH] [Distributed] add copy_partitions.py (#1866) * fix bugs. * eval on both vaidation and testing. * add script. * update. * update launch. * make train_dist.py independent. * update readme. * update readme. * update readme. * update readme. * generate undirected graph. * rename conf_file to part_config * use rsync * make train_dist independent. Co-authored-by: Ubuntu Co-authored-by: Ubuntu Co-authored-by: xiang song(charlie.song) --- .../pytorch/graphsage/experimental/README.md | 27 +++-- .../graphsage/experimental/partition_graph.py | 12 +++ .../graphsage/experimental/train_dist.py | 70 ++++++++++-- python/dgl/distributed/dist_graph.py | 14 +-- tests/distributed/test_dist_graph_store.py | 2 +- .../distributed/test_distributed_sampling.py | 2 +- tools/copy_partitions.py | 102 ++++++++++++++++++ tools/launch.py | 12 +-- 8 files changed, 207 insertions(+), 34 deletions(-) create mode 100644 tools/copy_partitions.py diff --git a/examples/pytorch/graphsage/experimental/README.md b/examples/pytorch/graphsage/experimental/README.md index 59710e036410..757b94ce28cf 100644 --- a/examples/pytorch/graphsage/experimental/README.md +++ b/examples/pytorch/graphsage/experimental/README.md @@ -21,22 +21,33 @@ python3 partition_graph.py --dataset ogb-product --num_parts 4 --balance_train - ### Step 2: copy the partitioned data to the cluster -When copying data to the cluster, we recommend users to copy the partitioned data to NFS so that all worker machines -will be able to access the partitioned data. +DGL provides a script for copying partitioned data to the cluster. The command below copies partition data +to the machines in the cluster. The configuration of the cluster is defined by `ip_config.txt`, +The data is copied to `~/graphsage/ogb-product` on each of the remote machines. `--part_config` +specifies the location of the partitioned data in the local machine (a user only needs to specify +the location of the partition configuration file). +```bash +python3 ~/dgl/tools/copy_partitions.py --ip_config ip_config.txt \ + --workspace ~/graphsage --rel_data_path ogb-product \ + --part_config data/ogb-product.json +``` -### Step 3: Launch distributed jobs +**Note**: users need to make sure that the master node has right permission to ssh to all the other nodes. -First make sure that the master node has the right permission to ssh to all the other nodes. Change the `ip_config.txt` file by using your own instance IP. +Users need to copy the training script to the workspace directory on remote machines as well. + +### Step 3: Launch distributed jobs -Then run script: +DGL provides a script to launch the training job in the cluster. `part_config` and `ip_config` +specify relative paths to the path of the workspace. ```bash python3 ~/dgl/tools/launch.py \ ---workspace ~/dgl/examples/pytorch/graphsage/experimental \ +--workspace ~/graphsage/ \ --num_client 4 \ ---conf_path data/ogb-product.json \ +--part_config ogb-product/ogb-product.json \ --ip_config ip_config.txt \ -"python3 train_dist.py --graph-name ogb-product --ip_config ip_config.txt --num-epochs 30 --batch-size 1000 --lr 0.1 --num-client 4" +"python3 train_dist.py --graph-name ogb-product --ip_config ip_config.txt --num-epochs 30 --batch-size 1000" ``` ## Distributed code runs in the standalone mode diff --git a/examples/pytorch/graphsage/experimental/partition_graph.py b/examples/pytorch/graphsage/experimental/partition_graph.py index 5bca33a8399f..9162a28cda32 100644 --- a/examples/pytorch/graphsage/experimental/partition_graph.py +++ b/examples/pytorch/graphsage/experimental/partition_graph.py @@ -12,8 +12,12 @@ help='datasets: reddit, ogb-product, ogb-paper100M') argparser.add_argument('--num_parts', type=int, default=4, help='number of partitions') + argparser.add_argument('--part_method', type=str, default='metis', + help='the partition method') argparser.add_argument('--balance_train', action='store_true', help='balance the training size in each partition.') + argparser.add_argument('--undirected', action='store_true', + help='turn the graph into an undirected graph.') argparser.add_argument('--balance_edges', action='store_true', help='balance the number of edges in each partition.') args = argparser.parse_args() @@ -34,6 +38,14 @@ balance_ntypes = g.ndata['train_mask'] else: balance_ntypes = None + + if args.undirected: + sym_g = dgl.to_bidirected_stale(g, readonly=True) + for key in g.ndata: + sym_g.ndata[key] = g.ndata[key] + g = sym_g + dgl.distributed.partition_graph(g, args.dataset, args.num_parts, 'data', + part_method=args.part_method, balance_ntypes=balance_ntypes, balance_edges=args.balance_edges) diff --git a/examples/pytorch/graphsage/experimental/train_dist.py b/examples/pytorch/graphsage/experimental/train_dist.py index 4df04e684bed..230432069ae8 100644 --- a/examples/pytorch/graphsage/experimental/train_dist.py +++ b/examples/pytorch/graphsage/experimental/train_dist.py @@ -21,8 +21,6 @@ from torch.utils.data import DataLoader from pyinstrument import Profiler -from train_sampling import run, SAGE, compute_acc, evaluate, load_subtensor - class NeighborSampler(object): def __init__(self, g, fanouts, sample_neighbors): self.g = g @@ -43,11 +41,29 @@ def sample_blocks(self, seeds): blocks.insert(0, block) return blocks -class DistSAGE(SAGE): +class DistSAGE(nn.Module): def __init__(self, in_feats, n_hidden, n_classes, n_layers, activation, dropout): - super(DistSAGE, self).__init__(in_feats, n_hidden, n_classes, n_layers, - activation, dropout) + super().__init__() + self.n_layers = n_layers + self.n_hidden = n_hidden + self.n_classes = n_classes + self.layers = nn.ModuleList() + self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, 'mean')) + for i in range(1, n_layers - 1): + self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, 'mean')) + self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, 'mean')) + self.dropout = nn.Dropout(dropout) + self.activation = activation + + def forward(self, blocks, x): + h = x + for l, (layer, block) in enumerate(zip(self.layers, blocks)): + h = layer(block, h) + if l != len(self.layers) - 1: + h = self.activation(h) + h = self.dropout(h) + return h def inference(self, g, x, batch_size, device): """ @@ -100,9 +116,40 @@ def inference(self, g, x, batch_size, device): g.barrier() return y +def compute_acc(pred, labels): + """ + Compute the accuracy of prediction given the labels. + """ + labels = labels.long() + return (th.argmax(pred, dim=1) == labels).float().sum() / len(pred) + +def evaluate(model, g, inputs, labels, val_nid, test_nid, batch_size, device): + """ + Evaluate the model on the validation set specified by ``val_nid``. + g : The entire graph. + inputs : The features of all the nodes. + labels : The labels of all the nodes. + val_nid : the node Ids for validation. + batch_size : Number of nodes to compute at the same time. + device : The GPU device to evaluate on. + """ + model.eval() + with th.no_grad(): + pred = model.inference(g, inputs, batch_size, device) + model.train() + return compute_acc(pred[val_nid], labels[val_nid]), compute_acc(pred[test_nid], labels[test_nid]) + +def load_subtensor(g, seeds, input_nodes, device): + """ + Copys features and labels of a set of nodes onto GPU. + """ + batch_inputs = g.ndata['features'][input_nodes].to(device) + batch_labels = g.ndata['labels'][seeds].to(device) + return batch_inputs, batch_labels + def run(args, device, data): # Unpack data - train_nid, val_nid, in_feats, n_classes, g = data + train_nid, val_nid, test_nid, in_feats, n_classes, g = data # Create sampler sampler = NeighborSampler(g, [int(fanout) for fanout in args.fan_out.split(',')], dgl.distributed.sample_neighbors) @@ -204,9 +251,10 @@ def run(args, device, data): if epoch % args.eval_every == 0 and epoch != 0: start = time.time() - eval_acc = evaluate(model.module, g, g.ndata['features'], - g.ndata['labels'], val_nid, args.batch_size_eval, device) - print('Part {}, Eval Acc {:.4f}, time: {:.4f}'.format(g.rank(), eval_acc, time.time() - start)) + val_acc, test_acc = evaluate(model.module, g, g.ndata['features'], + g.ndata['labels'], val_nid, test_nid, args.batch_size_eval, device) + print('Part {}, Val Acc {:.4f}, Test Acc {:.4f}, time: {:.4f}'.format(g.rank(), val_acc, test_acc, + time.time() - start)) profiler.stop() print(profiler.output_text(unicode=True, color=True)) @@ -217,7 +265,7 @@ def run(args, device, data): def main(args): if not args.standalone: th.distributed.init_process_group(backend='gloo') - g = dgl.distributed.DistGraph(args.ip_config, args.graph_name, conf_file=args.conf_path) + g = dgl.distributed.DistGraph(args.ip_config, args.graph_name, part_config=args.conf_path) print('rank:', g.rank()) pb = g.get_partition_book() @@ -236,7 +284,7 @@ def main(args): # Pack data in_feats = g.ndata['features'].shape[1] - data = train_nid, val_nid, in_feats, n_classes, g + data = train_nid, val_nid, test_nid, in_feats, n_classes, g run(args, device, data) print("parent ends") diff --git a/python/dgl/distributed/dist_graph.py b/python/dgl/distributed/dist_graph.py index 142938aebfa6..9f1c2aaec27f 100644 --- a/python/dgl/distributed/dist_graph.py +++ b/python/dgl/distributed/dist_graph.py @@ -207,17 +207,17 @@ class DistGraphServer(KVServer): Path of IP configuration file. num_clients : int Total number of client nodes. - conf_file : string + part_config : string The path of the config file generated by the partition tool. disable_shared_mem : bool Disable shared memory. ''' - def __init__(self, server_id, ip_config, num_clients, conf_file, disable_shared_mem=False): + def __init__(self, server_id, ip_config, num_clients, part_config, disable_shared_mem=False): super(DistGraphServer, self).__init__(server_id=server_id, ip_config=ip_config, num_clients=num_clients) self.ip_config = ip_config # Load graph partition data. - self.client_g, node_feats, edge_feats, self.gpb, graph_name = load_partition(conf_file, + self.client_g, node_feats, edge_feats, self.gpb, graph_name = load_partition(part_config, server_id) print('load ' + graph_name) if not disable_shared_mem: @@ -286,16 +286,16 @@ class DistGraph: The name of the graph. This name has to be the same as the one used in DistGraphServer. gpb : PartitionBook The partition book object - conf_file : str + part_config : str The partition config file. It's used in the standalone mode. ''' - def __init__(self, ip_config, graph_name, gpb=None, conf_file=None): + def __init__(self, ip_config, graph_name, gpb=None, part_config=None): if os.environ.get('DGL_DIST_MODE', 'standalone') == 'standalone': - assert conf_file is not None, \ + assert part_config is not None, \ 'When running in the standalone model, the partition config file is required' self._client = SA_KVClient() # Load graph partition data. - g, node_feats, edge_feats, self._gpb, _ = load_partition(conf_file, 0) + g, node_feats, edge_feats, self._gpb, _ = load_partition(part_config, 0) assert self._gpb.num_partitions() == 1, \ 'The standalone mode can only work with the graph data with one partition' if self._gpb is None: diff --git a/tests/distributed/test_dist_graph_store.py b/tests/distributed/test_dist_graph_store.py index 306d0879eb96..70904add5b4c 100644 --- a/tests/distributed/test_dist_graph_store.py +++ b/tests/distributed/test_dist_graph_store.py @@ -241,7 +241,7 @@ def test_standalone(): g.edata['features'] = F.unsqueeze(F.arange(0, g.number_of_edges()), 1) partition_graph(g, graph_name, num_parts, '/tmp/dist_graph') dist_g = DistGraph("kv_ip_config.txt", graph_name, - conf_file='/tmp/dist_graph/{}.json'.format(graph_name)) + part_config='/tmp/dist_graph/{}.json'.format(graph_name)) check_dist_graph(dist_g, g.number_of_nodes(), g.number_of_edges()) def test_split(): diff --git a/tests/distributed/test_distributed_sampling.py b/tests/distributed/test_distributed_sampling.py index 4c6ca90667d3..7c8e4e72c7e2 100644 --- a/tests/distributed/test_distributed_sampling.py +++ b/tests/distributed/test_distributed_sampling.py @@ -136,7 +136,7 @@ def check_standalone_sampling(tmpdir): partition_graph(g, 'test_sampling', num_parts, tmpdir, num_hops=num_hops, part_method='metis', reshuffle=False) - dist_graph = DistGraph(None, "test_sampling", conf_file=tmpdir / 'test_sampling.json') + dist_graph = DistGraph(None, "test_sampling", part_config=tmpdir / 'test_sampling.json') sampled_graph = sample_neighbors(dist_graph, [0, 10, 99, 66, 1024, 2008], 3) src, dst = sampled_graph.edges() diff --git a/tools/copy_partitions.py b/tools/copy_partitions.py new file mode 100644 index 000000000000..7a40a5791e2d --- /dev/null +++ b/tools/copy_partitions.py @@ -0,0 +1,102 @@ +"""Copy the partitions to a cluster of machines.""" +import os +import stat +import sys +import subprocess +import argparse +import signal +import logging +import json +import copy + +def copy_file(file_name, ip, workspace): + print('copy {} to {}'.format(file_name, ip + ':' + workspace + '/')) + cmd = 'rsync -e \"ssh -o StrictHostKeyChecking=no\" -arvc ' + file_name + ' ' + ip + ':' + workspace + '/' + subprocess.check_call(cmd, shell = True) + +def exec_cmd(ip, cmd): + cmd = 'ssh -o StrictHostKeyChecking=no ' + ip + ' \'' + cmd + '\'' + subprocess.check_call(cmd, shell = True) + +def main(): + parser = argparse.ArgumentParser(description='Copy data to the servers.') + parser.add_argument('--workspace', type=str, required=True, + help='Path of user directory of distributed tasks. \ + This is used to specify a destination location where \ + data are copied to on remote machines.') + parser.add_argument('--rel_data_path', type=str, required=True, + help='Relative path in workspace to store the partition data.') + parser.add_argument('--part_config', type=str, required=True, + help='The partition config file. The path is on the local machine.') + parser.add_argument('--ip_config', type=str, required=True, + help='The file of IP configuration for servers. \ + The path is on the local machine.') + args = parser.parse_args() + + hosts = [] + with open(args.ip_config) as f: + for line in f: + ip, _, _ = line.strip().split(' ') + hosts.append(ip) + + + # We need to update the partition config file so that the paths are relative to + # the workspace in the remote machines. + with open(args.part_config) as conf_f: + part_metadata = json.load(conf_f) + tmp_part_metadata = copy.deepcopy(part_metadata) + num_parts = part_metadata['num_parts'] + assert num_parts == len(hosts), \ + 'The number of partitions needs to be the same as the number of hosts.' + graph_name = part_metadata['graph_name'] + node_map = part_metadata['node_map'] + edge_map = part_metadata['edge_map'] + if not isinstance(node_map, list): + assert node_map[-4:] == '.npy', 'node map should be stored in a NumPy array.' + tmp_part_metadata['node_map'] = '{}/{}/node_map.npy'.format(args.workspace, + args.rel_data_path) + if not isinstance(edge_map, list): + assert edge_map[-4:] == '.npy', 'edge map should be stored in a NumPy array.' + tmp_part_metadata['edge_map'] = '{}/{}/edge_map.npy'.format(args.workspace, + args.rel_data_path) + + for part_id in range(num_parts): + part_files = tmp_part_metadata['part-{}'.format(part_id)] + part_files['edge_feats'] = '{}/part{}/edge_feat.dgl'.format(args.rel_data_path, part_id) + part_files['node_feats'] = '{}/part{}/node_feat.dgl'.format(args.rel_data_path, part_id) + part_files['part_graph'] = '{}/part{}/graph.dgl'.format(args.rel_data_path, part_id) + tmp_part_config = '/tmp/{}.json'.format(graph_name) + with open(tmp_part_config, 'w') as outfile: + json.dump(tmp_part_metadata, outfile, sort_keys=True, indent=4) + + # Copy ip config. + for part_id, ip in enumerate(hosts): + remote_path = '{}/{}'.format(args.workspace, args.rel_data_path) + exec_cmd(ip, 'mkdir -p {}'.format(remote_path)) + + copy_file(args.ip_config, ip, args.workspace) + copy_file(tmp_part_config, ip, '{}/{}'.format(args.workspace, args.rel_data_path)) + node_map = part_metadata['node_map'] + edge_map = part_metadata['edge_map'] + if not isinstance(node_map, list): + copy_file(node_map, ip, tmp_part_metadata['node_map']) + if not isinstance(edge_map, list): + copy_file(edge_map, ip, tmp_part_metadata['edge_map']) + remote_path = '{}/{}/part{}'.format(args.workspace, args.rel_data_path, part_id) + exec_cmd(ip, 'mkdir -p {}'.format(remote_path)) + + part_files = part_metadata['part-{}'.format(part_id)] + copy_file(part_files['node_feats'], ip, remote_path) + copy_file(part_files['edge_feats'], ip, remote_path) + copy_file(part_files['part_graph'], ip, remote_path) + + +def signal_handler(signal, frame): + logging.info('Stop copying') + sys.exit(0) + +if __name__ == '__main__': + fmt = '%(asctime)s %(levelname)s %(message)s' + logging.basicConfig(format=fmt, level=logging.INFO) + signal.signal(signal.SIGINT, signal_handler) + main() diff --git a/tools/launch.py b/tools/launch.py index 319e9b0569d5..65e8382a0cb2 100644 --- a/tools/launch.py +++ b/tools/launch.py @@ -39,7 +39,7 @@ def submit_jobs(args, udf_command): # launch server tasks server_cmd = 'DGL_ROLE=server' server_cmd = server_cmd + ' ' + 'DGL_NUM_CLIENT=' + str(args.num_client) - server_cmd = server_cmd + ' ' + 'DGL_CONF_PATH=' + str(args.conf_path) + server_cmd = server_cmd + ' ' + 'DGL_CONF_PATH=' + str(args.part_config) server_cmd = server_cmd + ' ' + 'DGL_IP_CONFIG=' + str(args.ip_config) for i in range(len(hosts)*server_count_per_machine): ip, _ = hosts[int(i / server_count_per_machine)] @@ -50,7 +50,7 @@ def submit_jobs(args, udf_command): # launch client tasks client_cmd = 'DGL_DIST_MODE="distributed" DGL_ROLE=client' client_cmd = client_cmd + ' ' + 'DGL_NUM_CLIENT=' + str(args.num_client) - client_cmd = client_cmd + ' ' + 'DGL_CONF_PATH=' + str(args.conf_path) + client_cmd = client_cmd + ' ' + 'DGL_CONF_PATH=' + str(args.part_config) client_cmd = client_cmd + ' ' + 'DGL_IP_CONFIG=' + str(args.ip_config) if os.environ.get('OMP_NUM_THREADS') is not None: client_cmd = client_cmd + ' ' + 'OMP_NUM_THREADS=' + os.environ.get('OMP_NUM_THREADS') @@ -85,11 +85,11 @@ def main(): help='Path of user directory of distributed tasks. \ This is used to specify a destination location where \ the contents of current directory will be rsyncd') - parser.add_argument('--num_client', type=int, + parser.add_argument('--num_client', type=int, help='Total number of client processes in the cluster') - parser.add_argument('--conf_path', type=str, - help='The file (in workspace) of the partition config file') - parser.add_argument('--ip_config', type=str, + parser.add_argument('--part_config', type=str, + help='The file (in workspace) of the partition config') + parser.add_argument('--ip_config', type=str, help='The file (in workspace) of IP configuration for server processes') args, udf_command = parser.parse_known_args() assert len(udf_command) == 1, 'Please provide user command line.'