Skip to content

Commit

Permalink
[Distributed] add distributed evaluation. (dmlc#1810)
Browse files Browse the repository at this point in the history
* add eval.

* extend DistTensor.

* fix.

* add barrier.

* add more print.

* add more checks in kvstore.

* fix lint.

* get all neighbors for eval.

* reorganize.

* fix.

* fix.

* fix.

* fix test.

* add reuse_if_exist.

* add test for reuse_if_exist.

* fix lint.

* fix bugs.

* fix.

* print errors of tcp socket.

* support delete tensors.

* fix lint.

* fix

* fix example

Co-authored-by: Ubuntu <[email protected]>
  • Loading branch information
zheng-da and Ubuntu authored Jul 22, 2020
1 parent 9d85397 commit 562871e
Show file tree
Hide file tree
Showing 11 changed files with 326 additions and 202 deletions.
100 changes: 83 additions & 17 deletions examples/pytorch/graphsage/experimental/train_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,63 @@ def sample_blocks(self, seeds):
blocks.insert(0, block)
return blocks

class DistSAGE(SAGE):
def __init__(self, in_feats, n_hidden, n_classes, n_layers,
activation, dropout):
super(DistSAGE, self).__init__(in_feats, n_hidden, n_classes, n_layers,
activation, dropout)

def inference(self, g, x, batch_size, device):
"""
Inference with the GraphSAGE model on full neighbors (i.e. without neighbor sampling).
g : the entire graph.
x : the input of entire node set.
The inference code is written in a fashion that it could handle any number of nodes and
layers.
"""
# During inference with sampling, multi-layer blocks are very inefficient because
# lots of computations in the first few layers are repeated.
# Therefore, we compute the representation of all nodes layer by layer. The nodes
# on each layer are of course splitted in batches.
# TODO: can we standardize this?
nodes = dgl.distributed.node_split(np.arange(g.number_of_nodes()),
g.get_partition_book(), force_even=True)
y = dgl.distributed.DistTensor(g, (g.number_of_nodes(), self.n_hidden), th.float32, 'h',
persistent=True)
for l, layer in enumerate(self.layers):
if l == len(self.layers) - 1:
y = dgl.distributed.DistTensor(g, (g.number_of_nodes(), self.n_classes),
th.float32, 'h_last', persistent=True)

sampler = NeighborSampler(g, [-1], dgl.distributed.sample_neighbors)
print('|V|={}, eval batch size: {}'.format(g.number_of_nodes(), batch_size))
# Create PyTorch DataLoader for constructing blocks
dataloader = DataLoader(
dataset=nodes,
batch_size=batch_size,
collate_fn=sampler.sample_blocks,
shuffle=False,
drop_last=False,
num_workers=args.num_workers)

for blocks in tqdm.tqdm(dataloader):
block = blocks[0]
input_nodes = block.srcdata[dgl.NID]
output_nodes = block.dstdata[dgl.NID]
h = x[input_nodes].to(device)
h_dst = h[:block.number_of_dst_nodes()]
h = layer(block, (h, h_dst))
if l != len(self.layers) - 1:
h = self.activation(h)
h = self.dropout(h)

y[output_nodes] = h.cpu()

x = y
g.barrier()
return y

def run(args, device, data):
# Unpack data
train_nid, val_nid, in_feats, n_classes, g = data
Expand All @@ -60,7 +117,7 @@ def run(args, device, data):
num_workers=args.num_workers)

# Define model and optimizer
model = SAGE(in_feats, args.num_hidden, n_classes, args.num_layers, F.relu, args.dropout)
model = DistSAGE(in_feats, args.num_hidden, n_classes, args.num_layers, F.relu, args.dropout)
model = model.to(device)
if not args.standalone:
model = th.nn.parallel.DistributedDataParallel(model)
Expand Down Expand Up @@ -101,6 +158,8 @@ def run(args, device, data):
# Load the input features as well as output labels
start = time.time()
batch_inputs, batch_labels = load_subtensor(g, seeds, input_nodes, device)
assert th.all(th.logical_not(th.isnan(batch_labels)))
batch_labels = batch_labels.long()
copy_time += time.time() - start

num_seeds += len(blocks[-1].dstdata[dgl.NID])
Expand All @@ -122,7 +181,7 @@ def run(args, device, data):
if param.requires_grad and param.grad is not None:
th.distributed.all_reduce(param.grad.data,
op=th.distributed.ReduceOp.SUM)
param.grad.data /= args.num_client
param.grad.data /= dgl.distributed.get_num_client()

optimizer.step()
update_time += time.time() - compute_end
Expand All @@ -133,21 +192,21 @@ def run(args, device, data):
if step % args.log_every == 0:
acc = compute_acc(batch_pred, batch_labels)
gpu_mem_alloc = th.cuda.max_memory_allocated() / 1000000 if th.cuda.is_available() else 0
print('Epoch {:05d} | Step {:05d} | Loss {:.4f} | Train Acc {:.4f} | Speed (samples/sec) {:.4f} | GPU {:.1f} MiB | time {:.3f} s'.format(
epoch, step, loss.item(), acc.item(), np.mean(iter_tput[3:]), gpu_mem_alloc, np.sum(step_time[-args.log_every:])))
print('Part {} | Epoch {:05d} | Step {:05d} | Loss {:.4f} | Train Acc {:.4f} | Speed (samples/sec) {:.4f} | GPU {:.1f} MiB | time {:.3f} s'.format(
g.rank(), epoch, step, loss.item(), acc.item(), np.mean(iter_tput[3:]), gpu_mem_alloc, np.sum(step_time[-args.log_every:])))
start = time.time()

