-
Notifications
You must be signed in to change notification settings - Fork 44
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #4 from Vicky-51/main
A Modified Version for Batch Training
- Loading branch information
Showing
17 changed files
with
14,928 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from layers import cheb_conv | ||
import numpy as np | ||
|
||
|
||
class GCN(nn.Module): | ||
def __init__(self, nfeat, nhid, nclass, adj, cheb_K, dropout): | ||
super(GCN, self).__init__() | ||
|
||
self.gc1 = cheb_conv(nfeat, nhid, adj, cheb_K) | ||
self.gc2 = cheb_conv(nhid, nclass, adj, cheb_K) | ||
self.dropout = dropout | ||
|
||
def forward(self, x, adj): | ||
x = F.relu(self.gc1(x)) | ||
x = F.dropout(x, self.dropout, training=self.training) | ||
x = self.gc2(x) | ||
return F.log_softmax(x, dim=1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
V_25文件:取PEMS数据集25个传感器。 | ||
行:表示时间,共12672行。时间间隔5min,一天288个5min,2012年5月和6月工作日共44天。288*44(天)=12672(行) | ||
列:传感器数量,共25列,代表25个传感器。 | ||
|
||
|
||
W_25文件:邻接矩阵,无向图。 | ||
边表示传感器间距离。25行,25列。 | ||
|
||
|
||
github限制,完整 PEMSD7 加QQ:935140093 |
Large diffs are not rendered by default.
Oops, something went wrong.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
0,3170,8730,11900,7760,19900,18400,2210,5890,16100,13600,4270,9100,25400,13500,5370,16000,15800,18100,17600,15500,15200,19400,15200,18700 | ||
3170,0,5630,8750,4700,16700,15300,1040,2910,13000,10500,1330,5980,22300,10400,2390,13500,13200,15000,15300,12300,12300,16300,12300,15600 | ||
8730,5630,0,3280,1040,11500,10000,6660,2910,7630,5070,4470,387,16900,4970,3390,8920,8340,10200,11000,6740,7080,11800,6880,10900 | ||
11900,8750,3280,0,4320,8190,6760,9760,6170,4360,1790,7690,2890,13600,1690,6640,8210,7390,7130,10300,3640,5590,8770,4980,7880 | ||
7760,4700,1040,4320,0,12500,11100,5730,1890,8670,6100,3490,1420,17900,6000,2390,9350,8850,11200,11300,7740,7770,12700,7660,11900 | ||
19900,16700,11500,8190,12500,0,1450,17700,14300,3840,6400,15800,11100,5560,6500,14800,11900,11000,2950,13400,5410,9230,3470,8290,3110 | ||
18400,15300,10000,6760,11100,1450,0,16200,12900,2410,4970,14300,9640,6990,5060,13300,11000,10100,2370,12700,4170,8240,3550,7270,2860 | ||
2210,1040,6660,9760,5730,17700,16200,0,3940,13900,11400,2330,7010,23200,11400,3420,14500,14100,15800,16200,13400,13400,17200,13300,16500 | ||
5890,2910,2910,6170,1890,14300,12900,3940,0,10500,7940,1630,3290,19800,7850,523,10700,10300,12900,12600,9630,9420,14400,9410,13600 | ||
16100,13000,7630,4360,8670,3840,2410,13900,10500,0,2570,12000,7240,9310,2660,11000,9440,8490,3480,11300,2190,6510,5120,5540,4240 | ||
13600,10500,5070,1790,6100,6400,4970,11400,7940,2570,0,9430,4680,11900,96.5,8410,8530,7630,5480,10600,2330,5660,7140,4820,6240 | ||
4270,1330,4470,7690,3490,15800,14300,2330,1630,12000,9430,0,4840,21300,9340,1110,12200,11800,14200,14000,11200,11000,15600,11000,14800 | ||
9100,5980,387,2890,1420,11100,9640,7010,3290,7240,4680,4840,0,16500,4580,3770,8790,8170,9820,10800,6360,6850,11400,6610,10500 | ||
25400,22300,16900,13600,17900,5560,6990,23200,19800,9310,11900,21300,16500,0,12000,20300,15500,14700,7980,16500,10400,13400,7410,12600,7690 | ||
13500,10400,4970,1690,6000,6500,5060,11400,7850,2660,96.5,9340,4580,12000,0,8310,8520,7610,5560,10600,2390,5650,7220,4820,6320 | ||
5370,2390,3390,6640,2390,14800,13300,3420,523,11000,8410,1110,3770,20300,8310,0,11200,10800,13300,13000,10100,9940,14800,9930,13900 | ||
16000,13500,8920,8210,9350,11900,11000,14500,10700,9440,8530,12200,8790,15500,8520,11200,0,950,12900,2110,7260,2940,14400,3900,13600 | ||
15800,13200,8340,7390,8850,11000,10100,14100,10300,8490,7630,11800,8170,14700,7610,10800,950,0,11900,2960,6310,2000,13500,2950,12600 | ||
18100,15000,10200,7130,11200,2950,2370,15800,12900,3480,5480,14200,9820,7980,5560,13300,12900,11900,0,14700,5660,9970,1670,8990,773 | ||
17600,15300,11000,10300,11300,13400,12700,16200,12600,11300,10600,14000,10800,16500,10600,13000,2110,2960,14700,0,9140,4930,16200,5830,15400 | ||
15500,12300,6740,3640,7740,5410,4170,13400,9630,2190,2330,11200,6360,10400,2390,10100,7260,6310,5660,9140,0,4330,7280,3360,6410 | ||
15200,12300,7080,5590,7770,9230,8240,13400,9420,6510,5660,11000,6850,13400,5650,9940,2940,2000,9970,4930,4330,0,11600,979,10700 | ||
19400,16300,11800,8770,12700,3470,3550,17200,14400,5120,7140,15600,11400,7410,7220,14800,14400,13500,1670,16200,7280,11600,0,10600,902 | ||
15200,12300,6880,4980,7660,8290,7270,13300,9410,5540,4820,11000,6610,12600,4820,9930,3900,2950,8990,5830,3360,979,10600,0,9730 | ||
18700,15600,10900,7880,11900,3110,2860,16500,13600,4240,6240,14800,10500,7690,6320,13900,13600,12600,773,15400,6410,10700,902,9730,0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
# Modified Version for STTN Model | ||
|
||
**Several modifications have been made based on the code from https://github.com/wubin5/STTN** | ||
|
||
A Pytorch version for paper | ||
|
||
[Spatial-Temporal Transformer Networks for Traffic Flow Forecasting]: https://arxiv.org/pdf/2001.02908.pdf | ||
|
||
|
||
|
||
## Main Modification | ||
|
||
- This version allows batch training to improve training efficiency and performance. | ||
- Fix the bug in multi-head attention. | ||
- Use *Chebyshev Polynomials* to model *Fixed Graph Convolution Layer*, which is the same as original paper. | ||
- Provide positional embedding options with sine/cosine functions, which is the same as [Attention is All Your Need]: https://arxiv.org/abs/1706.03762. | ||
- Add positional embedding to all three components: key, query, value. | ||
|
||
|
||
|
||
## How to Run? | ||
|
||
### prepareData.py | ||
|
||
Data preprocessing. Adapted from [ASTGCN]: https://github.com/guoshnBJTU/ASTGCN-r-pytorch. | ||
|
||
- graph_signal_matrix_filename: The path for raw data, shape: [L, N]. Each row represents a record, N means there are N traffic sensors in total. | ||
- This code will return and save a processed .npz file in the same path. | ||
- Other parameters you may refer to the comments in python file. | ||
|
||
|
||
|
||
### train_batch.py | ||
|
||
Training process. Adapted from [ASTGCN]: https://github.com/guoshnBJTU/ASTGCN-r-pytorch. | ||
|
||
- adj_mx: The path for adjacency matrix. | ||
|
||
- params_path: The path for saving model parameters and training log. | ||
|
||
- filename: The data we got from prepareData.py. You may also use your own preprocess function. The shape of our dataloader should be [Batch_size(B), number_of_sensors(N), input_channels(C), input_length(T)] | ||
|
||
- All the training process will be saved. | ||
|
||
- You should write down the best epochs for further test. | ||
|
||
|
||
|
||
### predict_batch.py | ||
|
||
Testing process. | ||
|
||
- Similar setting as train_batch.py. | ||
- It will return and save the prediction results for test set. | ||
|
Oops, something went wrong.