Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
yysijie committed Jun 5, 2018
1 parent a6f59f9 commit e1f6da7
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 9 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
work_dir: ./work_dir/demo
weights: ./work_dir/recognition/kinetics_skeleton/ST_GCN//epoch10_model.pt
weights: ./models/kinetics-st_gcn.pt

# model
model: net.st_gcn.Model
Expand All @@ -12,5 +11,4 @@ model_args:
strategy: 'spatial'

# training
device: [0]

device: [0]
12 changes: 10 additions & 2 deletions net/st_gcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,17 @@ def extract_feature(self, x):
x, _ = gcn(x, self.A * importance)

_, c, t, v = x.size()
x = x.view(N, M, c, t, v).permute(0, 2, 3, 4, 1)
feature = x.view(N, M, c, t, v).permute(0, 2, 3, 4, 1)

return x
# global pooling
x = F.avg_pool2d(x, x.size()[2:])
x = x.view(N, M, -1, 1, 1).mean(dim=1)

# prediction
x = self.fcn(x)
x = x.view(x.size(0), -1)

return x, feature

class st_gcn(nn.Module):
r"""Applies a spatial temporal graph convolution over an input graph sequence.
Expand Down
6 changes: 3 additions & 3 deletions processor/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def start(self):
with open(output_sequence_path, 'w') as outfile:
json.dump(video_info, outfile)
if len(video_info['data']) == 0:
print('Openpose can not find human skeletons from this video.')
print('Can not find pose estimation results.')
return
else:
print('Pose estimation complete.')
Expand All @@ -60,7 +60,7 @@ def start(self):

# extract feature
print('Network forwad.')
feature = self.model.extract_feature(data)[0]
output, feature = self.model.extract_feature(data)[0]
intensity = feature.abs().sum(dim=0)
intensity = intensity.cpu().detach().numpy()

Expand Down Expand Up @@ -95,7 +95,7 @@ def get_parser(add_help=False):
parser.add_argument('--output_dir',
default='./data/demo_result',
help='Path to save results')
parser.set_defaults(config='./config/demo/demo.yaml')
parser.set_defaults(config='./config/st_gcn/kinetics_skeleton/demo.yaml')
parser.set_defaults(print_log=False)
# endregion yapf: enable

Expand Down

0 comments on commit e1f6da7

Please sign in to comment.