toc = time.time()
print('Epoch Time(s): {:.4f}, sample: {:.4f}, data copy: {:.4f}, forward: {:.4f}, backward: {:.4f}, update: {:.4f}, #seeds: {}, #inputs: {}'.format(
toc - tic, sample_time, copy_time, forward_time, backward_time, update_time, num_seeds, num_inputs))
print('Part {}, Epoch Time(s): {:.4f}, sample: {:.4f}, data copy: {:.4f}, forward: {:.4f}, backward: {:.4f}, update: {:.4f}, #seeds: {}, #inputs: {}'.format(
g.rank(), toc - tic, sample_time, copy_time, forward_time, backward_time, update_time, num_seeds, num_inputs))
epoch += 1


toc = time.time()
print('Epoch Time(s): {:.4f}'.format(toc - tic))
#if epoch % args.eval_every == 0 and epoch != 0:
# eval_acc = evaluate(model, g, g.ndata['features'], g.ndata['labels'], val_nid, args.batch_size, device)
# print('Eval Acc {:.4f}'.format(eval_acc))
if epoch % args.eval_every == 0 and epoch != 0:
start = time.time()
eval_acc = evaluate(model.module, g, g.ndata['features'],
g.ndata['labels'], val_nid, args.batch_size_eval, device)
print('Part {}, Eval Acc {:.4f}, time: {:.4f}'.format(g.rank(), eval_acc, time.time() - start))

profiler.stop()
print(profiler.output_text(unicode=True, color=True))
Expand All @@ -161,13 +220,19 @@ def main(args):
g = dgl.distributed.DistGraph(args.ip_config, args.graph_name, conf_file=args.conf_path)
print('rank:', g.rank())

train_nid = dgl.distributed.node_split(g.ndata['train_mask'], g.get_partition_book(), force_even=True)
val_nid = dgl.distributed.node_split(g.ndata['val_mask'], g.get_partition_book(), force_even=True)
test_nid = dgl.distributed.node_split(g.ndata['test_mask'], g.get_partition_book(), force_even=True)
print('part {}, train: {}, val: {}, test: {}'.format(g.rank(), len(train_nid),
len(val_nid), len(test_nid)))
pb = g.get_partition_book()
train_nid = dgl.distributed.node_split(g.ndata['train_mask'], pb, force_even=True)
val_nid = dgl.distributed.node_split(g.ndata['val_mask'], pb, force_even=True)
test_nid = dgl.distributed.node_split(g.ndata['test_mask'], pb, force_even=True)
local_nid = pb.partid2nids(pb.partid).detach().numpy()
print('part {}, train: {} (local: {}), val: {} (local: {}), test: {} (local: {})'.format(
g.rank(), len(train_nid), len(np.intersect1d(train_nid.numpy(), local_nid)),
len(val_nid), len(np.intersect1d(val_nid.numpy(), local_nid)),
len(test_nid), len(np.intersect1d(test_nid.numpy(), local_nid))))
device = th.device('cpu')
n_classes = len(th.unique(g.ndata['labels'][np.arange(g.number_of_nodes())]))
labels = g.ndata['labels'][np.arange(g.number_of_nodes())]
n_classes = len(th.unique(labels[th.logical_not(th.isnan(labels))]))
print('#labels:', n_classes)

# Pack data
in_feats = g.ndata['features'].shape[1]
Expand All @@ -191,6 +256,7 @@ def main(args):
parser.add_argument('--num-layers', type=int, default=2)
parser.add_argument('--fan-out', type=str, default='10,25')
parser.add_argument('--batch-size', type=int, default=1000)
parser.add_argument('--batch-size-eval', type=int, default=100000)
parser.add_argument('--log-every', type=int, default=20)
parser.add_argument('--eval-every', type=int, default=5)
parser.add_argument('--lr', type=float, default=0.003)
Expand Down
2 changes: 1 addition & 1 deletion examples/pytorch/graphsage/train_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ def inference(self, g, x, batch_size, device):
# Therefore, we compute the representation of all nodes layer by layer. The nodes
# on each layer are of course splitted in batches.
# TODO: can we standardize this?
nodes = th.arange(g.number_of_nodes())
for l, layer in enumerate(self.layers):
y = th.zeros(g.number_of_nodes(), self.n_hidden if l != len(self.layers) - 1 else self.n_classes)

Expand Down Expand Up @@ -113,6 +112,7 @@ def compute_acc(pred, labels):
"""
Compute the accuracy of prediction given the labels.
"""
labels = labels.long()
return (th.argmax(pred, dim=1) == labels).float().sum() / len(pred)

def evaluate(model, g, inputs, labels, val_nid, batch_size, device):
Expand Down
Loading

0 comments on commit 562871e

Please sign in to comment.