diff --git a/net/st_gcn.py b/net/st_gcn.py index 2aafd0480..d06c94222 100644 --- a/net/st_gcn.py +++ b/net/st_gcn.py @@ -41,8 +41,10 @@ def __init__(self, in_channels, num_class, graph_args, temporal_kernel_size = 9 kernel_size = (temporal_kernel_size, spatial_kernel_size) self.data_bn = nn.BatchNorm1d(in_channels * A.size(1)) + kwargs0 = {k: v for k, v in kwargs.items() if k != 'dropout'} + print(kwargs0) self.st_gcn_networks = nn.ModuleList(( - st_gcn(in_channels, 64, kernel_size, 1, residual=False, **kwargs), + st_gcn(in_channels, 64, kernel_size, 1, residual=False, **kwargs0), st_gcn(64, 64, kernel_size, 1, **kwargs), st_gcn(64, 64, kernel_size, 1, **kwargs), st_gcn(64, 64, kernel_size, 1, **kwargs),