Skip to content

Commit

Permalink
[Examples] refine dist train example (dmlc#5763)
Browse files Browse the repository at this point in the history
  • Loading branch information
Rhett-Ying authored Jun 6, 2023
1 parent 041f78b commit 0202598
Showing 1 changed file with 145 additions and 91 deletions.
236 changes: 145 additions & 91 deletions examples/distributed/graphsage/node_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,26 @@
import tqdm


def load_subtensor(g, seeds, input_nodes, device, load_feat=True):
class DistSAGE(nn.Module):
"""
Copys features and labels of a set of nodes onto GPU.
SAGE model for distributed train and evaluation.
Parameters
----------
in_feats : int
Feature dimension.
n_hidden : int
Hidden layer dimension.
n_classes : int
Number of classes.
n_layers : int
Number of layers.
activation : callable
Activation function.
dropout : float
Dropout value.
"""
batch_inputs = (
g.ndata["features"][input_nodes].to(device) if load_feat else None
)
batch_labels = g.ndata["labels"][seeds].to(device)
return batch_inputs, batch_labels


class DistSAGE(nn.Module):
def __init__(
self, in_feats, n_hidden, n_classes, n_layers, activation, dropout
):
Expand All @@ -40,6 +48,16 @@ def __init__(
self.activation = activation

def forward(self, blocks, x):
"""
Forward function.
Parameters
----------
blocks : List[DGLBlock]
Sampled blocks.
x : DistTensor
Feature data.
"""
h = x
for i, (layer, block) in enumerate(zip(self.layers, blocks)):
h = layer(block, h)
Expand All @@ -50,41 +68,46 @@ def forward(self, blocks, x):

def inference(self, g, x, batch_size, device):
"""
Inference with the GraphSAGE model on full neighbors (i.e. without
neighbor sampling).
g : the entire graph.
x : the input of entire node set.
Distributed layer-wise inference.
Distributed layer-wise inference with the GraphSAGE model on full
neighbors.
Parameters
----------
g : DistGraph
Input Graph for inference.
x : DistTensor
Node feature data of input graph.
Returns
-------
DistTensor
Inference results.
"""
# During inference with sampling, multi-layer blocks are very
# inefficient because lots of computations in the first few layers
# are repeated. Therefore, we compute the representation of all nodes
# layer by layer. The nodes on each layer are of course splitted in
# batches.
# TODO: can we standardize this?
# Split nodes to each trainer.
nodes = dgl.distributed.node_split(
np.arange(g.num_nodes()),
g.get_partition_book(),
force_even=True,
)
y = dgl.distributed.DistTensor(
(g.num_nodes(), self.n_hidden),
th.float32,
"h",
persistent=True,
)

for i, layer in enumerate(self.layers):
# Create DistTensor to save forward results.
if i == len(self.layers) - 1:
y = dgl.distributed.DistTensor(
(g.num_nodes(), self.n_classes),
th.float32,
"h_last",
persistent=True,
)
print(f"|V|={g.num_nodes()}, eval batch size: {batch_size}")
out_dim = self.n_classes
name = "h_last"
else:
out_dim = self.n_hidden
name = "h"
y = dgl.distributed.DistTensor(
(g.num_nodes(), out_dim),
th.float32,
name,
persistent=True,
)
print(f"|V|={g.num_nodes()}, inference batch size: {batch_size}")

# `-1` indicates all inbound edges will be inlcuded, namely, full
# neighbor sampling.
sampler = dgl.dataloading.NeighborSampler([-1])
dataloader = dgl.dataloading.DistNodeDataLoader(
g,
Expand All @@ -103,31 +126,64 @@ def inference(self, g, x, batch_size, device):
if i != len(self.layers) - 1:
h = self.activation(h)
h = self.dropout(h)

# Copy back to CPU as DistTensor requires data reside on CPU.
y[output_nodes] = h.cpu()

x = y
# Synchronize trainers.
g.barrier()
return y
return x


def compute_acc(pred, labels):
"""
Compute the accuracy of prediction given the labels.
Parameters
----------
pred : torch.Tensor
Predicted labels.
labels : torch.Tensor
Ground-truth labels.
Returns
-------
float
Accuracy.
"""
labels = labels.long()
return (th.argmax(pred, dim=1) == labels).float().sum() / len(pred)


def evaluate(model, g, inputs, labels, val_nid, test_nid, batch_size, device):
"""
Evaluate the model on the validation set specified by ``val_nid``.
g : The entire graph.
inputs : The features of all the nodes.
labels : The labels of all the nodes.
val_nid : the node Ids for validation.
batch_size : Number of nodes to compute at the same time.
device : The GPU device to evaluate on.
Evaluate the model on the validation and test set.
Parameters
----------
model : DistSAGE
The model to be evaluated.
g : DistGraph
The entire graph.
inputs : DistTensor
The feature data of all the nodes.
labels : DistTensor
The labels of all the nodes.
val_nid : torch.Tensor
The node IDs for validation.
test_nid : torch.Tensor
The node IDs for test.
batch_size : int
Batch size for evaluation.
device : torch.Device
The target device to evaluate on.
Returns
-------
float
Validation accuracy.
float
Test accuracy.
"""
model.eval()
with th.no_grad():
Expand All @@ -139,6 +195,19 @@ def evaluate(model, g, inputs, labels, val_nid, test_nid, batch_size, device):


def run(args, device, data):
"""
Train and evaluate DistSAGE.
Parameters
----------
args : argparse.Args
Arguments for train and evaluate.
device : torch.Device
Target device for train and evaluate.
data : Packed Data
Packed data includes train/val/test IDs, feature dimension,
number of classes, graph.
"""
train_nid, val_nid, test_nid, in_feats, n_classes, g = data
sampler = dgl.dataloading.NeighborSampler(
[int(fanout) for fanout in args.fan_out.split(",")]
Expand Down Expand Up @@ -178,25 +247,23 @@ def run(args, device, data):
for _ in range(args.num_epochs):
epoch += 1
tic = time.time()
# Various time statistics.
sample_time = 0
forward_time = 0
backward_time = 0
update_time = 0
num_seeds = 0
num_inputs = 0
start = time.time()
# Loop over the dataloader to sample the computation dependency graph
# as a list of blocks.
step_time = []

with model.join():
for step, (input_nodes, seeds, blocks) in enumerate(dataloader):
tic_step = time.time()
sample_time += tic_step - start
batch_inputs, batch_labels = load_subtensor(
g, seeds, input_nodes, "cpu"
)
batch_labels = batch_labels.long()
# Slice feature and label.
batch_inputs = g.ndata["features"][input_nodes]
batch_labels = g.ndata["labels"][seeds].long()
num_seeds += len(blocks[-1].dstdata[dgl.NID])
num_inputs += len(blocks[0].srcdata[dgl.NID])
# Move to target device.
Expand Down Expand Up @@ -227,36 +294,23 @@ def run(args, device, data):
if th.cuda.is_available()
else 0
)
sample_speed = np.mean(iter_tput[-args.log_every :])
mean_step_time = np.mean(step_time[-args.log_every :])
print(
"Part {} | Epoch {:05d} | Step {:05d} | Loss {:.4f} | "
"Train Acc {:.4f} | Speed (samples/sec) {:.4f} | GPU "
"{:.1f} MB | time {:.3f} s".format(
g.rank(),
epoch,
step,
loss.item(),
acc.item(),
np.mean(iter_tput[3:]),
gpu_mem_alloc,
np.mean(step_time[-args.log_every :]),
)
f"Part {g.rank()} | Epoch {epoch:05d} | Step {step:05d}"
f" | Loss {loss.item():.4f} | Train Acc {acc.item():.4f}"
f" | Speed (samples/sec) {sample_speed:.4f}"
f" | GPU {gpu_mem_alloc:.1f} MB | "
f"Mean step time {mean_step_time:.3f} s"
)
start = time.time()

toc = time.time()
print(
"Part {}, Epoch Time(s): {:.4f}, sample+data_copy: {:.4f}, "
"forward: {:.4f}, backward: {:.4f}, update: {:.4f}, #seeds: {}, "
"#inputs: {}".format(
g.rank(),
toc - tic,
sample_time,
forward_time,
backward_time,
update_time,
num_seeds,
num_inputs,
)
f"Part {g.rank()}, Epoch Time(s): {toc - tic:.4f}, "
f"sample+data_copy: {sample_time:.4f}, forward: {forward_time:.4f},"
f" backward: {backward_time:.4f}, update: {update_time:.4f}, "
f"#seeds: {num_seeds}, #inputs: {num_inputs}"
)
epoch_time.append(toc - tic)

Expand All @@ -273,23 +327,27 @@ def run(args, device, data):
device,
)
print(
"Part {}, Val Acc {:.4f}, Test Acc {:.4f}, time: {:.4f}".format(
g.rank(), val_acc, test_acc, time.time() - start
)
f"Part {g.rank()}, Val Acc {val_acc:.4f}, "
f"Test Acc {test_acc:.4f}, time: {time.time() - start:.4f}"
)

return np.mean(epoch_time[-int(args.num_epochs * 0.8) :]), test_acc


def main(args):
print(socket.gethostname(), "Initializing DistDGL.")
"""
Main function.
"""
host_name = socket.gethostname()
print(f"{host_name}: Initializing DistDGL.")
dgl.distributed.initialize(args.ip_config, net_type=args.net_type)
print(socket.gethostname(), "Initializing PyTorch process group.")
print(f"{host_name}: Initializing PyTorch process group.")
th.distributed.init_process_group(backend=args.backend)
print(socket.gethostname(), "Initializing DistGraph.")
print(f"{host_name}: Initializing DistGraph.")
g = dgl.distributed.DistGraph(args.graph_name, part_config=args.part_config)
print(socket.gethostname(), "rank:", g.rank())
print(f"Rank of {host_name}: {g.rank()}")

# Split train/val/test IDs for each trainer.
pb = g.get_partition_book()
if "trainer_id" in g.ndata:
train_nid = dgl.distributed.node_split(
Expand Down Expand Up @@ -321,17 +379,13 @@ def main(args):
g.ndata["test_mask"], pb, force_even=True
)
local_nid = pb.partid2nids(pb.partid).detach().numpy()
num_train_local = len(np.intersect1d(train_nid.numpy(), local_nid))
num_val_local = len(np.intersect1d(val_nid.numpy(), local_nid))
num_test_local = len(np.intersect1d(test_nid.numpy(), local_nid))
print(
"part {}, train: {} (local: {}), val: {} (local: {}), test: {} "
"(local: {})".format(
g.rank(),
len(train_nid),
len(np.intersect1d(train_nid.numpy(), local_nid)),
len(val_nid),
len(np.intersect1d(val_nid.numpy(), local_nid)),
len(test_nid),
len(np.intersect1d(test_nid.numpy(), local_nid)),
)
f"part {g.rank()}, train: {len(train_nid)} (local: {num_train_local}), "
f"val: {len(val_nid)} (local: {num_val_local}), "
f"test: {len(test_nid)} (local: {num_test_local})"
)
del local_nid
if args.num_gpus == 0:
Expand Down

0 comments on commit 0202598

Please sign in to comment.