-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Co-authored-by: maqy1995 <maqingyun0718> Co-authored-by: Quan (Andy) Gan <[email protected]>
- Loading branch information
Showing
1 changed file
with
271 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,271 @@ | ||
# -*- coding: utf-8 -*- | ||
""" | ||
HAN mini-batch training by RandomWalkSampler. | ||
note: This demo use RandomWalkSampler to sample neighbors, it's hard to get all neighbors when valid or test, | ||
so we sampled twice as many neighbors during val/test than training. | ||
""" | ||
import dgl | ||
import numpy | ||
import argparse | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from dgl.nn.pytorch import GATConv | ||
|
||
from dgl.sampling import RandomWalkNeighborSampler | ||
from sklearn.metrics import f1_score | ||
from torch.utils.data import DataLoader | ||
|
||
from model_hetero import SemanticAttention | ||
from utils import EarlyStopping, set_random_seed | ||
|
||
|
||
class HANLayer(torch.nn.Module): | ||
""" | ||
HAN layer. | ||
Arguments | ||
--------- | ||
num_metapath : number of metapath based sub-graph | ||
in_size : input feature dimension | ||
out_size : output feature dimension | ||
layer_num_heads : number of attention heads | ||
dropout : Dropout probability | ||
Inputs | ||
------ | ||
g : DGLHeteroGraph | ||
The heterogeneous graph | ||
h : tensor | ||
Input features | ||
Outputs | ||
------- | ||
tensor | ||
The output feature | ||
""" | ||
|
||
def __init__(self, num_metapath, in_size, out_size, layer_num_heads, dropout): | ||
super(HANLayer, self).__init__() | ||
|
||
# One GAT layer for each meta path based adjacency matrix | ||
self.gat_layers = nn.ModuleList() | ||
for i in range(num_metapath): | ||
self.gat_layers.append(GATConv(in_size, out_size, layer_num_heads, | ||
dropout, dropout, activation=F.elu, | ||
allow_zero_in_degree=True)) | ||
self.semantic_attention = SemanticAttention(in_size=out_size * layer_num_heads) | ||
self.num_metapath = num_metapath | ||
|
||
def forward(self, block_list, h_list): | ||
semantic_embeddings = [] | ||
|
||
for i, block in enumerate(block_list): | ||
semantic_embeddings.append(self.gat_layers[i](block, h_list[i]).flatten(1)) | ||
semantic_embeddings = torch.stack(semantic_embeddings, dim=1) # (N, M, D * K) | ||
|
||
return self.semantic_attention(semantic_embeddings) # (N, D * K) | ||
|
||
|
||
class HAN(nn.Module): | ||
def __init__(self, num_metapath, in_size, hidden_size, out_size, num_heads, dropout): | ||
super(HAN, self).__init__() | ||
|
||
self.layers = nn.ModuleList() | ||
self.layers.append(HANLayer(num_metapath, in_size, hidden_size, num_heads[0], dropout)) | ||
for l in range(1, len(num_heads)): | ||
self.layers.append(HANLayer(num_metapath, hidden_size * num_heads[l - 1], | ||
hidden_size, num_heads[l], dropout)) | ||
self.predict = nn.Linear(hidden_size * num_heads[-1], out_size) | ||
|
||
def forward(self, g, h): | ||
for gnn in self.layers: | ||
h = gnn(g, h) | ||
|
||
return self.predict(h) | ||
|
||
|
||
class HANSampler(object): | ||
def __init__(self, g, metapath_list, num_neighbors): | ||
self.sampler_list = [] | ||
for metapath in metapath_list: | ||
# note: random walk may get same route(same edge), which will be removed in the sampled graph. | ||
# So the sampled graph's edges may be less than num_random_walks(num_neighbors). | ||
self.sampler_list.append(RandomWalkNeighborSampler(G=g, | ||
num_traversals=1, | ||
termination_prob=0, | ||
num_random_walks=num_neighbors, | ||
num_neighbors=num_neighbors, | ||
metapath=metapath)) | ||
|
||
def sample_blocks(self, seeds): | ||
block_list = [] | ||
for sampler in self.sampler_list: | ||
frontier = sampler(seeds) | ||
# add self loop | ||
frontier = dgl.remove_self_loop(frontier) | ||
frontier.add_edges(torch.tensor(seeds), torch.tensor(seeds)) | ||
block = dgl.to_block(frontier, seeds) | ||
block_list.append(block) | ||
|
||
return seeds, block_list | ||
|
||
|
||
def score(logits, labels): | ||
_, indices = torch.max(logits, dim=1) | ||
prediction = indices.long().cpu().numpy() | ||
labels = labels.cpu().numpy() | ||
|
||
accuracy = (prediction == labels).sum() / len(prediction) | ||
micro_f1 = f1_score(labels, prediction, average='micro') | ||
macro_f1 = f1_score(labels, prediction, average='macro') | ||
|
||
return accuracy, micro_f1, macro_f1 | ||
|
||
|
||
def evaluate(model, g, metapath_list, num_neighbors, features, labels, val_nid, loss_fcn, batch_size): | ||
model.eval() | ||
|
||
han_valid_sampler = HANSampler(g, metapath_list, num_neighbors=num_neighbors * 2) | ||
dataloader = DataLoader( | ||
dataset=val_nid, | ||
batch_size=batch_size, | ||
collate_fn=han_valid_sampler.sample_blocks, | ||
shuffle=False, | ||
drop_last=False, | ||
num_workers=4) | ||
correct = total = 0 | ||
prediction_list = [] | ||
labels_list = [] | ||
with torch.no_grad(): | ||
for step, (seeds, blocks) in enumerate(dataloader): | ||
h_list = load_subtensors(blocks, features) | ||
blocks = [block.to(args['device']) for block in blocks] | ||
hs = [h.to(args['device']) for h in h_list] | ||
|
||
logits = model(blocks, hs) | ||
loss = loss_fcn(logits, labels[numpy.asarray(seeds)].to(args['device'])) | ||
# get each predict label | ||
_, indices = torch.max(logits, dim=1) | ||
prediction = indices.long().cpu().numpy() | ||
labels_batch = labels[numpy.asarray(seeds)].cpu().numpy() | ||
|
||
prediction_list.append(prediction) | ||
labels_list.append(labels_batch) | ||
|
||
correct += (prediction == labels_batch).sum() | ||
total += prediction.shape[0] | ||
|
||
total_prediction = numpy.concatenate(prediction_list) | ||
total_labels = numpy.concatenate(labels_list) | ||
micro_f1 = f1_score(total_labels, total_prediction, average='micro') | ||
macro_f1 = f1_score(total_labels, total_prediction, average='macro') | ||
accuracy = correct / total | ||
|
||
return loss, accuracy, micro_f1, macro_f1 | ||
|
||
|
||
def load_subtensors(blocks, features): | ||
h_list = [] | ||
for block in blocks: | ||
input_nodes = block.srcdata[dgl.NID] | ||
h_list.append(features[input_nodes]) | ||
return h_list | ||
|
||
|
||
def main(args): | ||
# acm data | ||
if args['dataset'] == 'ACMRaw': | ||
from utils import load_data | ||
g, features, labels, n_classes, train_nid, val_nid, test_nid, train_mask, \ | ||
val_mask, test_mask = load_data('ACMRaw') | ||
metapath_list = [['pa', 'ap'], ['pf', 'fp']] | ||
else: | ||
raise NotImplementedError('Unsupported dataset {}'.format(args['dataset'])) | ||
|
||
# Is it need to set different neighbors numbers for different meta-path based graph? | ||
num_neighbors = args['num_neighbors'] | ||
han_sampler = HANSampler(g, metapath_list, num_neighbors) | ||
# Create PyTorch DataLoader for constructing blocks | ||
dataloader = DataLoader( | ||
dataset=train_nid, | ||
batch_size=args['batch_size'], | ||
collate_fn=han_sampler.sample_blocks, | ||
shuffle=True, | ||
drop_last=False, | ||
num_workers=4) | ||
|
||
model = HAN(num_metapath=len(metapath_list), | ||
in_size=features.shape[1], | ||
hidden_size=args['hidden_units'], | ||
out_size=n_classes, | ||
num_heads=args['num_heads'], | ||
dropout=args['dropout']).to(args['device']) | ||
|
||
total_params = sum(p.numel() for p in model.parameters()) | ||
print("total_params: {:d}".format(total_params)) | ||
total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) | ||
print("total trainable params: {:d}".format(total_trainable_params)) | ||
|
||
stopper = EarlyStopping(patience=args['patience']) | ||
loss_fn = torch.nn.CrossEntropyLoss() | ||
optimizer = torch.optim.Adam(model.parameters(), lr=args['lr'], | ||
weight_decay=args['weight_decay']) | ||
|
||
for epoch in range(args['num_epochs']): | ||
model.train() | ||
for step, (seeds, blocks) in enumerate(dataloader): | ||
h_list = load_subtensors(blocks, features) | ||
blocks = [block.to(args['device']) for block in blocks] | ||
hs = [h.to(args['device']) for h in h_list] | ||
|
||
logits = model(blocks, hs) | ||
loss = loss_fn(logits, labels[numpy.asarray(seeds)].to(args['device'])) | ||
|
||
optimizer.zero_grad() | ||
loss.backward() | ||
optimizer.step() | ||
|
||
# print info in each batch | ||
train_acc, train_micro_f1, train_macro_f1 = score(logits, labels[numpy.asarray(seeds)]) | ||
print( | ||
"Epoch {:d} | loss: {:.4f} | train_acc: {:.4f} | train_micro_f1: {:.4f} | train_macro_f1: {:.4f}".format( | ||
epoch + 1, loss, train_acc, train_micro_f1, train_macro_f1 | ||
)) | ||
val_loss, val_acc, val_micro_f1, val_macro_f1 = evaluate(model, g, metapath_list, num_neighbors, features, | ||
labels, val_nid, loss_fn, args['batch_size']) | ||
early_stop = stopper.step(val_loss.data.item(), val_acc, model) | ||
|
||
print('Epoch {:d} | Val loss {:.4f} | Val Accuracy {:.4f} | Val Micro f1 {:.4f} | Val Macro f1 {:.4f}'.format( | ||
epoch + 1, val_loss.item(), val_acc, val_micro_f1, val_macro_f1)) | ||
|
||
if early_stop: | ||
break | ||
|
||
stopper.load_checkpoint(model) | ||
test_loss, test_acc, test_micro_f1, test_macro_f1 = evaluate(model, g, metapath_list, num_neighbors, features, | ||
labels, test_nid, loss_fn, args['batch_size']) | ||
print('Test loss {:.4f} | Test Accuracy {:.4f} | Test Micro f1 {:.4f} | Test Macro f1 {:.4f}'.format( | ||
test_loss.item(), test_acc, test_micro_f1, test_macro_f1)) | ||
|
||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser('mini-batch HAN') | ||
parser.add_argument('-s', '--seed', type=int, default=1, | ||
help='Random seed') | ||
parser.add_argument('--batch_size', type=int, default=32) | ||
parser.add_argument('--num_neighbors', type=int, default=20) | ||
parser.add_argument('--lr', type=float, default=0.001) | ||
parser.add_argument('--num_heads', type=list, default=[8]) | ||
parser.add_argument('--hidden_units', type=int, default=8) | ||
parser.add_argument('--dropout', type=float, default=0.6) | ||
parser.add_argument('--weight_decay', type=float, default=0.001) | ||
parser.add_argument('--num_epochs', type=int, default=100) | ||
parser.add_argument('--patience', type=int, default=10) | ||
parser.add_argument('--dataset', type=str, default='ACMRaw') | ||
parser.add_argument('--device', type=str, default='cuda:0') | ||
|
||
args = parser.parse_args().__dict__ | ||
# set_random_seed(args['seed']) | ||
|
||
main(args) |