Skip to content

Commit

Permalink
[GraphBolt] Modify examples to use seeds. (dmlc#7231)
Browse files Browse the repository at this point in the history
Co-authored-by: Ubuntu <[email protected]>
  • Loading branch information
yxy235 and Ubuntu authored Apr 1, 2024
1 parent b725ee5 commit efe2849
Show file tree
Hide file tree
Showing 10 changed files with 73 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def val_dataloader(self):
)
args = parser.parse_args()

dataset = gb.BuiltinDataset("ogbn-products").load()
dataset = gb.BuiltinDataset("ogbn-products-seeds").load()
datamodule = DataModule(
dataset,
[10, 10, 10],
Expand Down
52 changes: 24 additions & 28 deletions examples/sampling/graphbolt/link_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
import torch.nn as nn
import torch.nn.functional as F
import tqdm
from ogb.linkproppred import Evaluator
from torchmetrics.retrieval import RetrievalMRR


class SAGE(nn.Module):
Expand Down Expand Up @@ -243,43 +243,38 @@ def create_dataloader(args, graph, features, itemset, is_train=True):


@torch.no_grad()
def compute_mrr(args, model, evaluator, node_emb, src, dst, neg_dst):
def compute_mrr(args, model, node_emb, seeds, labels, indexes):
"""Compute the Mean Reciprocal Rank (MRR) for given source and destination
nodes.
This function computes the MRR for a set of node pairs, dividing the task
into batches to handle potentially large graphs.
"""
rr = torch.zeros(src.shape[0])
# Loop over node pairs in batches.
for start in tqdm.trange(
0, src.shape[0], args.eval_batch_size, desc="Evaluate"
):
end = min(start + args.eval_batch_size, src.shape[0])

# Concatenate positive and negative destination nodes.
all_dst = torch.cat([dst[start:end, None], neg_dst[start:end]], 1)
preds = torch.empty(seeds.shape[0])
mrr = RetrievalMRR()
seeds_src, seeds_dst = seeds.T
# The constant number is 1001, due to negtive ratio in the `ogbl-citation2`
# dataset is 1000.
eval_size = args.eval_batch_size * 1001
# Loop over node pairs in batches.
for start in tqdm.trange(0, seeds_src.shape[0], eval_size, desc="Evaluate"):
end = min(start + eval_size, seeds_src.shape[0])

# Fetch embeddings for current batch of source and destination nodes.
h_src = node_emb[src[start:end]][:, None, :].to(args.device)
h_dst = (
node_emb[all_dst.view(-1)].view(*all_dst.shape, -1).to(args.device)
)
h_src = node_emb[seeds_src[start:end]].to(args.device)
h_dst = node_emb[seeds_dst[start:end]].to(args.device)

# Compute prediction scores using the model.
pred = model.predictor(h_src * h_dst).squeeze(-1)

# Evaluate the predictions to obtain MRR values.
input_dict = {"y_pred_pos": pred[:, 0], "y_pred_neg": pred[:, 1:]}
rr[start:end] = evaluator.eval(input_dict)["mrr_list"]
return rr.mean()
pred = model.predictor(h_src * h_dst).squeeze()
preds[start:end] = pred
return mrr(preds, labels, indexes=indexes)


@torch.no_grad()
def evaluate(args, model, graph, features, all_nodes_set, valid_set, test_set):
"""Evaluate the model on validation and test sets."""
model.eval()
evaluator = Evaluator(name="ogbl-citation2")

dataloader = create_dataloader(
args, graph, features, all_nodes_set, is_train=False
Expand All @@ -292,13 +287,13 @@ def evaluate(args, model, graph, features, all_nodes_set, valid_set, test_set):
# Loop over both validation and test sets.
for split in [valid_set, test_set]:
# Unpack the item set.
src = split._items[0][:, 0].to(node_emb.device)
dst = split._items[0][:, 1].to(node_emb.device)
neg_dst = split._items[1].to(node_emb.device)
seeds = split._items[0].to(node_emb.device)
labels = split._items[1].to(node_emb.device)
indexes = split._items[2].to(node_emb.device)

# Compute MRR values for the current split.
results.append(
compute_mrr(args, model, evaluator, node_emb, src, dst, neg_dst)
compute_mrr(args, model, node_emb, seeds, labels, indexes)
)
return results

Expand All @@ -313,15 +308,16 @@ def train(args, model, graph, features, train_set):
start_epoch_time = time.time()
for step, data in tqdm.tqdm(enumerate(dataloader)):
# Get node pairs with labels for loss calculation.
compacted_pairs, labels = data.node_pairs_with_labels
compacted_seeds = data.compacted_seeds.T
labels = data.labels

node_feature = data.node_features["feat"]
blocks = data.blocks

# Get the embeddings of the input nodes.
y = model(blocks, node_feature)
logits = model.predictor(
y[compacted_pairs[0]] * y[compacted_pairs[1]]
y[compacted_seeds[0]] * y[compacted_seeds[1]]
).squeeze()

# Compute loss.
Expand Down Expand Up @@ -389,7 +385,7 @@ def main(args):

# Load and preprocess dataset.
print("Loading data")
dataset = gb.BuiltinDataset("ogbl-citation2").load()
dataset = gb.BuiltinDataset("ogbl-citation2-seeds").load()

# Move the dataset to the selected storage.
if args.storage_device == "pinned":
Expand Down
10 changes: 7 additions & 3 deletions examples/sampling/graphbolt/node_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def create_dataloader(
# ensures that the rest of the operations run on the GPU.
############################################################################
if args.storage_device != "cpu":
datapipe = datapipe.copy_to(device=device, extra_attrs=["seed_nodes"])
datapipe = datapipe.copy_to(device=device, extra_attrs=["seeds"])

############################################################################
# [Step-3]:
Expand Down Expand Up @@ -364,8 +364,12 @@ def parse_args():
parser.add_argument(
"--dataset",
type=str,
default="ogbn-products",
choices=["ogbn-arxiv", "ogbn-products", "ogbn-papers100M"],
default="ogbn-products-seeds",
choices=[
"ogbn-arxiv-seeds",
"ogbn-products-seeds",
"ogbn-papers100M-seeds",
],
help="The dataset we can use for node classification example. Currently"
" ogbn-products, ogbn-arxiv, ogbn-papers100M datasets are supported.",
)
Expand Down
5 changes: 3 additions & 2 deletions examples/sampling/graphbolt/pyg/node_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,8 +208,9 @@ def main():
parser.add_argument(
"--dataset",
type=str,
default="ogbn-products",
help='Name of the dataset to use (e.g., "ogbn-products", "ogbn-arxiv")',
default="ogbn-products-seeds",
help='Name of the dataset to use (e.g., "ogbn-products-seeds",'
+ ' "ogbn-arxiv-seeds")',
)
parser.add_argument(
"--epochs", type=int, default=10, help="Number of training epochs."
Expand Down
10 changes: 7 additions & 3 deletions examples/sampling/graphbolt/pyg/node_classification_advanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def create_dataloader(
)
# Copy the data to the specified device.
if args.graph_device != "cpu":
datapipe = datapipe.copy_to(device=device, extra_attrs=["seed_nodes"])
datapipe = datapipe.copy_to(device=device, extra_attrs=["seeds"])
# Sample neighbors for each node in the mini-batch.
datapipe = getattr(datapipe, args.sample_mode)(
graph, fanout if job != "infer" else [-1]
Expand Down Expand Up @@ -320,8 +320,12 @@ def parse_args():
parser.add_argument(
"--dataset",
type=str,
default="ogbn-products",
choices=["ogbn-arxiv", "ogbn-products", "ogbn-papers100M"],
default="ogbn-products-seeds",
choices=[
"ogbn-arxiv-seeds",
"ogbn-products-seeds",
"ogbn-papers100M-seeds",
],
help="The dataset we can use for node classification example. Currently"
" ogbn-products, ogbn-arxiv, ogbn-papers100M datasets are supported.",
)
Expand Down
12 changes: 7 additions & 5 deletions examples/sampling/graphbolt/quickstart/link_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ def evaluate(model, dataset, device):
labels = []
for step, data in enumerate(dataloader):
# Get node pairs with labels for loss calculation.
compacted_pairs, label = data.node_pairs_with_labels
compacted_seeds = data.compacted_seeds.T
label = data.labels

# The features of sampled nodes.
x = data.node_features["feat"]
Expand All @@ -94,7 +95,7 @@ def evaluate(model, dataset, device):
y = model(data.blocks, x)
logit = (
model.predictor(
y[compacted_pairs[0].long()] * y[compacted_pairs[1].long()]
y[compacted_seeds[0].long()] * y[compacted_seeds[1].long()]
)
.squeeze()
.detach()
Expand Down Expand Up @@ -126,15 +127,16 @@ def train(model, dataset, device):
########################################################################
for step, data in enumerate(dataloader):
# Get node pairs with labels for loss calculation.
compacted_pairs, labels = data.node_pairs_with_labels
compacted_seeds = data.compacted_seeds.T
labels = data.labels

# The features of sampled nodes.
x = data.node_features["feat"]

# Forward.
y = model(data.blocks, x)
logits = model.predictor(
y[compacted_pairs[0].long()] * y[compacted_pairs[1].long()]
y[compacted_seeds[0].long()] * y[compacted_seeds[1].long()]
).squeeze()

# Compute loss.
Expand All @@ -156,7 +158,7 @@ def train(model, dataset, device):

# Load and preprocess dataset.
print("Loading data...")
dataset = gb.BuiltinDataset("cora").load()
dataset = gb.BuiltinDataset("cora-seeds").load()

# If a CUDA device is selected, we pin the graph and the features so that
# the GPU can access them.
Expand Down
4 changes: 2 additions & 2 deletions examples/sampling/graphbolt/quickstart/node_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def create_dataloader(dataset, itemset, device):
datapipe = gb.ItemSampler(itemset, batch_size=16)

# Copy the mini-batch to the designated device for sampling and training.
datapipe = datapipe.copy_to(device, extra_attrs=["seed_nodes"])
datapipe = datapipe.copy_to(device, extra_attrs=["seeds"])

# Sample neighbors for the seed nodes.
datapipe = datapipe.sample_neighbor(dataset.graph, fanouts=[4, 2])
Expand Down Expand Up @@ -117,7 +117,7 @@ def train(model, dataset, device):

# Load and preprocess dataset.
print("Loading data...")
dataset = gb.BuiltinDataset("cora").load()
dataset = gb.BuiltinDataset("cora-seeds").load()

# If a CUDA device is selected, we pin the graph and the features so that
# the GPU can access them.
Expand Down
16 changes: 8 additions & 8 deletions examples/sampling/graphbolt/rgcn/hetero_rgcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def create_dataloader(
# Move the mini-batch to the appropriate device.
# `device`:
# The device to move the mini-batch to.
datapipe = datapipe.copy_to(device, extra_attrs=["seed_nodes"])
datapipe = datapipe.copy_to(device, extra_attrs=["seeds"])

# Sample neighbors for each seed node in the mini-batch.
# `graph`:
Expand Down Expand Up @@ -153,7 +153,7 @@ def extract_embed(node_embed, input_nodes):

def extract_node_features(name, block, data, node_embed, device):
"""Extract the node features from embedding layer or raw features."""
if name == "ogbn-mag":
if name == "ogbn-mag-seeds":
input_nodes = {
k: v.to(device) for k, v in block.srcdata[dgl.NID].items()
}
Expand Down Expand Up @@ -419,8 +419,8 @@ def evaluate(
model.eval()
category = "paper"
# An evaluator for the dataset.
if name == "ogbn-mag":
evaluator = Evaluator(name=name)
if name == "ogbn-mag-seeds":
evaluator = Evaluator(name="ogbn-mag")
else:
evaluator = MAG240MEvaluator()

Expand Down Expand Up @@ -578,7 +578,7 @@ def main(args):
# `institution` are generated in advance and stored in the feature store.
# For `ogbn-mag`, we generate the features on the fly.
embed_layer = None
if args.dataset == "ogbn-mag":
if args.dataset == "ogbn-mag-seeds":
# Create the embedding layer and move it to the appropriate device.
embed_layer = rel_graph_embed(g, feat_size).to(device)
print(
Expand Down Expand Up @@ -652,9 +652,9 @@ def main(args):
parser.add_argument(
"--dataset",
type=str,
default="ogbn-mag",
choices=["ogbn-mag", "ogb-lsc-mag240m"],
help="Dataset name. Possible values: ogbn-mag, ogb-lsc-mag240m",
default="ogbn-mag-seeds",
choices=["ogbn-mag-seeds", "ogb-lsc-mag240m"],
help="Dataset name. Possible values: ogbn-mag-seeds, ogb-lsc-mag240m",
)
parser.add_argument("--num_epochs", type=int, default=3)
parser.add_argument("--num_workers", type=int, default=0)
Expand Down
6 changes: 3 additions & 3 deletions examples/sampling/graphbolt/sparse/graphsage.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def __init__(self, datapipe, matrix, fanouts):

def sample_subgraphs(self, seeds, seeds_timestamp=None):
sampled_matrices = []
src = seeds
src = seeds.long()

#####################################################################
# (HIGHLIGHT) Using the sparse sample operator to preform random
Expand Down Expand Up @@ -242,7 +242,7 @@ def train(device, A, features, dataset, num_classes, model):
# Load and preprocess dataset.
print("Loading data")
device = torch.device("cpu" if args.mode == "cpu" else "cuda")
dataset = gb.BuiltinDataset("ogbn-products").load()
dataset = gb.BuiltinDataset("ogbn-products-seeds").load()
g = dataset.graph
features = dataset.feature

Expand All @@ -254,7 +254,7 @@ def train(device, A, features, dataset, num_classes, model):

# Create sparse.
N = g.num_nodes
A = dglsp.from_csc(g.csc_indptr, g.indices, shape=(N, N))
A = dglsp.from_csc(g.csc_indptr.long(), g.indices.long(), shape=(N, N))

# Model training.
print("Training...")
Expand Down
12 changes: 11 additions & 1 deletion python/dgl/graphbolt/impl/ondisk_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -996,12 +996,22 @@ class BuiltinDataset(OnDiskDataset):
)
_datasets = [
"cora",
"cora-seeds",
"ogbn-mag",
"ogbn-mag-seeds",
"ogbl-citation2",
"ogbl-citation2-seeds",
"ogbn-products",
"ogbn-products-seeds",
"ogbn-arxiv",
"ogbn-arxiv-seeds",
]
_large_datasets = [
"ogb-lsc-mag240m",
"ogb-lsc-mag240m-seeds",
"ogbn-papers100M",
"ogbn-papers100M-seeds",
]
_large_datasets = ["ogb-lsc-mag240m", "ogbn-papers100M"]
_all_datasets = _datasets + _large_datasets

def __init__(self, name: str, root: str = "datasets") -> OnDiskDataset:
Expand Down

0 comments on commit efe2849

Please sign in to comment.