diff --git a/config/demo/demo.yaml b/config/st_gcn/kinetics_skeleton/demo.yaml similarity index 62% rename from config/demo/demo.yaml rename to config/st_gcn/kinetics_skeleton/demo.yaml index c40dd3eda..49e1e3489 100644 --- a/config/demo/demo.yaml +++ b/config/st_gcn/kinetics_skeleton/demo.yaml @@ -1,5 +1,4 @@ -work_dir: ./work_dir/demo -weights: ./work_dir/recognition/kinetics_skeleton/ST_GCN//epoch10_model.pt +weights: ./models/kinetics-st_gcn.pt # model model: net.st_gcn.Model @@ -12,5 +11,4 @@ model_args: strategy: 'spatial' # training -device: [0] - +device: [0] \ No newline at end of file diff --git a/net/st_gcn.py b/net/st_gcn.py index ce6003a8f..21fb694c5 100644 --- a/net/st_gcn.py +++ b/net/st_gcn.py @@ -107,9 +107,17 @@ def extract_feature(self, x): x, _ = gcn(x, self.A * importance) _, c, t, v = x.size() - x = x.view(N, M, c, t, v).permute(0, 2, 3, 4, 1) + feature = x.view(N, M, c, t, v).permute(0, 2, 3, 4, 1) - return x + # global pooling + x = F.avg_pool2d(x, x.size()[2:]) + x = x.view(N, M, -1, 1, 1).mean(dim=1) + + # prediction + x = self.fcn(x) + x = x.view(x.size(0), -1) + + return x, feature class st_gcn(nn.Module): r"""Applies a spatial temporal graph convolution over an input graph sequence. diff --git a/processor/demo.py b/processor/demo.py index 0ef653d11..82e50fc68 100644 --- a/processor/demo.py +++ b/processor/demo.py @@ -47,7 +47,7 @@ def start(self): with open(output_sequence_path, 'w') as outfile: json.dump(video_info, outfile) if len(video_info['data']) == 0: - print('Openpose can not find human skeletons from this video.') + print('Can not find pose estimation results.') return else: print('Pose estimation complete.') @@ -60,7 +60,7 @@ def start(self): # extract feature print('Network forwad.') - feature = self.model.extract_feature(data)[0] + output, feature = self.model.extract_feature(data)[0] intensity = feature.abs().sum(dim=0) intensity = intensity.cpu().detach().numpy() @@ -95,7 +95,7 @@ def get_parser(add_help=False): parser.add_argument('--output_dir', default='./data/demo_result', help='Path to save results') - parser.set_defaults(config='./config/demo/demo.yaml') + parser.set_defaults(config='./config/st_gcn/kinetics_skeleton/demo.yaml') parser.set_defaults(print_log=False) # endregion yapf: enable