Skip to content

Commit

Permalink
[Example][Bug Fix] Improve DiffPool (dmlc#2730)
Browse files Browse the repository at this point in the history
* change DiffPoolBatchedGraphLayer

* fix bug and add benchmark

* upt

* upt

* upt

* upt

Co-authored-by: Tong He <[email protected]>
  • Loading branch information
lygztq and hetong007 authored Mar 9, 2021
1 parent 3317522 commit 91cb347
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 37 deletions.
36 changes: 33 additions & 3 deletions examples/pytorch/diffpool/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,44 @@ How to run
----------

```bash
python train.py --dataset ENZYMES --pool_ratio 0.10 --num_pool 1
python train.py --dataset DD --pool_ratio 0.15 --num_pool 1
python train.py --dataset ENZYMES --pool_ratio 0.10 --num_pool 1 --epochs 1000
python train.py --dataset DD --pool_ratio 0.15 --num_pool 1 --batch-size 10
```
Performance
-----------
ENZYMES 63.33% (with early stopping)
DD 79.31% (with early stopping)


## Dependencies
## Update (2021-03-09)

**Changes:**

* Fix bug in Diffpool: the wrong `assign_dim` parameter
* Improve efficiency of DiffPool, make the model independent of batch size. Remove redundant computation.


**Efficiency:**

On V100-SXM2 16GB

| | Train time/epoch (original) (s) | Train time/epoch (improved) (s) |
| ------------------ | ------------------------------: | ------------------------------: |
| DD (batch_size=10) | 21.302 | **17.282** |
| DD (batch_size=20) | OOM | **44.682** |
| ENZYMES | 1.749 | **1.685** |

| | Memory usage (original) (MB) | Memory usage (improved) (MB) |
| ------------------ | ---------------------------: | ---------------------------: |
| DD (batch_size=10) | 5274.620 | **2928.568** |
| DD (batch_size=20) | OOM | **10088.889** |
| ENZYMES | 25.685 | **21.909** |

**Accuracy**

Each experiment with improved model is only conducted once, thus the result may has noise.

| | Original | Improved |
| ------- | ---------: | ---------: |
| DD | **79.31%** | 78.33% |
| ENZYMES | 63.33% | **68.33%** |
27 changes: 7 additions & 20 deletions examples/pytorch/diffpool/model/dgl_layers/gnn.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from scipy.linalg import block_diag

Expand Down Expand Up @@ -101,27 +102,13 @@ def __init__(self, input_dim, assign_dim, output_feat_dim,
self.reg_loss.append(EntropyLoss())

def forward(self, g, h):
feat = self.feat_gc(g, h)
assign_tensor = self.pool_gc(g, h)
feat = self.feat_gc(g, h) # size = (sum_N, F_out), sum_N is num of nodes in this batch
device = feat.device
assign_tensor_masks = []
batch_size = len(g.batch_num_nodes())
for g_n_nodes in g.batch_num_nodes():
mask = torch.ones((g_n_nodes,
int(assign_tensor.size()[1] / batch_size)))
assign_tensor_masks.append(mask)
"""
The first pooling layer is computed on batched graph.
We first take the adjacency matrix of the batched graph, which is block-wise diagonal.
We then compute the assignment matrix for the whole batch graph, which will also be block diagonal
"""
mask = torch.FloatTensor(
block_diag(
*
assign_tensor_masks)).to(
device=device)
assign_tensor = masked_softmax(assign_tensor, mask,
memory_efficient=False)
assign_tensor = self.pool_gc(g, h) # size = (sum_N, N_a), N_a is num of nodes in pooled graph.
assign_tensor = F.softmax(assign_tensor, dim=1)
assign_tensor = torch.split(assign_tensor, g.batch_num_nodes().tolist())
assign_tensor = torch.block_diag(*assign_tensor) # size = (sum_N, batch_size * N_a)

h = torch.matmul(torch.t(assign_tensor), feat)
adj = g.adjacency_matrix(transpose=False, ctx=device)
adj_new = torch.sparse.mm(adj, assign_tensor)
Expand Down
2 changes: 1 addition & 1 deletion examples/pytorch/diffpool/model/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def forward(self, g):
out_all.append(readout)

adj, h = self.first_diffpool_layer(g, g_embedding)
node_per_pool_graph = int(adj.size()[0] / self.batch_size)
node_per_pool_graph = int(adj.size()[0] / len(g.batch_num_nodes()))

h, adj = batch2tensor(adj, h, node_per_pool_graph)
h = self.gcn_forward_tensorized(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,15 @@ def __init__(self, infeat, outfeat, use_bn=True,
gain=nn.init.calculate_gain('relu'))

def forward(self, x, adj):
num_node_per_graph = adj.size(1)
if self.use_bn and not hasattr(self, 'bn'):
self.bn = nn.BatchNorm1d(adj.size(1)).to(adj.device)
self.bn = nn.BatchNorm1d(num_node_per_graph).to(adj.device)

if self.add_self:
adj = adj + torch.eye(adj.size(0)).to(adj.device)
adj = adj + torch.eye(num_node_per_graph).to(adj.device)

if self.mean:
adj = adj / adj.sum(1, keepdim=True)
adj = adj / adj.sum(-1, keepdim=True)

h_k_N = torch.matmul(adj, x)
h_k = self.W(h_k_N)
Expand Down
23 changes: 13 additions & 10 deletions examples/pytorch/diffpool/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from model.encoder import DiffPool
from data_utils import pre_process

global_train_time_per_epoch = []

def arg_parse():
'''
Expand Down Expand Up @@ -68,7 +69,7 @@ def arg_parse():
'--save_dir',
dest='save_dir',
help='model saving directory: SAVE_DICT/DATASET')
parser.add_argument('--load_epoch', dest='load_epoch', help='load trained model params from\
parser.add_argument('--load_epoch', dest='load_epoch', type=int, help='load trained model params from\
SAVE_DICT/DATASET/model-LOAD_EPOCH')
parser.add_argument('--data_mode', dest='data_mode', help='data\
preprocessing mode: default, id, degree, or one-hot\
Expand Down Expand Up @@ -113,7 +114,6 @@ def prepare_data(dataset, prog_args, train=False, pre_process=None):
return dgl.dataloading.GraphDataLoader(dataset,
batch_size=prog_args.batch_size,
shuffle=shuffle,
drop_last=True,
num_workers=prog_args.n_worker)


Expand Down Expand Up @@ -148,8 +148,7 @@ def graph_classify_task(prog_args):

# calculate assignment dimension: pool_ratio * largest graph's maximum
# number of nodes in the dataset
assign_dim = int(max_num_node * prog_args.pool_ratio) * \
prog_args.batch_size
assign_dim = int(max_num_node * prog_args.pool_ratio)
print("++++++++++MODEL STATISTICS++++++++")
print("model hidden dim is", hidden_dim)
print("model embedding dim for graph instance embedding", embedding_dim)
Expand Down Expand Up @@ -187,7 +186,7 @@ def graph_classify_task(prog_args):
prog_args,
val_dataset=val_dataloader)
result = evaluate(test_dataloader, model, prog_args, logger)
print("test accuracy {}%".format(result * 100))
print("test accuracy {:.2f}%".format(result * 100))


def train(dataset, model, prog_args, same_feat=True, val_dataset=None):
Expand All @@ -209,7 +208,7 @@ def train(dataset, model, prog_args, same_feat=True, val_dataset=None):
model.train()
accum_correct = 0
total = 0
print("EPOCH ###### {} ######".format(epoch))
print("\nEPOCH ###### {} ######".format(epoch))
computation_time = 0.0
for (batch_idx, (batch_graph, graph_labels)) in enumerate(dataloader):
for (key, value) in batch_graph.ndata.items():
Expand All @@ -234,21 +233,22 @@ def train(dataset, model, prog_args, same_feat=True, val_dataset=None):
optimizer.step()

train_accu = accum_correct / total
print("train accuracy for this epoch {} is {}%".format(epoch,
print("train accuracy for this epoch {} is {:.2f}%".format(epoch,
train_accu * 100))
elapsed_time = time.time() - begin_time
print("loss {} with epoch time {} s & computation time {} s ".format(
print("loss {:.4f} with epoch time {:.4f} s & computation time {:.4f} s ".format(
loss.item(), elapsed_time, computation_time))
global_train_time_per_epoch.append(elapsed_time)
if val_dataset is not None:
result = evaluate(val_dataset, model, prog_args)
print("validation accuracy {}%".format(result * 100))
print("validation accuracy {:.2f}%".format(result * 100))
if result >= early_stopping_logger['val_acc'] and result <=\
train_accu:
early_stopping_logger.update(best_epoch=epoch, val_acc=result)
if prog_args.save_dir is not None:
torch.save(model.state_dict(), prog_args.save_dir + "/" + prog_args.dataset
+ "/model.iter-" + str(early_stopping_logger['best_epoch']))
print("best epoch is EPOCH {}, val_acc is {}%".format(early_stopping_logger['best_epoch'],
print("best epoch is EPOCH {}, val_acc is {:.2f}%".format(early_stopping_logger['best_epoch'],
early_stopping_logger['val_acc'] * 100))
torch.cuda.empty_cache()
return early_stopping_logger
Expand Down Expand Up @@ -287,6 +287,9 @@ def main():
print(prog_args)
graph_classify_task(prog_args)

print("Train time per epoch: {:.4f}".format( sum(global_train_time_per_epoch) / len(global_train_time_per_epoch) ))
print("Max memory usage: {:.4f}".format(torch.cuda.max_memory_allocated(0) / (1024 * 1024)))


if __name__ == "__main__":
main()

0 comments on commit 91cb347

Please sign in to comment.