forked from dmlc/dgl
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Co-authored-by: Steve <[email protected]> Co-authored-by: Mufei Li <[email protected]>
- Loading branch information
1 parent
3103f8c
commit 53895b9
Showing
1 changed file
with
120 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
""" | ||
[SIGN: Scalable Inception Graph Neural Networks] | ||
(https://arxiv.org/abs/2004.11198) | ||
This example shows a simplified version of SIGN: a precomputed 2-hops diffusion | ||
operator on top of symmetrically normalized adjacency matrix A_hat. | ||
""" | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from torch.optim import Adam | ||
|
||
from dgl.data import CoraGraphDataset | ||
from dgl.mock_sparse import create_from_coo, diag, identity | ||
|
||
################################################################################ | ||
# (HIGHLIGHT) Take the advantage of DGL sparse APIs to implement the feature | ||
# diffusion in SIGN laconically. | ||
################################################################################ | ||
def sign_diffusion(A, X, r): | ||
# Perform the r-hop diffusion operation. | ||
X_sign = [X] | ||
for _ in range(r): | ||
X = A @ X | ||
X_sign.append(X) | ||
return X_sign | ||
|
||
|
||
class SIGN(nn.Module): | ||
def __init__(self, in_size, out_size, r, hidden_size=256): | ||
super().__init__() | ||
# Note that theta and omega refer to the learnable matrices in the | ||
# original paper correspondingly. The variable r refers to subscript to | ||
# theta. | ||
self.theta = nn.ModuleList( | ||
[nn.Linear(in_size, hidden_size) for _ in range(r + 1)] | ||
) | ||
self.omega = nn.Linear(hidden_size * (r + 1), out_size) | ||
|
||
def forward(self, X_sign): | ||
results = [] | ||
for i in range(len(X_sign)): | ||
results.append(self.theta[i](X_sign[i])) | ||
Z = F.relu(torch.cat(results, dim=1)) | ||
return self.omega(Z) | ||
|
||
|
||
def evaluate(g, pred): | ||
label = g.ndata["label"] | ||
val_mask = g.ndata["val_mask"] | ||
test_mask = g.ndata["test_mask"] | ||
|
||
# Compute accuracy on validation/test set. | ||
val_acc = (pred[val_mask] == label[val_mask]).float().mean() | ||
test_acc = (pred[test_mask] == label[test_mask]).float().mean() | ||
return val_acc, test_acc | ||
|
||
|
||
def train(g, model): | ||
labels = g.ndata["label"] | ||
train_mask = g.ndata["train_mask"] | ||
optimizer = Adam(model.parameters(), lr=3e-3) | ||
|
||
for epoch in range(10): | ||
# Forward. | ||
logits = model(X_sign) | ||
|
||
# Compute loss with nodes in training set. | ||
loss = F.cross_entropy(logits[train_mask], labels[train_mask]) | ||
|
||
# Backward. | ||
optimizer.zero_grad() | ||
loss.backward() | ||
optimizer.step() | ||
|
||
# Compute prediction. | ||
pred = logits.argmax(1) | ||
|
||
# Evaluate the prediction. | ||
val_acc, test_acc = evaluate(g, pred) | ||
print( | ||
f"In epoch {epoch}, loss: {loss:.3f}, val acc: {val_acc:.3f}, test" | ||
f" acc: {test_acc:.3f}" | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
# If CUDA is available, use GPU to accelerate the training, use CPU | ||
# otherwise. | ||
dev = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | ||
|
||
# Load graph from the existing dataset. | ||
dataset = CoraGraphDataset() | ||
g = dataset[0].to(dev) | ||
|
||
# Create the sparse adjacency matrix A (note that W was used as the notation | ||
# for adjacency matrix in the original paper). | ||
src, dst = g.edges() | ||
N = g.num_nodes() | ||
A = create_from_coo(dst, src, shape=(N, N)) | ||
|
||
# Calculate the symmetrically normalized adjacency matrix. | ||
I = identity(A.shape, device=dev) | ||
A_hat = A + I | ||
D_hat = diag(A_hat.sum(dim=1)) ** -0.5 | ||
A_hat = D_hat @ A_hat @ D_hat | ||
|
||
# 2-hop diffusion. | ||
r = 2 | ||
X = g.ndata["feat"] | ||
X_sign = sign_diffusion(A_hat, X, r) | ||
|
||
# Create SIGN model. | ||
in_size = X.shape[1] | ||
out_size = dataset.num_classes | ||
model = SIGN(in_size, out_size, r).to(dev) | ||
|
||
# Kick off training. | ||
train(g, model) |