Skip to content

Commit

Permalink
[AIR] Add distributed torch_geometric example (ray-project#23580)
Browse files Browse the repository at this point in the history
Add example for distributed pytorch geometric (graph learning) with Ray AIR

This only showcases distributed training, but with data small enough that it can be loaded in by each training worker individually. Distributed data ingest is out of scope for this PR.

Co-authored-by: matthewdeng <[email protected]>
  • Loading branch information
amogkam and matthewdeng authored Apr 21, 2022
1 parent e4a66c0 commit 732175e
Show file tree
Hide file tree
Showing 6 changed files with 253 additions and 4 deletions.
10 changes: 10 additions & 0 deletions python/ray/ml/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,16 @@ py_test(
deps = [":ml_lib"]
)

py_test(
name = "distributed_sage_example",
size = "large",
main = "examples/pytorch_geometric/distributed_sage_example.py",
srcs = ["examples/pytorch_geometric/distributed_sage_example.py"],
tags = ["team:ml", "exclusive", "gpu"],
deps = [":ml_lib"],
args = ["--use-gpu", "--num-workers=2", "--epochs=1", "--dataset=fake"]
)

py_test (
name = "lightgbm_example",
size = "medium",
Expand Down
Empty file.
231 changes: 231 additions & 0 deletions python/ray/ml/examples/pytorch_geometric/distributed_sage_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
# Adapted from https://github.com/pyg-team/pytorch_geometric/blob/master/examples
# /multi_gpu/distributed_sampling.py.

import os
import argparse
from filelock import FileLock

import torch
import torch.nn.functional as F

from torch_geometric.datasets import Reddit, FakeDataset
from torch_geometric.loader import NeighborSampler
from torch_geometric.nn import SAGEConv

from ray import train
from ray.ml.train.integrations.torch import TorchTrainer
from torch_geometric.transforms import RandomNodeSplit


class SAGE(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, num_layers=2):
super().__init__()
self.num_layers = num_layers

self.convs = torch.nn.ModuleList()
self.convs.append(SAGEConv(in_channels, hidden_channels))
for _ in range(self.num_layers - 2):
self.convs.append(SAGEConv(hidden_channels, hidden_channels))
self.convs.append(SAGEConv(hidden_channels, out_channels))

def forward(self, x, adjs):
for i, (edge_index, _, size) in enumerate(adjs):
x_target = x[: size[1]] # Target nodes are always placed first.
x = self.convs[i]((x, x_target), edge_index)
if i != self.num_layers - 1:
x = F.relu(x)
x = F.dropout(x, p=0.5, training=self.training)
return x.log_softmax(dim=-1)

@torch.no_grad()
def inference(self, x_all, subgraph_loader):
for i in range(self.num_layers):
xs = []
for batch_size, n_id, adj in subgraph_loader:

edge_index, _, size = adj
x = x_all[n_id].to(train.torch.get_device())
x_target = x[: size[1]]
x = self.convs[i]((x, x_target), edge_index)
if i != self.num_layers - 1:
x = F.relu(x)
xs.append(x.cpu())

x_all = torch.cat(xs, dim=0)

return x_all


def train_loop_per_worker(train_loop_config):
dataset = train_loop_config["dataset_fn"]()
batch_size = train_loop_config["batch_size"]
num_epochs = train_loop_config["num_epochs"]

data = dataset[0]
train_idx = data.train_mask.nonzero(as_tuple=False).view(-1)
train_idx = train_idx.split(train_idx.size(0) // train.world_size())[
train.world_rank()
]

train_loader = NeighborSampler(
data.edge_index,
node_idx=train_idx,
sizes=[25, 10],
batch_size=batch_size,
shuffle=True,
)

# Disable distributed sampler since the train_loader has already been split above.
train_loader = train.torch.prepare_data_loader(train_loader, add_dist_sampler=False)

# Do validation on rank 0 worker only.
if train.world_rank() == 0:
subgraph_loader = NeighborSampler(
data.edge_index, node_idx=None, sizes=[-1], batch_size=2048, shuffle=False
)
subgraph_loader = train.torch.prepare_data_loader(
subgraph_loader, add_dist_sampler=False
)

model = SAGE(dataset.num_features, 256, dataset.num_classes)
model = train.torch.prepare_model(model)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

x, y = data.x.to(train.torch.get_device()), data.y.to(train.torch.get_device())

for epoch in range(num_epochs):
model.train()

# ``batch_size`` is the number of samples in the current batch.
# ``n_id`` are the ids of all the nodes used in the computation. This is
# needed to pull in the necessary features just for the current batch that is
# being trained on.
# ``adjs`` is a list of 3 element tuple consisting of ``(edge_index, e_id,
# size)`` for each sample in the batch, where ``edge_index``represent the
# edges of the sampled subgraph, ``e_id`` are the ids of the edges in the
# sample, and ``size`` holds the shape of the subgraph.
# See ``torch_geometric.loader.neighbor_sampler.NeighborSampler`` for more info.
for batch_size, n_id, adjs in train_loader:
optimizer.zero_grad()
out = model(x[n_id], adjs)
loss = F.nll_loss(out, y[n_id[:batch_size]])
loss.backward()
optimizer.step()

if train.world_rank() == 0:
print(f"Epoch: {epoch:03d}, Loss: {loss:.4f}")

train_accuracy = validation_accuracy = test_accuracy = None

# Do validation on rank 0 worker only.
if train.world_rank() == 0:
model.eval()
with torch.no_grad():
out = model.module.inference(x, subgraph_loader)
res = out.argmax(dim=-1) == data.y
train_accuracy = int(res[data.train_mask].sum()) / int(
data.train_mask.sum()
)
validation_accuracy = int(res[data.val_mask].sum()) / int(
data.val_mask.sum()
)
test_accuracy = int(res[data.test_mask].sum()) / int(data.test_mask.sum())

train.report(
train_accuracy=train_accuracy,
validation_accuracy=validation_accuracy,
test_accuracy=test_accuracy,
)


def gen_fake_dataset():
"""Returns a function to be called on each worker that returns a Fake Dataset."""

# For fake dataset, since the dataset is randomized, we create it once on the
# driver, and then send the same dataset to all the training workers.
# Use 10% of nodes for validation and 10% for testing.
fake_dataset = FakeDataset(transform=RandomNodeSplit(num_val=0.1, num_test=0.1))

def gen_dataset():
return fake_dataset

return gen_dataset


def gen_reddit_dataset():
"""Returns a function to be called on each worker that returns Reddit Dataset."""

# For Reddit dataset, we have to download the data on each node, so we create the
# dataset on each training worker.
def gen_dataset():
with FileLock(os.path.expanduser("~/.reddit_dataset_lock")):
dataset = Reddit("./data/Reddit")
return dataset

return gen_dataset


def train_gnn(
num_workers=2, use_gpu=False, epochs=3, global_batch_size=32, dataset="reddit"
):

per_worker_batch_size = global_batch_size // num_workers

trainer = TorchTrainer(
train_loop_per_worker=train_loop_per_worker,
train_loop_config={
"num_epochs": epochs,
"batch_size": per_worker_batch_size,
"dataset_fn": gen_reddit_dataset()
if dataset == "reddit"
else gen_fake_dataset(),
},
scaling_config={"num_workers": num_workers, "use_gpu": use_gpu},
)
result = trainer.fit()
print(result.metrics)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--address", required=False, type=str, help="the address to use for Ray"
)
parser.add_argument(
"--num-workers",
"-n",
type=int,
default=2,
help="Sets number of workers for training.",
)
parser.add_argument(
"--use-gpu", action="store_true", help="Whether to use GPU for training."
)
parser.add_argument(
"--epochs", type=int, default=3, help="Number of epochs to train for."
)
parser.add_argument(
"--global-batch-size",
"-b",
type=int,
default=32,
help="Global batch size to use for training.",
)
parser.add_argument(
"--dataset",
"-d",
type=str,
choices=["reddit", "fake"],
default="reddit",
help="The dataset to use. Either 'reddit' or 'fake' Defaults to 'reddit'.",
)

args, _ = parser.parse_known_args()

train_gnn(
num_workers=args.num_workers,
use_gpu=args.use_gpu,
epochs=args.epochs,
global_batch_size=args.global_batch_size,
dataset=args.dataset,
)
3 changes: 2 additions & 1 deletion python/ray/train/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,8 @@ def _wait_for_batch(self, item):
# the tensor might be freed once it is no longer used by
# the creator stream.
for i in item:
i.record_stream(curr_stream)
if isinstance(i, torch.Tensor):
i.record_stream(curr_stream)

def __len__(self):
return len(self._dataloader)
Expand Down
2 changes: 1 addition & 1 deletion python/requirements/ml/requirements_dl.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
tensorflow==2.6.2
tensorflow-probability==0.14.1

# If you make changes below this line, please also change `requirements_ml_docker.txt`!
# If you make changes to the torch versions below, please also make the corresponding changes to `requirements_ml_docker.txt`!

torch==1.9.0;sys_platform=="darwin"
torchvision==0.10.0;sys_platform=="darwin"
Expand Down
11 changes: 9 additions & 2 deletions python/requirements_ml_docker.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,15 @@ ipython
# Needed for Ray Client error message serialization/deserialization.
tblib

# If you make changes to anything below this line, please also change `requirements_dl.txt`!

# If you make changes to the torch versions, please also make the corresponding changes to `requirements_dl.txt`!
-f https://download.pytorch.org/whl/torch_stable.html
torch==1.9.0+cu111
torchvision==0.10.0+cu111

-f https://data.pyg.org/whl/torch-1.9.0+cu111.html
torch-scatter==2.0.9
torch-sparse==0.6.12

# torch-geometric has to be installed after torch-scatter and torch-sparse.
torch-geometric==2.0.3; python_version < '3.7'
torch-geometric==2.0.4; python_version >= '3.7'

0 comments on commit 732175e

Please sign in to comment.