Skip to content

Commit

Permalink
[GraphBolt][CUDA] Refine the multi-GPU example (dmlc#6980)
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin authored Jan 26, 2024
1 parent b3841c2 commit 2da6ace
Showing 1 changed file with 17 additions and 29 deletions.
46 changes: 17 additions & 29 deletions examples/multigpu/graphbolt/node_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,7 @@ def create_dataloader(
features,
itemset,
device,
drop_last=False,
shuffle=True,
drop_uneven_inputs=False,
is_train,
):
############################################################################
# [HIGHLIGHT]
Expand Down Expand Up @@ -122,9 +120,9 @@ def create_dataloader(
datapipe = gb.DistributedItemSampler(
item_set=itemset,
batch_size=args.batch_size,
drop_last=drop_last,
shuffle=shuffle,
drop_uneven_inputs=drop_uneven_inputs,
drop_last=is_train,
shuffle=is_train,
drop_uneven_inputs=is_train,
)
############################################################################
# [Note]:
Expand Down Expand Up @@ -190,7 +188,7 @@ def train(
epoch_start = time.time()

model.train()
total_loss = torch.tensor(0, dtype=torch.float).to(device)
total_loss = torch.tensor(0, dtype=torch.float, device=device)
########################################################################
# (HIGHLIGHT) Use Join Context Manager to solve uneven input problem.
#
Expand Down Expand Up @@ -230,20 +228,17 @@ def train(
loss.backward()
optimizer.step()

total_loss += loss
total_loss += loss.detach()

# Evaluate the model.
if rank == 0:
print("Validating...")
acc = (
evaluate(
rank,
model,
valid_dataloader,
num_classes,
device,
)
/ world_size
acc = evaluate(
rank,
model,
valid_dataloader,
num_classes,
device,
)
########################################################################
# (HIGHLIGHT) Collect accuracy and loss values from sub-processes and
Expand All @@ -255,14 +250,13 @@ def train(
dist.reduce(tensor=acc, dst=0)
total_loss /= step + 1
dist.reduce(tensor=total_loss, dst=0)
dist.barrier()

epoch_end = time.time()
if rank == 0:
print(
f"Epoch {epoch:05d} | "
f"Average Loss {total_loss.item() / world_size:.4f} | "
f"Accuracy {acc.item():.4f} | "
f"Accuracy {acc.item() / world_size:.4f} | "
f"Time {epoch_end - epoch_start:.4f}"
)

Expand Down Expand Up @@ -304,29 +298,23 @@ def run(rank, world_size, args, devices, dataset):
dataset.feature,
train_set,
device,
drop_last=False,
shuffle=True,
drop_uneven_inputs=False,
is_train=True,
)
valid_dataloader = create_dataloader(
args,
dataset.graph,
dataset.feature,
valid_set,
device,
drop_last=False,
shuffle=False,
drop_uneven_inputs=False,
is_train=False,
)
test_dataloader = create_dataloader(
args,
dataset.graph,
dataset.feature,
test_set,
device,
drop_last=False,
shuffle=False,
drop_uneven_inputs=False,
is_train=False,
)

# Model training.
Expand Down Expand Up @@ -357,7 +345,7 @@ def run(rank, world_size, args, devices, dataset):
/ world_size
)
dist.reduce(tensor=test_acc, dst=0)
dist.barrier()
torch.cuda.synchronize()
if rank == 0:
print(f"Test Accuracy {test_acc.item():.4f}")

Expand Down

0 comments on commit 2da6ace

Please sign in to comment.