Skip to content

Commit

Permalink
[Distributed] add the standalone mode in DistGraph (dmlc#1800)
Browse files Browse the repository at this point in the history
* add standalone mode

* add comments.

* add tests for sampling.

* fix.

* make the code to run the standalone mode

* fix

* fix

* fix readme.

* fix.

* fix test

Co-authored-by: Chao Ma <[email protected]>
  • Loading branch information
zheng-da and aksnzhy authored Jul 15, 2020
1 parent a5722d0 commit cda0abf
Show file tree
Hide file tree
Showing 6 changed files with 193 additions and 29 deletions.
24 changes: 23 additions & 1 deletion examples/pytorch/graphsage/experimental/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ will be able to access the partitioned data.

### Step 3: run servers

We need to run a server on each machine. Before running the servers, we need to update `ip_config.txt` with the right IP addresses.
To perform actual distributed training (running training jobs in multiple machines), we need to run
a server on each machine. Before running the servers, we need to update `ip_config.txt` with the right IP addresses.

On each of the machines, set the following environment variables.

Expand Down Expand Up @@ -60,6 +61,8 @@ We run a trainer process on each machine. Here we use Pytorch distributed. We ne
Pytorch distributed requires one of the trainer process to be the master. Here we use the first machine to run the master process.

```bash
# set the DistGraph in distributed mode
export DGL_DIST_MODE="distributed"
# run client on machine 0
python3 -m torch.distributed.launch --nproc_per_node=1 --nnodes=4 --node_rank=0 --master_addr="172.31.16.250" --master_port=1234 train_dist.py --graph-name ogb-product --ip_config ip_config.txt --num-epochs 3 --batch-size 1000 --lr 0.1
# run client on machine 1
Expand All @@ -69,3 +72,22 @@ python3 -m torch.distributed.launch --nproc_per_node=1 --nnodes=4 --node_rank=2
# run client on machine 3
python3 -m torch.distributed.launch --nproc_per_node=1 --nnodes=4 --node_rank=3 --master_addr="172.31.16.250" --master_port=1234 train_dist.py --graph-name ogb-product --ip_config ip_config.txt --num-epochs 3 --batch-size 1000 --lr 0.1
```

## Distributed code runs in the standalone mode

The standalone mode is mainly used for development and testing. The procedure to run the code is much simpler.

### Step 1: graph construction.

When testing the standalone mode of the training script, we should construct a graph with one partition.
```bash
python3 partition_graph.py --dataset ogb-product --num_parts 1
```

### Step 2: run the training script

```bash
python3 train_dist.py --graph-name ogb-product --ip_config ip_config.txt --num-epochs 3 --batch-size 1000 --conf_path data/ogb-product.json --standalone
```

Note: please ensure that all environment variables shown above are unset if they were set for testing distributed training.
27 changes: 16 additions & 11 deletions examples/pytorch/graphsage/experimental/train_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ def run(args, device, data):
# Define model and optimizer
model = SAGE(in_feats, args.num_hidden, n_classes, args.num_layers, F.relu, args.dropout)
model = model.to(device)
model = th.nn.parallel.DistributedDataParallel(model)
if not args.standalone:
model = th.nn.parallel.DistributedDataParallel(model)
loss_fcn = nn.CrossEntropyLoss()
loss_fcn = loss_fcn.to(device)
optimizer = optim.Adam(model.parameters(), lr=args.lr)
Expand Down Expand Up @@ -116,11 +117,12 @@ def run(args, device, data):
backward_time += compute_end - forward_end

# Aggregate gradients in multiple nodes.
for param in model.parameters():
if param.requires_grad and param.grad is not None:
th.distributed.all_reduce(param.grad.data,
op=th.distributed.ReduceOp.SUM)
param.grad.data /= args.num_client
if not args.standalone:
for param in model.parameters():
if param.requires_grad and param.grad is not None:
th.distributed.all_reduce(param.grad.data,
op=th.distributed.ReduceOp.SUM)
param.grad.data /= args.num_client

optimizer.step()
update_time += time.time() - compute_end
Expand Down Expand Up @@ -150,13 +152,15 @@ def run(args, device, data):
profiler.stop()
print(profiler.output_text(unicode=True, color=True))
# clean up
g._client.barrier()
dgl.distributed.shutdown_servers()
dgl.distributed.finalize_client()
if not args.standalone:
g._client.barrier()
dgl.distributed.shutdown_servers()
dgl.distributed.finalize_client()

def main(args):
th.distributed.init_process_group(backend='gloo')
g = dgl.distributed.DistGraph(args.ip_config, args.graph_name)
if not args.standalone:
th.distributed.init_process_group(backend='gloo')
g = dgl.distributed.DistGraph(args.ip_config, args.graph_name, conf_file=args.conf_path)
print('rank:', g.rank())

train_nid = dgl.distributed.node_split(g.ndata['train_mask'], g.get_partition_book(), force_even=True)
Expand Down Expand Up @@ -196,6 +200,7 @@ def main(args):
parser.add_argument('--num-workers', type=int, default=0,
help="Number of sampling processes. Use 0 for no extra process.")
parser.add_argument('--local_rank', type=int, help='get rank of the process')
parser.add_argument('--standalone', action='store_true', help='run in the standalone mode')
args = parser.parse_args()

print(args)
Expand Down
55 changes: 43 additions & 12 deletions python/dgl/distributed/dist_graph.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
"""Define distributed graph."""

from collections.abc import MutableMapping
import os
import numpy as np

from ..graph import DGLGraph
from .. import backend as F
from ..base import NID, EID
from .kvstore import KVServer, KVClient
from .standalone_kvstore import KVClient as SA_KVClient
from ..graph_index import from_shared_mem_graph_index
from .._ffi.ndarray import empty_shared_mem
from ..frame import infer_scheme
Expand Down Expand Up @@ -338,7 +340,16 @@ class DistGraph:
This provides the graph interface to access the partitioned graph data for distributed GNN
training. All data of partitions are loaded by the DistGraph server.
By default, `DistGraph` uses shared-memory to access the partition data in the local machine.
DistGraph can run in two modes: the standalone mode and the distributed mode.
* When a user runs the training script normally, DistGraph will be in the standalone mode.
In this mode, the input graph has to be constructed with only one partition. This mode is
used for testing and debugging purpose.
* When a user runs the training script with the distributed launch script, DistGraph will
be set into the distributed mode. This is used for actual distributed training.
When running in the distributed mode, `DistGraph` uses shared-memory to access
the partition data in the local machine.
This gives the best performance for distributed training when we run `DistGraphServer`
and `DistGraph` on the same machine. However, a user may want to run them in separate
machines. In this case, a user may want to disable shared memory by passing
Expand All @@ -353,20 +364,40 @@ class DistGraph:
The name of the graph. This name has to be the same as the one used in DistGraphServer.
gpb : PartitionBook
The partition book object
conf_file : str
The partition config file. It's used in the standalone mode.
'''
def __init__(self, ip_config, graph_name, gpb=None):
connect_to_server(ip_config=ip_config)
self._client = KVClient(ip_config)
g = _get_graph_from_shared_mem(graph_name)
if g is not None:
def __init__(self, ip_config, graph_name, gpb=None, conf_file=None):
if os.environ.get('DGL_DIST_MODE', 'standalone') == 'standalone':
assert conf_file is not None, \
'When running in the standalone model, the partition config file is required'
self._client = SA_KVClient()
# Load graph partition data.
g, node_feats, edge_feats, self._gpb, _ = load_partition(conf_file, 0)
assert self._gpb.num_partitions() == 1, \
'The standalone mode can only work with the graph data with one partition'
if self._gpb is None:
self._gpb = gpb
self._g = as_heterograph(g)
for name in node_feats:
self._client.add_data(_get_ndata_name(name), node_feats[name])
for name in edge_feats:
self._client.add_data(_get_edata_name(name), edge_feats[name])
rpc.set_num_client(1)
else:
self._g = None
self._gpb = get_shared_mem_partition_book(graph_name, self._g)
if self._gpb is None:
self._gpb = gpb
self._client.barrier()
self._client.map_shared_data(self._gpb)
connect_to_server(ip_config=ip_config)
self._client = KVClient(ip_config)
g = _get_graph_from_shared_mem(graph_name)
if g is not None:
self._g = as_heterograph(g)
else:
self._g = None
self._gpb = get_shared_mem_partition_book(graph_name, self._g)
if self._gpb is None:
self._gpb = gpb
self._client.barrier()
self._client.map_shared_data(self._gpb)

self._ndata = NodeDataView(self)
self._edata = EdgeDataView(self)

Expand Down
58 changes: 58 additions & 0 deletions python/dgl/distributed/standalone_kvstore.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
"""Define a fake kvstore
This kvstore is used when running in the standalone mode
"""

from .. import backend as F

class KVClient(object):
''' The fake KVStore client.
This is to mimic the distributed KVStore client. It's used for DistGraph
in standalone mode.
'''
def __init__(self):
self._data = {}
self._push_handlers = {}
self._pull_handlers = {}

def barrier(self):
'''barrier'''

def register_push_handler(self, name, func):
'''register push handler'''
self._push_handlers[name] = func

def register_pull_handler(self, name, func):
'''register pull handler'''
self._pull_handlers[name] = func

def add_data(self, name, tensor):
'''add data to the client'''
self._data[name] = tensor

def init_data(self, name, shape, dtype, _, init_func):
'''add new data to the client'''
self._data[name] = init_func(shape, dtype)

def data_name_list(self):
'''get the names of all data'''
return list(self._data.keys())

def get_data_meta(self, name):
'''get the metadata of data'''
return F.dtype(self._data[name]), F.shape(self._data[name]), None

def push(self, name, id_tensor, data_tensor):
'''push data to kvstore'''
if name in self._push_handlers:
self._push_handlers[name](self._data, name, id_tensor, data_tensor)
else:
F.scatter_row_inplace(self._data[name], id_tensor, data_tensor)

def pull(self, name, id_tensor):
'''pull data from kvstore'''
if name in self._pull_handlers:
return self._pull_handlers[name](self._data, name, id_tensor)
else:
return F.gather_row(self._data[name], id_tensor)
23 changes: 21 additions & 2 deletions tests/distributed/test_dist_graph_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,9 @@ def run_client(graph_name, part_id, num_nodes, num_edges):
gpb, graph_name = load_partition_book('/tmp/dist_graph/{}.json'.format(graph_name),
part_id, None)
g = DistGraph("kv_ip_config.txt", graph_name, gpb=gpb)
check_dist_graph(g, num_nodes, num_edges)

def check_dist_graph(g, num_nodes, num_edges):
# Test API
assert g.number_of_nodes() == num_nodes
assert g.number_of_edges() == num_edges
Expand Down Expand Up @@ -163,8 +165,9 @@ def run_client(graph_name, part_id, num_nodes, num_edges):
assert n in local_nids

# clean up
dgl.distributed.shutdown_servers()
dgl.distributed.finalize_client()
if os.environ['DGL_DIST_MODE'] == 'distributed':
dgl.distributed.shutdown_servers()
dgl.distributed.finalize_client()
print('end')

def check_server_client(shared_mem):
Expand Down Expand Up @@ -205,9 +208,24 @@ def check_server_client(shared_mem):

@unittest.skipIf(dgl.backend.backend_name == "tensorflow", reason="TF doesn't support some of operations in DistGraph")
def test_server_client():
os.environ['DGL_DIST_MODE'] = 'distributed'
check_server_client(True)
check_server_client(False)

@unittest.skipIf(dgl.backend.backend_name == "tensorflow", reason="TF doesn't support some of operations in DistGraph")
def test_standalone():
os.environ['DGL_DIST_MODE'] = 'standalone'
g = create_random_graph(10000)
# Partition the graph
num_parts = 1
graph_name = 'dist_graph_test_3'
g.ndata['features'] = F.unsqueeze(F.arange(0, g.number_of_nodes()), 1)
g.edata['features'] = F.unsqueeze(F.arange(0, g.number_of_edges()), 1)
partition_graph(g, graph_name, num_parts, '/tmp/dist_graph')
dist_g = DistGraph("kv_ip_config.txt", graph_name,
conf_file='/tmp/dist_graph/{}.json'.format(graph_name))
check_dist_graph(dist_g, g.number_of_nodes(), g.number_of_edges())

def test_split():
prepare_dist()
g = create_random_graph(10000)
Expand Down Expand Up @@ -323,3 +341,4 @@ def prepare_dist():
test_split()
test_split_even()
test_server_client()
test_standalone()
35 changes: 32 additions & 3 deletions tests/distributed/test_distributed_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,12 @@


def start_server(rank, tmpdir, disable_shared_mem, graph_name):
import dgl
g = DistGraphServer(rank, "rpc_ip_config.txt", 1,
tmpdir / (graph_name + '.json'), disable_shared_mem=disable_shared_mem)
g.start()


def start_sample_client(rank, tmpdir, disable_shared_mem):
import dgl
gpb = None
if disable_shared_mem:
_, _, _, gpb, _ = load_partition(tmpdir / 'test_sampling.json', rank)
Expand Down Expand Up @@ -74,6 +72,7 @@ def check_rpc_sampling(tmpdir, num_server):
@unittest.skipIf(dgl.backend.backend_name == 'tensorflow', reason='Not support tensorflow for now')
def test_rpc_sampling():
import tempfile
os.environ['DGL_DIST_MODE'] = 'distributed'
with tempfile.TemporaryDirectory() as tmpdirname:
check_rpc_sampling(Path(tmpdirname), 2)

Expand Down Expand Up @@ -126,12 +125,38 @@ def check_rpc_sampling_shuffle(tmpdir, num_server):
@unittest.skipIf(dgl.backend.backend_name == 'tensorflow', reason='Not support tensorflow for now')
def test_rpc_sampling_shuffle():
import tempfile
os.environ['DGL_DIST_MODE'] = 'distributed'
with tempfile.TemporaryDirectory() as tmpdirname:
check_rpc_sampling_shuffle(Path(tmpdirname), 2)
check_rpc_sampling_shuffle(Path(tmpdirname), 1)

def check_standalone_sampling(tmpdir):
g = CitationGraphDataset("cora")[0]
g.readonly()
num_parts = 1
num_hops = 1
partition_graph(g, 'test_sampling', num_parts, tmpdir,
num_hops=num_hops, part_method='metis', reshuffle=False)

dist_graph = DistGraph(None, "test_sampling", conf_file=tmpdir / 'test_sampling.json')
sampled_graph = sample_neighbors(dist_graph, [0, 10, 99, 66, 1024, 2008], 3)

src, dst = sampled_graph.edges()
assert sampled_graph.number_of_nodes() == g.number_of_nodes()
assert np.all(F.asnumpy(g.has_edges_between(src, dst)))
eids = g.edge_ids(src, dst)
assert np.array_equal(
F.asnumpy(sampled_graph.edata[dgl.EID]), F.asnumpy(eids))

@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@unittest.skipIf(dgl.backend.backend_name == 'tensorflow', reason='Not support tensorflow for now')
def test_standalone_sampling():
import tempfile
os.environ['DGL_DIST_MODE'] = 'standalone'
with tempfile.TemporaryDirectory() as tmpdirname:
check_standalone_sampling(Path(tmpdirname))

def start_in_subgraph_client(rank, tmpdir, disable_shared_mem, nodes):
import dgl
gpb = None
if disable_shared_mem:
_, _, _, gpb, _ = load_partition(tmpdir / 'test_in_subgraph.json', rank)
Expand Down Expand Up @@ -184,12 +209,16 @@ def check_rpc_in_subgraph(tmpdir, num_server):
@unittest.skipIf(dgl.backend.backend_name == 'tensorflow', reason='Not support tensorflow for now')
def test_rpc_in_subgraph():
import tempfile
os.environ['DGL_DIST_MODE'] = 'distributed'
with tempfile.TemporaryDirectory() as tmpdirname:
check_rpc_in_subgraph(Path(tmpdirname), 2)

if __name__ == "__main__":
import tempfile
with tempfile.TemporaryDirectory() as tmpdirname:
os.environ['DGL_DIST_MODE'] = 'standalone'
check_standalone_sampling(Path(tmpdirname))
os.environ['DGL_DIST_MODE'] = 'distributed'
check_rpc_in_subgraph(Path(tmpdirname), 2)
check_rpc_sampling_shuffle(Path(tmpdirname), 1)
check_rpc_sampling_shuffle(Path(tmpdirname), 2)
Expand Down

0 comments on commit cda0abf

Please sign in to comment.