Skip to content

Commit

Permalink
modify in_channels of gcn to spatial_channels
Browse files Browse the repository at this point in the history
  • Loading branch information
xumwen committed Apr 6, 2020
1 parent daf0215 commit bbf5b4f
Showing 1 changed file with 1 addition and 4 deletions.
5 changes: 1 addition & 4 deletions stgcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,7 @@ def __init__(self, in_channels, spatial_channels, out_channels,
"""
super(STGCNBlock, self).__init__()
self.temporal1 = TimeBlock(in_channels=in_channels,
out_channels=out_channels)
self.linear = nn.Linear(in_features=out_channels,
out_features=spatial_channels)
out_channels=spatial_channels)
self.gcn = GCNUnit(in_channels=spatial_channels,
out_channels=spatial_channels,
gcn_type=gcn_type,
Expand All @@ -83,7 +81,6 @@ def forward(self, X, A, edge_index, edge_weight):
t1 = self.temporal1(X)
# batch_size * timesteps -> batch_size
t21 = t1.permute(0, 2, 1, 3).contiguous().view(-1, t1.shape[1], t1.shape[3])
t21 = self.linear(t21)
t22 = F.relu(self.gcn(t21, A, edge_index, edge_weight))
# batch_size -> (batch_size, timesteps)
t23 = t22.view(t1.shape[0], t1.shape[2], t22.shape[1], t22.shape[2]).permute(0, 2, 1, 3)
Expand Down

0 comments on commit bbf5b4f

Please sign in to comment.