Skip to content

Commit

Permalink
example gat model
Browse files Browse the repository at this point in the history
  • Loading branch information
lingfanyu committed Jun 20, 2018
1 parent 1248bd2 commit 9f0d9e4
Showing 1 changed file with 170 additions and 0 deletions.
170 changes: 170 additions & 0 deletions examples/pytorch/gat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
import networkx as nx
from dgl.graph import DGLGraph
import torch
import torch.nn as nn
import torch.nn.functional as F
import argparse
from dataset import load_data, preprocess_features
import numpy as np

class NodeUpdateModule(nn.Module):
def __init__(self, input_dim, num_hidden, aggregator, num_heads=3, act=None,
attention_dropout=None, input_dropout=None, residual=False):
super(NodeUpdateModule, self).__init__()
self.num_hidden = num_hidden
self.num_heads = num_heads
self.fc = nn.ModuleList(
[nn.Linear(input_dim, num_hidden, bias=False)
for _ in range(num_heads)])
self.attention = nn.ModuleList(
[nn.Linear(num_hidden * 2, 1, bias=False) for _ in range(num_heads)])
self.act = act
self.attention_dropout = attention_dropout
self.input_dropout = input_dropout
self.aggregator = aggregator
self.residual = residual

def forward(self, node, msgs):
hv = node['h']
hu = torch.cat(msgs, dim=0)

# number of neighbors, including itself
n = len(msgs) + 1

out = []
for i in range(self.num_heads):
hvv = hv
huu = hu
if self.input_dropout is not None:
hvv = F.dropout(hvv, self.input_dropout)
huu = F.dropout(huu, self.input_dropout)
# calc W*hself and W*hneigh
hvv = self.fc[i](hv)
huu = self.fc[i](hu)
# concat itself with neighbors to make self-attention
huu = torch.cat((hvv, huu), dim=0)
# calculate W*hself||W*hneigh
h = torch.cat((hvv.expand(n, -1), huu), dim=1)
a = F.leaky_relu(self.attention[i](h))
a = F.softmax(a, dim=0)
if self.attention_dropout is not None:
a = F.dropout(a, self.attention_dropout)
if self.input_dropout is not None:
hvv = F.dropout(hvv, self.input_dropout)
h = torch.sum(a * hvv, 0, keepdim=True)
# add residual connection
if self.residual:
h += hvv
if self.act is not None:
h = self.act(h)
out.append(h)

# aggregate multi-head results
h = self.aggregator(out)
return {'h': h}


class GAT(nn.Module):
def __init__(self, num_layers, in_dim, num_hidden, num_classes, num_heads,
activation, attention_dropout, input_dropout, use_residual=False):
super(GAT, self).__init__()
self.layers = nn.ModuleList()
# update layers
aggregator = lambda x: torch.cat(x, 1)
for i in range(num_layers):
if i == 0:
last_dim = in_dim
residual = False
else:
last_dim = num_hidden * num_heads # because of concat heads
residual = use_residual
self.layers.append(
NodeUpdateModule(last_dim, num_hidden, aggregator, num_heads,
activation, attention_dropout, input_dropout, residual))
# projection layer
# FIXME: does pytorch has something similar to tf.add_n which sum over a list?
aggregator = lambda x: reduce(lambda a, b: a+b, x)
self.layers.append(NodeUpdateModule(num_hidden * 3, num_classes, aggregator,
1, None, attention_dropout, input_dropout, False))

def forward(self, g):
g.register_message_func(lambda src, dst, edge: src['h'])
for layer in self.layers:
g.register_update_func(layer)
g.update_all()
logits = [g.node[n]['h'] for n in g.nodes()]
logits = torch.cat(logits, dim=0)
return logits


def main(args):
# dropout parameters
input_dropout = 0.2
attention_dropout = 0.2

# load and preprocess dataset
adj, features, y_train, y_val, y_test, train_mask, val_mask, test_mask = load_data(args.dataset)
features = preprocess_features(features)

# initialize graph
g = DGLGraph(adj)

# create model
model = GAT(args.num_layers,
features.shape[1],
args.num_hidden,
y_train.shape[1],
args.num_heads,
F.elu,
attention_dropout,
input_dropout,
args.residual)

# use optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

# convert labels and masks to tensor
labels = torch.FloatTensor(y_train)
mask = torch.FloatTensor(train_mask.astype(np.float32))

for epoch in range(args.epochs):
# reset grad
optimizer.zero_grad()

# reset graph states
for n in g.nodes():
g.node[n]['h'] = torch.FloatTensor(features[n].toarray())

# forward
logits = model.forward(g)

# masked cross entropy loss
# TODO: (lingfan) use gather to speed up
logp = F.log_softmax(logits, 1)
loss = torch.mean(logp * labels * mask.view(-1, 1))
print("epoch {} loss: {}".format(epoch, loss.item()))

loss.backward()
optimizer.step()

if __name__ == '__main__':
parser = argparse.ArgumentParser(description='GAT')
parser.add_argument("--dataset", type=str, required=True,
help="dataset name")
parser.add_argument("--epochs", type=int, default=10,
help="training epoch")
parser.add_argument("--num-heads", type=int, default=3,
help="number of attentional heads to use")
parser.add_argument("--num-layers", type=int, default=1,
help="number of hidden layers")
parser.add_argument("--num-hidden", type=int, default=8,
help="size of hidden units")
parser.add_argument("--residual", action="store_true",
help="use residual connection")
parser.add_argument("--lr", type=float, default=0.001,
help="learning rate")
args = parser.parse_args()
print(args)

main(args)

0 comments on commit 9f0d9e4

Please sign in to comment.