-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Example] DCRNN and GaAN (dmlc#2858)
* Ready for PR * refractor code Co-authored-by: Ubuntu <[email protected]> Co-authored-by: Tianjun Xiao <[email protected]>
- Loading branch information
1 parent
3075b27
commit 3c38798
Showing
8 changed files
with
866 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
# Discrete Temporal Dynamic Graph with recurrent structure | ||
## DGL Implementation of DCRNN and GaAN paper. | ||
|
||
This DGL example implements the GNN model proposed in the paper [Diffusion Convolutional Recurrent Neural Network: Data-Driven Traffic Forecasting](https://arxiv.org/abs/1707.01926) and [GaAN:Gated Attention Networks for Learning on Large and Spatiotemporal Graphs](https://arxiv.org/pdf/1803.07294). | ||
|
||
Model implementor | ||
---------------------- | ||
This example was implemented by [Ericcsr](https://github.com/Ericcsr) during his Internship work at the AWS Shanghai AI Lab. | ||
|
||
The graph dataset used in this example | ||
--------------------------------------- | ||
METR-LA dataset. Dataset summary: | ||
- NumNodes: 207 | ||
- NumEdges: 1722 | ||
- NumFeats: 2 | ||
- TrainingSamples: 70% | ||
- ValidationSamples: 20% | ||
- TestSamples: 10% | ||
|
||
PEMS-BAY dataset. Dataset Summary: | ||
|
||
- NumNodes: 325 | ||
- NumEdges: 2694 | ||
- NumFeats: 2 | ||
- TrainingSamples: 70% | ||
- ValidationSamples: 20% | ||
- TestSamples: 10% | ||
|
||
How to run example files | ||
-------------------------------- | ||
In the dtdg folder, run | ||
|
||
**Please use `train.py`** | ||
|
||
Train the DCRNN model on METR-LA Dataset | ||
|
||
```python | ||
python train.py --dataset LA --model dcrnn | ||
``` | ||
|
||
If want to use a GPU, run | ||
|
||
```python | ||
python train.py --gpu 0 --dataset LA --model dcrnn | ||
``` | ||
|
||
if you want to use PEMS-BAY dataset | ||
|
||
```python | ||
python train.py --gpu 0 --dataset BAY --model dcrnn | ||
``` | ||
|
||
Train GaAN model | ||
|
||
```python | ||
python train.py --gpu 0 --model gaan --dataset <LA/BAY> | ||
``` | ||
|
||
|
||
Performance on METR-LA | ||
------------------------- | ||
| Models/Datasets | Test MAE | | ||
| :-------------- | --------:| | ||
| DCRNN in DGL | 2.91 | | ||
| DCRNN paper | 3.17 | | ||
| GaAN in DGL | 3.20 | | ||
| GaAN paper | 3.16 | | ||
|
||
|
||
Notice that Any Graph Convolution module can be plugged into the recurrent discrete temporal dynamic graph template to test performance; simply replace DiffConv or GaAN. | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
import os | ||
import ssl | ||
from six.moves import urllib | ||
import torch | ||
import numpy as np | ||
import dgl | ||
from torch.utils.data import Dataset, DataLoader | ||
|
||
|
||
def download_file(dataset): | ||
print("Start Downloading data: {}".format(dataset)) | ||
url = "https://s3.us-west-2.amazonaws.com/dgl-data/dataset/{}".format( | ||
dataset) | ||
print("Start Downloading File....") | ||
context = ssl._create_unverified_context() | ||
data = urllib.request.urlopen(url, context=context) | ||
with open("./data/{}".format(dataset), "wb") as handle: | ||
handle.write(data.read()) | ||
|
||
|
||
class SnapShotDataset(Dataset): | ||
def __init__(self, path, npz_file): | ||
if not os.path.exists(path+'/'+npz_file): | ||
if not os.path.exists(path): | ||
os.mkdir(path) | ||
download_file(npz_file) | ||
zipfile = np.load(path+'/'+npz_file) | ||
self.x = zipfile['x'] | ||
self.y = zipfile['y'] | ||
|
||
def __len__(self): | ||
return len(self.x) | ||
|
||
def __getitem__(self, idx): | ||
if torch.is_tensor(idx): | ||
idx = idx.tolist() | ||
|
||
return self.x[idx, ...], self.y[idx, ...] | ||
|
||
|
||
def METR_LAGraphDataset(): | ||
if not os.path.exists('data/graph_la.bin'): | ||
if not os.path.exists('data'): | ||
os.mkdir('data') | ||
download_file('graph_la.bin') | ||
g, _ = dgl.load_graphs('data/graph_la.bin') | ||
return g[0] | ||
|
||
|
||
class METR_LATrainDataset(SnapShotDataset): | ||
def __init__(self): | ||
super(METR_LATrainDataset, self).__init__('data', 'metr_la_train.npz') | ||
self.mean = self.x[..., 0].mean() | ||
self.std = self.x[..., 0].std() | ||
|
||
|
||
class METR_LATestDataset(SnapShotDataset): | ||
def __init__(self): | ||
super(METR_LATestDataset, self).__init__('data', 'metr_la_test.npz') | ||
|
||
|
||
class METR_LAValidDataset(SnapShotDataset): | ||
def __init__(self): | ||
super(METR_LAValidDataset, self).__init__('data', 'metr_la_valid.npz') | ||
|
||
|
||
def PEMS_BAYGraphDataset(): | ||
if not os.path.exists('data/graph_bay.bin'): | ||
if not os.path.exists('data'): | ||
os.mkdir('data') | ||
download_file('graph_bay.bin') | ||
g, _ = dgl.load_graphs('data/graph_bay.bin') | ||
return g[0] | ||
|
||
|
||
class PEMS_BAYTrainDataset(SnapShotDataset): | ||
def __init__(self): | ||
super(PEMS_BAYTrainDataset, self).__init__( | ||
'data', 'pems_bay_train.npz') | ||
self.mean = self.x[..., 0].mean() | ||
self.std = self.x[..., 0].std() | ||
|
||
|
||
class PEMS_BAYTestDataset(SnapShotDataset): | ||
def __init__(self): | ||
super(PEMS_BAYTestDataset, self).__init__('data', 'pems_bay_test.npz') | ||
|
||
|
||
class PEMS_BAYValidDataset(SnapShotDataset): | ||
def __init__(self): | ||
super(PEMS_BAYValidDataset, self).__init__( | ||
'data', 'pems_bay_valid.npz') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
import numpy as np | ||
import scipy.sparse as sparse | ||
import torch | ||
import torch.nn as nn | ||
import dgl | ||
from dgl.base import DGLError | ||
import dgl.function as fn | ||
|
||
|
||
class DiffConv(nn.Module): | ||
'''DiffConv is the implementation of diffusion convolution from paper DCRNN | ||
It will compute multiple diffusion matrix and perform multiple diffusion conv on it, | ||
this layer can be used for traffic prediction, pedamic model. | ||
Parameter | ||
========== | ||
in_feats : int | ||
number of input feature | ||
out_feats : int | ||
number of output feature | ||
k : int | ||
number of diffusion steps | ||
dir : str [both/in/out] | ||
direction of diffusion convolution | ||
From paper default both direction | ||
''' | ||
|
||
def __init__(self, in_feats, out_feats, k, in_graph_list, out_graph_list, dir='both'): | ||
super(DiffConv, self).__init__() | ||
self.in_feats = in_feats | ||
self.out_feats = out_feats | ||
self.k = k | ||
self.dir = dir | ||
self.num_graphs = self.k-1 if self.dir == 'both' else 2*self.k-2 | ||
self.project_fcs = nn.ModuleList() | ||
for i in range(self.num_graphs): | ||
self.project_fcs.append( | ||
nn.Linear(self.in_feats, self.out_feats, bias=False)) | ||
self.merger = nn.Parameter(torch.randn(self.num_graphs+1)) | ||
self.in_graph_list = in_graph_list | ||
self.out_graph_list = out_graph_list | ||
|
||
@staticmethod | ||
def attach_graph(g, k): | ||
device = g.device | ||
out_graph_list = [] | ||
in_graph_list = [] | ||
wadj, ind, outd = DiffConv.get_weight_matrix(g) | ||
adj = sparse.coo_matrix(wadj/outd.cpu().numpy()) | ||
outg = dgl.from_scipy(adj, eweight_name='weight').to(device) | ||
outg.edata['weight'] = outg.edata['weight'].float().to(device) | ||
out_graph_list.append(outg) | ||
for i in range(k-1): | ||
out_graph_list.append(DiffConv.diffuse( | ||
out_graph_list[-1], wadj, outd)) | ||
adj = sparse.coo_matrix(wadj.T/ind.cpu().numpy()) | ||
ing = dgl.from_scipy(adj, eweight_name='weight').to(device) | ||
ing.edata['weight'] = ing.edata['weight'].float().to(device) | ||
in_graph_list.append(ing) | ||
for i in range(k-1): | ||
in_graph_list.append(DiffConv.diffuse( | ||
in_graph_list[-1], wadj.T, ind)) | ||
return out_graph_list, in_graph_list | ||
|
||
@staticmethod | ||
def get_weight_matrix(g): | ||
adj = g.adj(scipy_fmt='coo') | ||
ind = g.in_degrees() | ||
outd = g.out_degrees() | ||
weight = g.edata['weight'] | ||
adj.data = weight.cpu().numpy() | ||
return adj, ind, outd | ||
|
||
@staticmethod | ||
def diffuse(progress_g, weighted_adj, degree): | ||
device = progress_g.device | ||
progress_adj = progress_g.adj(scipy_fmt='coo') | ||
progress_adj.data = progress_g.edata['weight'].cpu().numpy() | ||
ret_adj = sparse.coo_matrix(progress_adj@( | ||
weighted_adj/degree.cpu().numpy())) | ||
ret_graph = dgl.from_scipy(ret_adj, eweight_name='weight').to(device) | ||
ret_graph.edata['weight'] = ret_graph.edata['weight'].float().to( | ||
device) | ||
return ret_graph | ||
|
||
def forward(self, g, x): | ||
feat_list = [] | ||
if self.dir == 'both': | ||
graph_list = self.in_graph_list+self.out_graph_list | ||
elif self.dir == 'in': | ||
graph_list = self.in_graph_list | ||
elif self.dir == 'out': | ||
graph_list = self.out_graph_list | ||
|
||
for i in range(self.num_graphs): | ||
g = graph_list[i] | ||
with g.local_scope(): | ||
g.ndata['n'] = self.project_fcs[i](x) | ||
g.update_all(fn.u_mul_e('n', 'weight', 'e'), | ||
fn.sum('e', 'feat')) | ||
feat_list.append(g.ndata['feat']) | ||
# Each feat has shape [N,q_feats] | ||
feat_list.append(self.project_fcs[-1](x)) | ||
feat_list = torch.cat(feat_list).view( | ||
len(feat_list), -1, self.out_feats) | ||
ret = (self.merger*feat_list.permute(1, 2, 0)).permute(2, 0, 1).mean(0) | ||
return ret |
Oops, something went wrong.