Skip to content

Commit

Permalink
add HAN sampling demo (dmlc#3005)
Browse files Browse the repository at this point in the history
Co-authored-by: maqy1995 <maqingyun0718>
Co-authored-by: Quan (Andy) Gan <[email protected]>
  • Loading branch information
maqy1995 and BarclayII authored Jun 23, 2021
1 parent e56bbaf commit 31f4483
Showing 1 changed file with 271 additions and 0 deletions.
271 changes: 271 additions & 0 deletions examples/pytorch/han/train_sampling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,271 @@
# -*- coding: utf-8 -*-
"""
HAN mini-batch training by RandomWalkSampler.
note: This demo use RandomWalkSampler to sample neighbors, it's hard to get all neighbors when valid or test,
so we sampled twice as many neighbors during val/test than training.
"""
import dgl
import numpy
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn.pytorch import GATConv

from dgl.sampling import RandomWalkNeighborSampler
from sklearn.metrics import f1_score
from torch.utils.data import DataLoader

from model_hetero import SemanticAttention
from utils import EarlyStopping, set_random_seed


class HANLayer(torch.nn.Module):
"""
HAN layer.
Arguments
---------
num_metapath : number of metapath based sub-graph
in_size : input feature dimension
out_size : output feature dimension
layer_num_heads : number of attention heads
dropout : Dropout probability
Inputs
------
g : DGLHeteroGraph
The heterogeneous graph
h : tensor
Input features
Outputs
-------
tensor
The output feature
"""

def __init__(self, num_metapath, in_size, out_size, layer_num_heads, dropout):
super(HANLayer, self).__init__()

# One GAT layer for each meta path based adjacency matrix
self.gat_layers = nn.ModuleList()
for i in range(num_metapath):
self.gat_layers.append(GATConv(in_size, out_size, layer_num_heads,
dropout, dropout, activation=F.elu,
allow_zero_in_degree=True))
self.semantic_attention = SemanticAttention(in_size=out_size * layer_num_heads)
self.num_metapath = num_metapath

def forward(self, block_list, h_list):
semantic_embeddings = []

for i, block in enumerate(block_list):
semantic_embeddings.append(self.gat_layers[i](block, h_list[i]).flatten(1))
semantic_embeddings = torch.stack(semantic_embeddings, dim=1) # (N, M, D * K)

return self.semantic_attention(semantic_embeddings) # (N, D * K)


class HAN(nn.Module):
def __init__(self, num_metapath, in_size, hidden_size, out_size, num_heads, dropout):
super(HAN, self).__init__()

self.layers = nn.ModuleList()
self.layers.append(HANLayer(num_metapath, in_size, hidden_size, num_heads[0], dropout))
for l in range(1, len(num_heads)):
self.layers.append(HANLayer(num_metapath, hidden_size * num_heads[l - 1],
hidden_size, num_heads[l], dropout))
self.predict = nn.Linear(hidden_size * num_heads[-1], out_size)

def forward(self, g, h):
for gnn in self.layers:
h = gnn(g, h)

return self.predict(h)


class HANSampler(object):
def __init__(self, g, metapath_list, num_neighbors):
self.sampler_list = []
for metapath in metapath_list:
# note: random walk may get same route(same edge), which will be removed in the sampled graph.
# So the sampled graph's edges may be less than num_random_walks(num_neighbors).
self.sampler_list.append(RandomWalkNeighborSampler(G=g,
num_traversals=1,
termination_prob=0,
num_random_walks=num_neighbors,
num_neighbors=num_neighbors,
metapath=metapath))

def sample_blocks(self, seeds):
block_list = []
for sampler in self.sampler_list:
frontier = sampler(seeds)
# add self loop
frontier = dgl.remove_self_loop(frontier)
frontier.add_edges(torch.tensor(seeds), torch.tensor(seeds))
block = dgl.to_block(frontier, seeds)
block_list.append(block)

return seeds, block_list


def score(logits, labels):
_, indices = torch.max(logits, dim=1)
prediction = indices.long().cpu().numpy()
labels = labels.cpu().numpy()

accuracy = (prediction == labels).sum() / len(prediction)
micro_f1 = f1_score(labels, prediction, average='micro')
macro_f1 = f1_score(labels, prediction, average='macro')

return accuracy, micro_f1, macro_f1


def evaluate(model, g, metapath_list, num_neighbors, features, labels, val_nid, loss_fcn, batch_size):
model.eval()

han_valid_sampler = HANSampler(g, metapath_list, num_neighbors=num_neighbors * 2)
dataloader = DataLoader(
dataset=val_nid,
batch_size=batch_size,
collate_fn=han_valid_sampler.sample_blocks,
shuffle=False,
drop_last=False,
num_workers=4)
correct = total = 0
prediction_list = []
labels_list = []
with torch.no_grad():
for step, (seeds, blocks) in enumerate(dataloader):
h_list = load_subtensors(blocks, features)
blocks = [block.to(args['device']) for block in blocks]
hs = [h.to(args['device']) for h in h_list]

logits = model(blocks, hs)
loss = loss_fcn(logits, labels[numpy.asarray(seeds)].to(args['device']))
# get each predict label
_, indices = torch.max(logits, dim=1)
prediction = indices.long().cpu().numpy()
labels_batch = labels[numpy.asarray(seeds)].cpu().numpy()

prediction_list.append(prediction)
labels_list.append(labels_batch)

correct += (prediction == labels_batch).sum()
total += prediction.shape[0]

total_prediction = numpy.concatenate(prediction_list)
total_labels = numpy.concatenate(labels_list)
micro_f1 = f1_score(total_labels, total_prediction, average='micro')
macro_f1 = f1_score(total_labels, total_prediction, average='macro')
accuracy = correct / total

return loss, accuracy, micro_f1, macro_f1


def load_subtensors(blocks, features):
h_list = []
for block in blocks:
input_nodes = block.srcdata[dgl.NID]
h_list.append(features[input_nodes])
return h_list


def main(args):
# acm data
if args['dataset'] == 'ACMRaw':
from utils import load_data
g, features, labels, n_classes, train_nid, val_nid, test_nid, train_mask, \
val_mask, test_mask = load_data('ACMRaw')
metapath_list = [['pa', 'ap'], ['pf', 'fp']]
else:
raise NotImplementedError('Unsupported dataset {}'.format(args['dataset']))

# Is it need to set different neighbors numbers for different meta-path based graph?
num_neighbors = args['num_neighbors']
han_sampler = HANSampler(g, metapath_list, num_neighbors)
# Create PyTorch DataLoader for constructing blocks
dataloader = DataLoader(
dataset=train_nid,
batch_size=args['batch_size'],
collate_fn=han_sampler.sample_blocks,
shuffle=True,
drop_last=False,
num_workers=4)

model = HAN(num_metapath=len(metapath_list),
in_size=features.shape[1],
hidden_size=args['hidden_units'],
out_size=n_classes,
num_heads=args['num_heads'],
dropout=args['dropout']).to(args['device'])

total_params = sum(p.numel() for p in model.parameters())
print("total_params: {:d}".format(total_params))
total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("total trainable params: {:d}".format(total_trainable_params))

stopper = EarlyStopping(patience=args['patience'])
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=args['lr'],
weight_decay=args['weight_decay'])

