Skip to content

Commit

Permalink
predrop
Browse files Browse the repository at this point in the history
  • Loading branch information
yysijie committed Feb 19, 2019
1 parent 107afb1 commit 7f2789b
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 3 deletions.
87 changes: 86 additions & 1 deletion net/st_gcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class Model(nn.Module):
"""

def __init__(self, in_channels, num_class, graph_args,
edge_importance_weighting, **kwargs):
edge_importance_weighting, pre_drop=False, **kwargs):
super().__init__()

# load graph
Expand All @@ -43,6 +43,7 @@ def __init__(self, in_channels, num_class, graph_args,
self.data_bn = nn.BatchNorm1d(in_channels * A.size(1))
kwargs0 = {k: v for k, v in kwargs.items() if k != 'dropout'}
print(kwargs0)
if pre_drop: st_gcn = st_gcn_predrop
self.st_gcn_networks = nn.ModuleList((
st_gcn(in_channels, 64, kernel_size, 1, residual=False, **kwargs0),
st_gcn(64, 64, kernel_size, 1, **kwargs),
Expand Down Expand Up @@ -197,3 +198,87 @@ def forward(self, x, A):
x = self.tcn(x) + res

return self.relu(x), A


class st_gcn_predrop(nn.Module):
# the location of dropout layer is different

r"""Applies a spatial temporal graph convolution over an input graph sequence.
Args:
in_channels (int): Number of channels in the input sequence data
out_channels (int): Number of channels produced by the convolution
kernel_size (tuple): Size of the temporal convolving kernel and graph convolving kernel
stride (int, optional): Stride of the temporal convolution. Default: 1
dropout (int, optional): Dropout rate of the final output. Default: 0
residual (bool, optional): If ``True``, applies a residual mechanism. Default: ``True``
Shape:
- Input[0]: Input graph sequence in :math:`(N, in_channels, T_{in}, V)` format
- Input[1]: Input graph adjacency matrix in :math:`(K, V, V)` format
- Output[0]: Outpu graph sequence in :math:`(N, out_channels, T_{out}, V)` format
- Output[1]: Graph adjacency matrix for output data in :math:`(K, V, V)` format
where
:math:`N` is a batch size,
:math:`K` is the spatial kernel size, as :math:`K == kernel_size[1]`,
:math:`T_{in}/T_{out}` is a length of input/output sequence,
:math:`V` is the number of graph nodes.
"""

def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
dropout=0,
residual=True):
super().__init__()

assert len(kernel_size) == 2
assert kernel_size[0] % 2 == 1
padding = ((kernel_size[0] - 1) // 2, 0)

self.gcn = ConvTemporalGraphical(in_channels, out_channels,
kernel_size[1])

self.tcn = nn.Sequential(
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Dropout(dropout),
nn.Conv2d(
out_channels,
out_channels,
(kernel_size[0], 1),
(stride, 1),
padding,
),
nn.BatchNorm2d(out_channels),
)

if not residual:
self.residual = lambda x: 0

elif (in_channels == out_channels) and (stride == 1):
self.residual = lambda x: x

else:
self.residual = nn.Sequential(
nn.Conv2d(
in_channels,
out_channels,
kernel_size=1,
stride=(stride, 1)),
nn.BatchNorm2d(out_channels),
)

self.relu = nn.ReLU(inplace=True)

def forward(self, x, A):

res = self.residual(x)
x, A = self.gcn(x, A)
x = self.tcn(x) + res

return self.relu(x), A
2 changes: 0 additions & 2 deletions tools/utils/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
import numpy as np
import cv2

import tools

def video_info_parsing(video_info, num_person_in=5, num_person_out=2):
data_numpy = np.zeros((3, len(video_info['data']), 18, num_person_in))
for frame_info in video_info['data']:
Expand Down

0 comments on commit 7f2789b

Please sign in to comment.