Skip to content

Commit

Permalink
fix issue open-mmlab#304
Browse files Browse the repository at this point in the history
  • Loading branch information
yysijie committed May 24, 2020
1 parent 1e589a5 commit ce0e28a
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 23 deletions.
4 changes: 2 additions & 2 deletions configs/recognition/st_gcn_aaai18/kinetics-skeleton/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,5 @@ processor_cfg:
# debug: true

# dataloader setting
batch_size: 256
gpus: 4
batch_size: 64
gpus: 1
4 changes: 2 additions & 2 deletions configs/recognition/st_gcn_aaai18/ntu-rgbd-xsub/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,5 @@ processor_cfg:
# debug: true

# dataloader setting
batch_size: 256
gpus: 4
batch_size: 64
gpus: 1
4 changes: 2 additions & 2 deletions configs/recognition/st_gcn_aaai18/ntu-rgbd-xview/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,5 @@ processor_cfg:
# debug: true

# dataloader setting
batch_size: 256
gpus: 4
batch_size: 64
gpus: 1
5 changes: 3 additions & 2 deletions doc/GETTING_STARTED.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@ conda create -n open-mmlab python=3.7 -y
conda activate open-mmlab
```

b. Install PyTorch and torchvision following the [official instructions](https://pytorch.org/), e.g.,
b. Install PyTorch and torchvision:
``` shell
conda install pytorch torchvision -c pytorch
conda install pytorch==1.2.0 torchvision==0.4.0 -c pytorch
```
The higher versions are not covered by tests.

c. Clone mmskeleton from github:

Expand Down
31 changes: 16 additions & 15 deletions mmskeleton/models/backbones/st_gcn_aaai18.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ class ST_GCN_18(nn.Module):
:math:`V_{in}` is the number of graph nodes,
:math:`M_{in}` is the number of instance in a frame.
"""

def __init__(self,
in_channels,
num_class,
Expand All @@ -37,20 +36,25 @@ def __init__(self,

# load graph
self.graph = Graph(**graph_cfg)
A = torch.tensor(
self.graph.A, dtype=torch.float32, requires_grad=False)
A = torch.tensor(self.graph.A,
dtype=torch.float32,
requires_grad=False)
self.register_buffer('A', A)

# build networks
spatial_kernel_size = A.size(0)
temporal_kernel_size = 9
kernel_size = (temporal_kernel_size, spatial_kernel_size)
self.data_bn = nn.BatchNorm1d(
in_channels * A.size(1)) if data_bn else lambda x: x
self.data_bn = nn.BatchNorm1d(in_channels *
A.size(1)) if data_bn else lambda x: x
kwargs0 = {k: v for k, v in kwargs.items() if k != 'dropout'}
self.st_gcn_networks = nn.ModuleList((
st_gcn_block(
in_channels, 64, kernel_size, 1, residual=False, **kwargs0),
st_gcn_block(in_channels,
64,
kernel_size,
1,
residual=False,
**kwargs0),
st_gcn_block(64, 64, kernel_size, 1, **kwargs),
st_gcn_block(64, 64, kernel_size, 1, **kwargs),
st_gcn_block(64, 64, kernel_size, 1, **kwargs),
Expand All @@ -75,7 +79,6 @@ def __init__(self,
self.fcn = nn.Conv2d(256, num_class, kernel_size=1)

def forward(self, x):

# data normalization
N, C, T, V, M = x.size()
x = x.permute(0, 4, 3, 1, 2).contiguous()
Expand All @@ -85,7 +88,7 @@ def forward(self, x):
x = x.permute(0, 1, 3, 4, 2).contiguous()
x = x.view(N * M, C, T, V)

# forwad
# forward
for gcn, importance in zip(self.st_gcn_networks, self.edge_importance):
x, _ = gcn(x, self.A * importance)

Expand Down Expand Up @@ -148,7 +151,6 @@ class st_gcn_block(nn.Module):
:math:`V` is the number of graph nodes.
"""

def __init__(self,
in_channels,
out_channels,
Expand Down Expand Up @@ -187,11 +189,10 @@ def __init__(self,

else:
self.residual = nn.Sequential(
nn.Conv2d(
in_channels,
out_channels,
kernel_size=1,
stride=(stride, 1)),
nn.Conv2d(in_channels,
out_channels,
kernel_size=1,
stride=(stride, 1)),
nn.BatchNorm2d(out_channels),
)

Expand Down
5 changes: 5 additions & 0 deletions mmskeleton/processor/recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ def test(model_cfg, dataset_cfg, checkpoint, batch_size=64, gpus=1, workers=4):
prog_bar = ProgressBar(len(dataset))
for data, label in data_loader:
with torch.no_grad():
output = model(data)
output = model(data).data.cpu().numpy()

results.append(output)
labels.append(label)
for i in range(len(data)):
Expand Down Expand Up @@ -77,7 +79,9 @@ def train(
else:
model = call_obj(**model_cfg)
model.apply(weights_init)
print(111, len(model.edge_importance))
model = MMDataParallel(model, device_ids=range(gpus)).cuda()
print(222, len(model.module.edge_importance))
loss = call_obj(**loss_cfg)

# build runner
Expand All @@ -92,6 +96,7 @@ def train(

# run
workflow = [tuple(w) for w in workflow]
print(222, len(model.module.edge_importance))
runner.run(data_loaders, workflow, total_epochs, loss=loss)


Expand Down

0 comments on commit ce0e28a

Please sign in to comment.