for epoch in range(args['num_epochs']):
model.train()
for step, (seeds, blocks) in enumerate(dataloader):
h_list = load_subtensors(blocks, features)
blocks = [block.to(args['device']) for block in blocks]
hs = [h.to(args['device']) for h in h_list]

logits = model(blocks, hs)
loss = loss_fn(logits, labels[numpy.asarray(seeds)].to(args['device']))

optimizer.zero_grad()
loss.backward()
optimizer.step()

# print info in each batch
train_acc, train_micro_f1, train_macro_f1 = score(logits, labels[numpy.asarray(seeds)])
print(
"Epoch {:d} | loss: {:.4f} | train_acc: {:.4f} | train_micro_f1: {:.4f} | train_macro_f1: {:.4f}".format(
epoch + 1, loss, train_acc, train_micro_f1, train_macro_f1
))
val_loss, val_acc, val_micro_f1, val_macro_f1 = evaluate(model, g, metapath_list, num_neighbors, features,
labels, val_nid, loss_fn, args['batch_size'])
early_stop = stopper.step(val_loss.data.item(), val_acc, model)

print('Epoch {:d} | Val loss {:.4f} | Val Accuracy {:.4f} | Val Micro f1 {:.4f} | Val Macro f1 {:.4f}'.format(
epoch + 1, val_loss.item(), val_acc, val_micro_f1, val_macro_f1))

if early_stop:
break

stopper.load_checkpoint(model)
test_loss, test_acc, test_micro_f1, test_macro_f1 = evaluate(model, g, metapath_list, num_neighbors, features,
labels, test_nid, loss_fn, args['batch_size'])
print('Test loss {:.4f} | Test Accuracy {:.4f} | Test Micro f1 {:.4f} | Test Macro f1 {:.4f}'.format(
test_loss.item(), test_acc, test_micro_f1, test_macro_f1))


if __name__ == '__main__':
parser = argparse.ArgumentParser('mini-batch HAN')
parser.add_argument('-s', '--seed', type=int, default=1,
help='Random seed')
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--num_neighbors', type=int, default=20)
parser.add_argument('--lr', type=float, default=0.001)
parser.add_argument('--num_heads', type=list, default=[8])
parser.add_argument('--hidden_units', type=int, default=8)
parser.add_argument('--dropout', type=float, default=0.6)
parser.add_argument('--weight_decay', type=float, default=0.001)
parser.add_argument('--num_epochs', type=int, default=100)
parser.add_argument('--patience', type=int, default=10)
parser.add_argument('--dataset', type=str, default='ACMRaw')
parser.add_argument('--device', type=str, default='cuda:0')

args = parser.parse_args().__dict__
# set_random_seed(args['seed'])

main(args)

0 comments on commit 31f4483

Please sign in to comment.