Skip to content

Commit

Permalink
[Example] DCRNN and GaAN (dmlc#2858)
Browse files Browse the repository at this point in the history
* Ready for PR

* refractor code

Co-authored-by: Ubuntu <[email protected]>
Co-authored-by: Tianjun Xiao <[email protected]>
  • Loading branch information
3 people authored Apr 23, 2021
1 parent 3075b27 commit 3c38798
Show file tree
Hide file tree
Showing 8 changed files with 866 additions and 1 deletion.
10 changes: 9 additions & 1 deletion examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ The folder contains example implementations of selected research papers related
| [GNNExplainer: Generating Explanations for Graph Neural Networks](#gnnexplainer) | :heavy_check_mark: | | | | |
| [Interaction Networks for Learning about Objects, Relations and Physics](#graphsim) | | |:heavy_check_mark: | | |
| [Representation Learning on Graphs with Jumping Knowledge Networks](#jknet) | :heavy_check_mark: | | | | |

| [Diffusion Convolutional Recurrent Neural Network: Data-Driven Traffic Forcasting](#dcrnn) | | | :heavy_check_mark: | | |
| [GaAN: Gated Attention Networks for Learning on large and Spatiotemporal Graphs](#gaan) | | | :heavy_check_mark: | | |
## 2021

- <a name="bgnn"></a> Ivanov et al. Boost then Convolve: Gradient Boosting Meets Graph Neural Networks. [Paper link](https://openreview.net/forum?id=ebS5NUfoMKL).
Expand Down Expand Up @@ -268,6 +269,9 @@ The folder contains example implementations of selected research papers related
- Example code: [pytorch](../examples/pytorch/jknet)
- Tags: message passing, neighborhood

- <a name="gaan"></a> Zhang et al. GaAN: Gated Attention Networks for Learning on Large and Spatiotemporal Graphs. [Paper link](https://arxiv.org/abs/1803.07294).
- Example code: [pytorch](../examples/pytorch/dtgrnn)
- Tags: Static discrete temporal graph, traffic forcasting

## 2017

Expand Down Expand Up @@ -323,6 +327,10 @@ The folder contains example implementations of selected research papers related
- Example code: [PyTorch](https://github.com/awslabs/dgl-lifesci/tree/master/examples/property_prediction/alchemy)
- Tags: molecules, quantum chemistry

- <a name="dcrnn"></a> Li et al. Diffusion Convolutional Recurrent Neural Network: Data-Driven Traffic Forcasting. [Paper link](https://arxiv.org/abs/1707.01926).
- Example code: [Pytorch](../examples/pytorch/dtgrnn)
- Tags: Static discrete temporal graph, traffic forcasting.

## 2016

- <a name="ggnn"></a> Li et al. Gated Graph Sequence Neural Networks. [Paper link](https://arxiv.org/abs/1511.05493).
Expand Down
71 changes: 71 additions & 0 deletions examples/pytorch/dtgrnn/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Discrete Temporal Dynamic Graph with recurrent structure
## DGL Implementation of DCRNN and GaAN paper.

This DGL example implements the GNN model proposed in the paper [Diffusion Convolutional Recurrent Neural Network: Data-Driven Traffic Forecasting](https://arxiv.org/abs/1707.01926) and [GaAN:Gated Attention Networks for Learning on Large and Spatiotemporal Graphs](https://arxiv.org/pdf/1803.07294).

Model implementor
----------------------
This example was implemented by [Ericcsr](https://github.com/Ericcsr) during his Internship work at the AWS Shanghai AI Lab.

The graph dataset used in this example
---------------------------------------
METR-LA dataset. Dataset summary:
- NumNodes: 207
- NumEdges: 1722
- NumFeats: 2
- TrainingSamples: 70%
- ValidationSamples: 20%
- TestSamples: 10%

PEMS-BAY dataset. Dataset Summary:

- NumNodes: 325
- NumEdges: 2694
- NumFeats: 2
- TrainingSamples: 70%
- ValidationSamples: 20%
- TestSamples: 10%

How to run example files
--------------------------------
In the dtdg folder, run

**Please use `train.py`**

Train the DCRNN model on METR-LA Dataset

```python
python train.py --dataset LA --model dcrnn
```

If want to use a GPU, run

```python
python train.py --gpu 0 --dataset LA --model dcrnn
```

if you want to use PEMS-BAY dataset

```python
python train.py --gpu 0 --dataset BAY --model dcrnn
```

Train GaAN model

```python
python train.py --gpu 0 --model gaan --dataset <LA/BAY>
```


Performance on METR-LA
-------------------------
| Models/Datasets | Test MAE |
| :-------------- | --------:|
| DCRNN in DGL | 2.91 |
| DCRNN paper | 3.17 |
| GaAN in DGL | 3.20 |
| GaAN paper | 3.16 |


Notice that Any Graph Convolution module can be plugged into the recurrent discrete temporal dynamic graph template to test performance; simply replace DiffConv or GaAN.

92 changes: 92 additions & 0 deletions examples/pytorch/dtgrnn/dataloading.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import os
import ssl
from six.moves import urllib
import torch
import numpy as np
import dgl
from torch.utils.data import Dataset, DataLoader


def download_file(dataset):
print("Start Downloading data: {}".format(dataset))
url = "https://s3.us-west-2.amazonaws.com/dgl-data/dataset/{}".format(
dataset)
print("Start Downloading File....")
context = ssl._create_unverified_context()
data = urllib.request.urlopen(url, context=context)
with open("./data/{}".format(dataset), "wb") as handle:
handle.write(data.read())


class SnapShotDataset(Dataset):
def __init__(self, path, npz_file):
if not os.path.exists(path+'/'+npz_file):
if not os.path.exists(path):
os.mkdir(path)
download_file(npz_file)
zipfile = np.load(path+'/'+npz_file)
self.x = zipfile['x']
self.y = zipfile['y']

def __len__(self):
return len(self.x)

def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()

return self.x[idx, ...], self.y[idx, ...]


def METR_LAGraphDataset():
if not os.path.exists('data/graph_la.bin'):
if not os.path.exists('data'):
os.mkdir('data')
download_file('graph_la.bin')
g, _ = dgl.load_graphs('data/graph_la.bin')
return g[0]


class METR_LATrainDataset(SnapShotDataset):
def __init__(self):
super(METR_LATrainDataset, self).__init__('data', 'metr_la_train.npz')
self.mean = self.x[..., 0].mean()
self.std = self.x[..., 0].std()


class METR_LATestDataset(SnapShotDataset):
def __init__(self):
super(METR_LATestDataset, self).__init__('data', 'metr_la_test.npz')


class METR_LAValidDataset(SnapShotDataset):
def __init__(self):
super(METR_LAValidDataset, self).__init__('data', 'metr_la_valid.npz')


def PEMS_BAYGraphDataset():
if not os.path.exists('data/graph_bay.bin'):
if not os.path.exists('data'):
os.mkdir('data')
download_file('graph_bay.bin')
g, _ = dgl.load_graphs('data/graph_bay.bin')
return g[0]


class PEMS_BAYTrainDataset(SnapShotDataset):
def __init__(self):
super(PEMS_BAYTrainDataset, self).__init__(
'data', 'pems_bay_train.npz')
self.mean = self.x[..., 0].mean()
self.std = self.x[..., 0].std()


class PEMS_BAYTestDataset(SnapShotDataset):
def __init__(self):
super(PEMS_BAYTestDataset, self).__init__('data', 'pems_bay_test.npz')


class PEMS_BAYValidDataset(SnapShotDataset):
def __init__(self):
super(PEMS_BAYValidDataset, self).__init__(
'data', 'pems_bay_valid.npz')
109 changes: 109 additions & 0 deletions examples/pytorch/dtgrnn/dcrnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import numpy as np
import scipy.sparse as sparse
import torch
import torch.nn as nn
import dgl
from dgl.base import DGLError
import dgl.function as fn


class DiffConv(nn.Module):
'''DiffConv is the implementation of diffusion convolution from paper DCRNN
It will compute multiple diffusion matrix and perform multiple diffusion conv on it,
this layer can be used for traffic prediction, pedamic model.
Parameter
==========
in_feats : int
number of input feature
out_feats : int
number of output feature
k : int
number of diffusion steps
dir : str [both/in/out]
direction of diffusion convolution
From paper default both direction
'''

def __init__(self, in_feats, out_feats, k, in_graph_list, out_graph_list, dir='both'):
super(DiffConv, self).__init__()
self.in_feats = in_feats
self.out_feats = out_feats
self.k = k
self.dir = dir
self.num_graphs = self.k-1 if self.dir == 'both' else 2*self.k-2
self.project_fcs = nn.ModuleList()
for i in range(self.num_graphs):
self.project_fcs.append(
nn.Linear(self.in_feats, self.out_feats, bias=False))
self.merger = nn.Parameter(torch.randn(self.num_graphs+1))
self.in_graph_list = in_graph_list
self.out_graph_list = out_graph_list

@staticmethod
def attach_graph(g, k):
device = g.device
out_graph_list = []
in_graph_list = []
wadj, ind, outd = DiffConv.get_weight_matrix(g)
adj = sparse.coo_matrix(wadj/outd.cpu().numpy())
outg = dgl.from_scipy(adj, eweight_name='weight').to(device)
outg.edata['weight'] = outg.edata['weight'].float().to(device)
out_graph_list.append(outg)
for i in range(k-1):
out_graph_list.append(DiffConv.diffuse(
out_graph_list[-1], wadj, outd))
adj = sparse.coo_matrix(wadj.T/ind.cpu().numpy())
ing = dgl.from_scipy(adj, eweight_name='weight').to(device)
ing.edata['weight'] = ing.edata['weight'].float().to(device)
in_graph_list.append(ing)
for i in range(k-1):
in_graph_list.append(DiffConv.diffuse(
in_graph_list[-1], wadj.T, ind))
return out_graph_list, in_graph_list

@staticmethod
def get_weight_matrix(g):
adj = g.adj(scipy_fmt='coo')
ind = g.in_degrees()
outd = g.out_degrees()
weight = g.edata['weight']
adj.data = weight.cpu().numpy()
return adj, ind, outd

@staticmethod
def diffuse(progress_g, weighted_adj, degree):
device = progress_g.device
progress_adj = progress_g.adj(scipy_fmt='coo')
progress_adj.data = progress_g.edata['weight'].cpu().numpy()
ret_adj = sparse.coo_matrix(progress_adj@(
weighted_adj/degree.cpu().numpy()))
ret_graph = dgl.from_scipy(ret_adj, eweight_name='weight').to(device)
ret_graph.edata['weight'] = ret_graph.edata['weight'].float().to(
device)
return ret_graph

def forward(self, g, x):
feat_list = []
if self.dir == 'both':
graph_list = self.in_graph_list+self.out_graph_list
elif self.dir == 'in':
graph_list = self.in_graph_list
elif self.dir == 'out':
graph_list = self.out_graph_list

for i in range(self.num_graphs):
g = graph_list[i]
with g.local_scope():
g.ndata['n'] = self.project_fcs[i](x)
g.update_all(fn.u_mul_e('n', 'weight', 'e'),
fn.sum('e', 'feat'))
feat_list.append(g.ndata['feat'])
# Each feat has shape [N,q_feats]
feat_list.append(self.project_fcs[-1](x))
feat_list = torch.cat(feat_list).view(
len(feat_list), -1, self.out_feats)
ret = (self.merger*feat_list.permute(1, 2, 0)).permute(2, 0, 1).mean(0)
return ret
Loading

0 comments on commit 3c38798

Please sign in to comment.