Skip to content

Commit

Permalink
[Sparse] SIGN example. (dmlc#4908)
Browse files Browse the repository at this point in the history
Co-authored-by: Steve <[email protected]>
Co-authored-by: Mufei Li <[email protected]>
  • Loading branch information
3 people authored Nov 17, 2022
1 parent 3103f8c commit 53895b9
Showing 1 changed file with 120 additions and 0 deletions.
120 changes: 120 additions & 0 deletions examples/sparse/sign.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
"""
[SIGN: Scalable Inception Graph Neural Networks]
(https://arxiv.org/abs/2004.11198)
This example shows a simplified version of SIGN: a precomputed 2-hops diffusion
operator on top of symmetrically normalized adjacency matrix A_hat.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam

from dgl.data import CoraGraphDataset
from dgl.mock_sparse import create_from_coo, diag, identity

################################################################################
# (HIGHLIGHT) Take the advantage of DGL sparse APIs to implement the feature
# diffusion in SIGN laconically.
################################################################################
def sign_diffusion(A, X, r):
# Perform the r-hop diffusion operation.
X_sign = [X]
for _ in range(r):
X = A @ X
X_sign.append(X)
return X_sign


class SIGN(nn.Module):
def __init__(self, in_size, out_size, r, hidden_size=256):
super().__init__()
# Note that theta and omega refer to the learnable matrices in the
# original paper correspondingly. The variable r refers to subscript to
# theta.
self.theta = nn.ModuleList(
[nn.Linear(in_size, hidden_size) for _ in range(r + 1)]
)
self.omega = nn.Linear(hidden_size * (r + 1), out_size)

def forward(self, X_sign):
results = []
for i in range(len(X_sign)):
results.append(self.theta[i](X_sign[i]))
Z = F.relu(torch.cat(results, dim=1))
return self.omega(Z)


def evaluate(g, pred):
label = g.ndata["label"]
val_mask = g.ndata["val_mask"]
test_mask = g.ndata["test_mask"]

# Compute accuracy on validation/test set.
val_acc = (pred[val_mask] == label[val_mask]).float().mean()
test_acc = (pred[test_mask] == label[test_mask]).float().mean()
return val_acc, test_acc


def train(g, model):
labels = g.ndata["label"]
train_mask = g.ndata["train_mask"]
optimizer = Adam(model.parameters(), lr=3e-3)

for epoch in range(10):
# Forward.
logits = model(X_sign)

# Compute loss with nodes in training set.
loss = F.cross_entropy(logits[train_mask], labels[train_mask])

# Backward.
optimizer.zero_grad()
loss.backward()
optimizer.step()

# Compute prediction.
pred = logits.argmax(1)

# Evaluate the prediction.
val_acc, test_acc = evaluate(g, pred)
print(
f"In epoch {epoch}, loss: {loss:.3f}, val acc: {val_acc:.3f}, test"
f" acc: {test_acc:.3f}"
)


if __name__ == "__main__":
# If CUDA is available, use GPU to accelerate the training, use CPU
# otherwise.
dev = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Load graph from the existing dataset.
dataset = CoraGraphDataset()
g = dataset[0].to(dev)

# Create the sparse adjacency matrix A (note that W was used as the notation
# for adjacency matrix in the original paper).
src, dst = g.edges()
N = g.num_nodes()
A = create_from_coo(dst, src, shape=(N, N))

# Calculate the symmetrically normalized adjacency matrix.
I = identity(A.shape, device=dev)
A_hat = A + I
D_hat = diag(A_hat.sum(dim=1)) ** -0.5
A_hat = D_hat @ A_hat @ D_hat

# 2-hop diffusion.
r = 2
X = g.ndata["feat"]
X_sign = sign_diffusion(A_hat, X, r)

# Create SIGN model.
in_size = X.shape[1]
out_size = dataset.num_classes
model = SIGN(in_size, out_size, r).to(dev)

# Kick off training.
train(g, model)

0 comments on commit 53895b9

Please sign in to comment.