-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* gnnfilm * update * readme Co-authored-by: zhjwy9343 <[email protected]>
- Loading branch information
1 parent
c448e61
commit a9e16c1
Showing
6 changed files
with
318 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
# DGL Implementation of the GNN-FiLM Model | ||
|
||
This DGL example implements the GNN model proposed in the paper [GNN-FiLM: Graph Neural Networks with Feature-wise Linear Modulation](https://arxiv.org/pdf/1906.12192.pdf). | ||
The author's codes of implementation is in [here](https://github.com/Microsoft/tf-gnn-samples) | ||
|
||
|
||
Example implementor | ||
---------------------- | ||
This example was implemented by [Kounianhua Du](https://github.com/KounianhuaDu) during her Software Dev Engineer Intern work at the AWS Shanghai AI Lab. | ||
|
||
|
||
The graph dataset used in this example | ||
--------------------------------------- | ||
The DGL's built-in PPIDataset. This is a Protein-Protein Interaction dataset for inductive node classification. The PPIDataset is a toy Protein-Protein Interaction network dataset. The dataset contains 24 graphs. The average number of nodes per graph is 2372. Each node has 50 features and 121 labels. There are 20 graphs for training, 2 for validation, and 2 for testing. | ||
|
||
NOTE: Following the paper, in addition to the dataset-provided untyped edges, a fresh "self-loop" edge type is added. | ||
|
||
Statistics: | ||
- Train examples: 20 | ||
- Valid examples: 2 | ||
- Test examples: 2 | ||
- AvgNodesPerGraph: 2372 | ||
- NumFeats: 50 | ||
- NumLabels: 121 | ||
|
||
|
||
How to run example files | ||
-------------------------------- | ||
In the GNNFiLM folder, run | ||
|
||
```bash | ||
python main.py | ||
``` | ||
|
||
If want to use a GPU, run | ||
|
||
```bash | ||
python main.py --gpu ${your_device_id_here} | ||
``` | ||
|
||
|
||
Performance | ||
------------------------- | ||
|
||
NOTE: We do not perform grid search or finetune here, so there is a gap between the performance reported in the original paper and this example. Below results, mean(standard deviation), were computed over ten runs. | ||
|
||
**GNN-FiLM results on PPI task** | ||
| Model | Paper (tensorflow) | ours (dgl) | | ||
| ------------- | -------------------------------- | --------------------------- | | ||
| Avg. Micro-F1 | 0.992 (0.000) | 0.983 (0.001) | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
from torch.utils.data import Dataset, DataLoader | ||
import dgl | ||
from dgl.data import PPIDataset | ||
import collections | ||
|
||
#implement the collate_fn for dgl graph data class | ||
PPIBatch = collections.namedtuple('PPIBatch', ['graph', 'label']) | ||
def batcher(device): | ||
def batcher_dev(batch): | ||
batch_graphs = dgl.batch(batch) | ||
return PPIBatch(graph=batch_graphs, | ||
label=batch_graphs.ndata['label'].to(device)) | ||
return batcher_dev | ||
|
||
#add a fresh "self-loop" edge type to the untyped PPI dataset and prepare train, val, test loaders | ||
def load_PPI(batch_size=1, device='cpu'): | ||
train_set = PPIDataset(mode='train') | ||
valid_set = PPIDataset(mode='valid') | ||
test_set = PPIDataset(mode='test') | ||
#for each graph, add self-loops as a new relation type | ||
#here we reconstruct the graph since the schema of a heterograph cannot be changed once constructed | ||
for i in range(len(train_set)): | ||
g = dgl.heterograph({ | ||
('_N','_E','_N'): train_set[i].edges(), | ||
('_N', 'self', '_N'): (train_set[i].nodes(), train_set[i].nodes()) | ||
}) | ||
g.ndata['label'] = train_set[i].ndata['label'] | ||
g.ndata['feat'] = train_set[i].ndata['feat'] | ||
g.ndata['_ID'] = train_set[i].ndata['_ID'] | ||
g.edges['_E'].data['_ID'] = train_set[i].edata['_ID'] | ||
train_set.graphs[i] = g | ||
for i in range(len(valid_set)): | ||
g = dgl.heterograph({ | ||
('_N','_E','_N'): valid_set[i].edges(), | ||
('_N', 'self', '_N'): (valid_set[i].nodes(), valid_set[i].nodes()) | ||
}) | ||
g.ndata['label'] = valid_set[i].ndata['label'] | ||
g.ndata['feat'] = valid_set[i].ndata['feat'] | ||
g.ndata['_ID'] = valid_set[i].ndata['_ID'] | ||
g.edges['_E'].data['_ID'] = valid_set[i].edata['_ID'] | ||
valid_set.graphs[i] = g | ||
for i in range(len(test_set)): | ||
g = dgl.heterograph({ | ||
('_N','_E','_N'): test_set[i].edges(), | ||
('_N', 'self', '_N'): (test_set[i].nodes(), test_set[i].nodes()) | ||
}) | ||
g.ndata['label'] = test_set[i].ndata['label'] | ||
g.ndata['feat'] = test_set[i].ndata['feat'] | ||
g.ndata['_ID'] = test_set[i].ndata['_ID'] | ||
g.edges['_E'].data['_ID'] = test_set[i].edata['_ID'] | ||
test_set.graphs[i] = g | ||
|
||
etypes = train_set[0].etypes | ||
in_size = train_set[0].ndata['feat'].shape[1] | ||
out_size = train_set[0].ndata['label'].shape[1] | ||
|
||
#prepare train, valid, and test dataloaders | ||
train_loader = DataLoader(train_set, batch_size=batch_size, collate_fn=batcher(device), shuffle=True) | ||
valid_loader = DataLoader(valid_set, batch_size=batch_size, collate_fn=batcher(device), shuffle=True) | ||
test_loader = DataLoader(test_set, batch_size=batch_size, collate_fn=batcher(device), shuffle=True) | ||
return train_loader, valid_loader, test_loader, etypes, in_size, out_size |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,192 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import dgl | ||
import dgl.function as fn | ||
from utils import evaluate_f1_score | ||
from data_loader import load_PPI | ||
import argparse | ||
import numpy as np | ||
import os | ||
|
||
class GNNFiLMLayer(nn.Module): | ||
def __init__(self, in_size, out_size, etypes, dropout=0.1): | ||
super(GNNFiLMLayer, self).__init__() | ||
self.in_size = in_size | ||
self.out_size = out_size | ||
|
||
#weights for different types of edges | ||
self.W = nn.ModuleDict({ | ||
name : nn.Linear(in_size, out_size, bias = False) for name in etypes | ||
}) | ||
|
||
#hypernets to learn the affine functions for different types of edges | ||
self.film = nn.ModuleDict({ | ||
name : nn.Linear(in_size, 2*out_size, bias = False) for name in etypes | ||
}) | ||
|
||
#layernorm before each propogation | ||
self.layernorm = nn.LayerNorm(out_size) | ||
|
||
#dropout layer | ||
self.dropout = nn.Dropout(dropout) | ||
|
||
def forward(self, g, feat_dict): | ||
#the input graph is a multi-relational graph, so treated as hetero-graph. | ||
|
||
funcs = {} #message and reduce functions dict | ||
#for each type of edges, compute messages and reduce them all | ||
for srctype, etype, dsttype in g.canonical_etypes: | ||
messages = self.W[etype](feat_dict[srctype]) #apply W_l on src feature | ||
film_weights = self.film[etype](feat_dict[dsttype]) #use dst feature to compute affine function paras | ||
gamma = film_weights[:,:self.out_size] #"gamma" for the affine function | ||
beta = film_weights[:,self.out_size:] #"beta" for the affine function | ||
messages = gamma * messages + beta #compute messages | ||
messages = F.relu_(messages) | ||
g.nodes[srctype].data[etype] = messages #store in ndata | ||
funcs[etype] = (fn.copy_u(etype, 'm'), fn.sum('m', 'h')) #define message and reduce functions | ||
g.multi_update_all(funcs, 'sum') #update all, reduce by first type-wisely then across different types | ||
feat_dict={} | ||
for ntype in g.ntypes: | ||
feat_dict[ntype] = self.dropout(self.layernorm(g.nodes[ntype].data['h'])) #apply layernorm and dropout | ||
return feat_dict | ||
|
||
class GNNFiLM(nn.Module): | ||
def __init__(self, etypes, in_size, hidden_size, out_size, num_layers, dropout=0.1): | ||
super(GNNFiLM, self).__init__() | ||
self.film_layers = nn.ModuleList() | ||
self.film_layers.append( | ||
GNNFiLMLayer(in_size, hidden_size, etypes, dropout) | ||
) | ||
for i in range(num_layers-1): | ||
self.film_layers.append( | ||
GNNFiLMLayer(hidden_size, hidden_size, etypes, dropout) | ||
) | ||
self.predict = nn.Linear(hidden_size, out_size, bias = True) | ||
|
||
def forward(self, g, out_key): | ||
h_dict = {ntype : g.nodes[ntype].data['feat'] for ntype in g.ntypes} #prepare input feature dict | ||
for layer in self.film_layers: | ||
h_dict = layer(g, h_dict) | ||
h = self.predict(h_dict[out_key]) #use the final embed to predict, out_size = num_classes | ||
h = torch.sigmoid(h) | ||
return h | ||
|
||
def main(args): | ||
# Step 1: Prepare graph data and retrieve train/validation/test dataloader ============================= # | ||
if args.gpu >= 0 and torch.cuda.is_available(): | ||
device = 'cuda:{}'.format(args.gpu) | ||
else: | ||
device = 'cpu' | ||
|
||
if args.dataset == 'PPI': | ||
train_set, valid_set, test_set, etypes, in_size, out_size = load_PPI(args.batch_size, device) | ||
|
||
# Step 2: Create model and training components=========================================================== # | ||
model = GNNFiLM(etypes, in_size, args.hidden_size, out_size, args.num_layers).to(device) | ||
criterion = nn.BCELoss() | ||
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd) | ||
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.step_size, gamma=args.gamma) | ||
|
||
# Step 4: training epoches ============================================================================== # | ||
lastf1 = 0 | ||
cnt = 0 | ||
best_val_f1 = 0 | ||
for epoch in range(args.max_epoch): | ||
train_loss = [] | ||
train_f1 = [] | ||
val_loss = [] | ||
val_f1 = [] | ||
model.train() | ||
for batch in train_set: | ||
g = batch.graph | ||
g = g.to(device) | ||
logits = model.forward(g, '_N') | ||
labels = batch.label | ||
loss = criterion(logits, labels) | ||
f1 = evaluate_f1_score(logits.detach().cpu().numpy(), labels.detach().cpu().numpy()) | ||
|
||
optimizer.zero_grad() | ||
loss.backward() | ||
optimizer.step() | ||
train_loss.append(loss.item()) | ||
train_f1.append(f1) | ||
|
||
train_loss = np.mean(train_loss) | ||
train_f1 = np.mean(train_f1) | ||
scheduler.step() | ||
|
||
model.eval() | ||
with torch.no_grad(): | ||
for batch in valid_set: | ||
g = batch.graph | ||
g = g.to(device) | ||
logits = model.forward(g, '_N') | ||
labels = batch.label | ||
loss = criterion(logits, labels) | ||
f1 = evaluate_f1_score(logits.detach().cpu().numpy(), labels.detach().cpu().numpy()) | ||
val_loss.append(loss.item()) | ||
val_f1.append(f1) | ||
|
||
val_loss = np.mean(val_loss) | ||
val_f1 = np.mean(val_f1) | ||
print('Epoch {:d} | Train Loss {:.4f} | Train F1 {:.4f} | Val Loss {:.4f} | Val F1 {:.4f} |'.format(epoch + 1, train_loss, train_f1, val_loss, val_f1)) | ||
if val_f1 > best_val_f1: | ||
best_val_f1 = val_f1 | ||
torch.save(model.state_dict(), os.path.join(args.save_dir, args.name)) | ||
|
||
if val_f1 < lastf1: | ||
cnt += 1 | ||
if cnt == args.early_stopping: | ||
print('Early stop.') | ||
break | ||
else: | ||
cnt = 0 | ||
lastf1 = val_f1 | ||
|
||
model.eval() | ||
test_loss = [] | ||
test_f1 = [] | ||
model.load_state_dict(torch.load(os.path.join(args.save_dir, args.name))) | ||
with torch.no_grad(): | ||
for batch in test_set: | ||
g = batch.graph | ||
g = g.to(device) | ||
logits = model.forward(g, '_N') | ||
labels = batch.label | ||
loss = criterion(logits, labels) | ||
f1 = evaluate_f1_score(logits.detach().cpu().numpy(), labels.detach().cpu().numpy()) | ||
test_loss.append(loss.item()) | ||
test_f1.append(f1) | ||
test_loss = np.mean(test_loss) | ||
test_f1 = np.mean(test_f1) | ||
|
||
print("Test F1: {:.4f} | Test loss: {:.4f}".format(test_f1, test_loss)) | ||
|
||
|
||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser(description='GNN-FiLM') | ||
parser.add_argument("--dataset", type=str, default="PPI", help="DGL dataset for this GNN-FiLM") | ||
parser.add_argument("--gpu", type=int, default=-1, help="GPU Index. Default: -1, using CPU.") | ||
parser.add_argument("--in_size", type=int, default=50, help="Input dimensionalities") | ||
parser.add_argument("--hidden_size", type=int, default=320, help="Hidden layer dimensionalities") | ||
parser.add_argument("--out_size", type=int, default=121, help="Output dimensionalities") | ||
parser.add_argument("--num_layers", type=int, default=4, help="Number of GNN layers") | ||
parser.add_argument("--batch_size", type=int, default=5, help="Batch size") | ||
parser.add_argument("--max_epoch", type=int, default=1500, help="The max number of epoches. Default: 500") | ||
parser.add_argument("--early_stopping", type=int, default=80, help="Early stopping. Default: 50") | ||
parser.add_argument("--lr", type=float, default=0.001, help="Learning rate. Default: 3e-1") | ||
parser.add_argument("--wd", type=float, default=0.0009, help="Weight decay. Default: 3e-1") | ||
parser.add_argument('--step-size', type=int, default=40, help='Period of learning rate decay.') | ||
parser.add_argument('--gamma', type=float, default=0.8, help='Multiplicative factor of learning rate decay.') | ||
parser.add_argument("--dropout", type=float, default=0.1, help="Dropout rate. Default: 0.9") | ||
parser.add_argument('--save_dir', type=str, default='./out', help='Path to save the model.') | ||
parser.add_argument("--name", type=str, default='GNN-FiLM', help="Saved model name.") | ||
|
||
args = parser.parse_args() | ||
print(args) | ||
if not os.path.exists(args.save_dir): | ||
os.mkdir(args.save_dir) | ||
main(args) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
|
||
from sklearn.metrics import f1_score | ||
import numpy as np | ||
|
||
#function to compute f1 score | ||
def evaluate_f1_score(pred, label): | ||
pred = np.round(pred, 0).astype(np.int16) | ||
pred = pred.flatten() | ||
label = label.flatten() | ||
return f1_score(y_pred=pred, y_true=label) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
out/ |