Skip to content

Commit

Permalink
Merge branch 'main' into tdn
Browse files Browse the repository at this point in the history
  • Loading branch information
iucario committed Jul 9, 2022
2 parents 649bf4b + b80f22d commit c34a19b
Show file tree
Hide file tree
Showing 15 changed files with 649 additions and 681 deletions.
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,3 @@ pyrightconfig.json
*.onnx
*.pth
*.pt
uniformer
77 changes: 43 additions & 34 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,35 +1,35 @@
# Workout Detector

This project uses the [MMAction2](https://github.com/open-mmlab/mmaction2)

- [x] Clean and process datasets
- [x] Action recognition
- [ ] Train on more datasets
- [x] Action detection
- [ ] Use pose estimation
- [x] Repetition counting
- [ ] Action accessment
- [ ] Action assessment

## Installation

```
git clone --recursive https://github.com/iucario/workoutdetector.git
git clone https://github.com/iucario/workoutdetector.git
cd WorkoutDetector
conda env create -f conda_env.yml
pip install openmim
mim install mmcv
pip install -r requirements.txt
pip install -e .
```

## Docker

Build image for dev environment example:

```
docker built -t workout/dev docker
docker run -it \
--gpus=all \
--shm-size=16gb \
--shm-size=32gb \
--volume="$PWD:/work" \
--volume="/home/$USER/data:/home/user/data:ro" \
-w /work \
Expand All @@ -42,12 +42,13 @@ pip install wandb pytest
sudo pip install -e .
export $PROJ_ROOT=$PWD
```

Run docker example

```
docker run --rm -it \
--gpus=all \
--shm-size=16gb \
--shm-size=32gb \
--volume="$PWD:/work" \
--volume="/home/$USER/data:/home/user/data:ro" \
workout/dev:latest python3 workoutdetector/trainer.py
Expand All @@ -61,40 +62,45 @@ docker run --rm -it \

<img src="images/demo.gif" alt="React demo" width="800"/>

Kown issue: After stopping streaming, WebSocket will disconnect. You need to refresh to restart streaming.

Going to fix the frontend React code.

## Run Gradio demo

1. Download onnx model. Same as the React demo. [OneDrive](https://1drv.ms/u/s!AiohV3HRf-34i_VY0jVJGvLeayIdjQ?e=XqAvLa)
2. Copy to checkpoints
3. `python WorkoutDetector/demo.py`
4. open http://localhost:7860/

## Inference

### Repetition counting
## Repetition counting

Two model types, image and video, can be used.

Method is naive. The transition of states is counted as one repetition. It's online counting.
Method is naive. The transition of states is counted as one repetition. It's online counting. Only previous frames are used.

### Evaluation

1. Prepare `onnx` model trained using `run.py`
2. Run script
1. Inference videos and save results to a directory. Results of each video will be saved in a JSON file.
`workoutdetector/utils/inference_count.py`
```python
ckpt = 'checkpoints/model.onnx'
model = onnxruntime.InferenceSession(ckpt, providers=['CUDAExecutionProvider'])
inference_dataset(model, ['train', 'val', 'test'], out_dir='out/tsm_rep_scores', checkpoint=ckpt)
```
python utils/inference_count.py \
--onnx ../checkpoints/tsm_video_binary_jump_jack.onnx \
--video path/to/input/video.mp4 \
-o path/to/output/video.mp4
Scores of each video are saved in `out/tsm_rep_scores/video.mp4.score.json`.
2. Evaluating mean absolute error and off-by-one accuracy
`workoutdetector/utils/eval_count.py`
```python
json_dir = 'out/tsm_rep_scores'
anno_path = 'data/RepCount/annotation.csv'
out_csv = 'out/tsm_rep_scores.csv'
main(json_dir, anno_path, out_csv, softmax=True)
analyze_count(out_csv, out_csv.replace('.csv', '_meta.csv'))
```
Results of every video are saved in `out/tsm_rep_scores.csv`.
Metrics are saved in `out/tsm_rep_scores_meta.csv`.
3. Visualization
`notebooks/rep_analysis.ipynb`

## Train an action recognition model

### Colab

Check `WorkoutDetector/tutorial.py` in [Google Colab](https://colab.research.google.com/github/iucario/WorkoutDetector/blob/main/WorkoutDetector/tutorial.ipynb)

### Local

Be sure to modify config `WorkoutDetector/settings/global_settings.py` to your project root.
Expand All @@ -115,12 +121,12 @@ See `WorkoutDetector/scripts/build_datasets.py` for details.
data/Workouts/
├── rawframes
│   ├── Countix
│   │   ├── train -> countix/rawframes/train
│   │   └── val -> countix/rawframes/val
│   │   ├── train
│   │   └── val
│   ├── RepCount
│   │   ├── test -> RepCount/rawframes/test
│   │   ├── train -> RepCount/rawframes/train
│   │   └── val -> RepCount/rawframes/val
│   │   ├── test
│   │   ├── train
│   │   └── val
│   ├── test.txt
│   ├── test_repcount.txt
│   ├── train.txt
Expand Down Expand Up @@ -196,12 +202,8 @@ Configs are in `workoutdetector/configs`.

Uses PyTorch Lightning to train a model.

## Count repetitions

### workoutdetector/utils/inference_count.py

It does not work.

- Inference every frames in a video using image model. Will write count to the `--output` file.
And save predicted scores to a JSON file in `--output` directory.
```
Expand All @@ -218,7 +220,7 @@ It does not work.
`workoutdetector/scripts/`

- `mpvscreenshot_process.py`
Before I create or find a video segment tool, I'll use this script to annotate videos.
Until I create or find a nice video segment tool, I'll use this script to annotate videos.
How to use:

1. The mpv screenshot filename template config is `screenshot-template=~/Desktop/%f_%P`
Expand All @@ -231,3 +233,10 @@ It does not work.
- `relabeled_csv_to_rawframe_list`
Use this with `mpvscreenshot_process.py` together.
Generates label files for mmaction rawframe datasets.

## Acknowledgements

This project uses pretrained models from:

- [MMAction2](https://github.com/open-mmlab/mmaction2)
- [TSM](https://hanlab.mit.edu/projects/tsm/)
319 changes: 141 additions & 178 deletions notebooks/rep_analysis.ipynb

Large diffs are not rendered by default.

22 changes: 19 additions & 3 deletions tests/test_inference_count.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import List

from pytest import fixture
from workoutdetector.utils.inference_count import eval_dataset, pred_to_count, parse_args, main
from workoutdetector.utils.inference_count import eval_dataset, pred_to_count


# TODO: load annotation.csv
Expand All @@ -28,5 +28,21 @@ def test_pred_to_count():
y4_reps = [0, 3, 7, 9]
assert pred_to_count(step=step, preds=x4) == (y4_count, [x * step for x in y4_reps])

x5 = [0,0,0,1,1,1,1,0,0,1,1,0,0,1,1,0,1,1]
y5_count = 3
x5 = [
-1, -1, 9, 9, 8, -1, -1, -1, -1, -1, -1, 6, 6, 7, 6, 6, 7, 6, 6, 7, -1, -1, -1,
-1, -1, -1, -1
]
y5_count = 3
pred_count, pred_rep = pred_to_count(preds=x5, step=8)
assert pred_count == y5_count

x6 = [
2, 3, 3, 2, 3, 3, 3, 2, 3, 3, 2, 2, 3, 3, 2, 2, 3, 3, 2, 2, 3, 3, 2, 3, 3, 2, 2,
3, 3, 2, 2, 3, 3, 2, 2, 3, 3, -1
]
y5_count = 10
y5_rep = [
0, 8, 24, 32, 56, 64, 80, 96, 112, 128, 144, 160, 176, 184, 200, 216, 232, 248,
264, 280
]
assert pred_to_count(preds=x6, step=8) == (y5_count, y5_rep)
3 changes: 2 additions & 1 deletion tests/test_repcount_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
ACTIONS = [
'situp', 'push_up', 'pull_up', 'bench_pressing', 'jump_jack', 'squat', 'front_raise'
]
DATA_ROOT = os.path.join(PROJ_ROOT, 'data/RepCount')


class TestRepcountHelper:
"""Test RepcountHelper"""
DATA_ROOT = os.path.join(PROJ_ROOT, 'data/RepCount')
helper = RepcountHelper(DATA_ROOT, REPCOUNT_ANNO_PATH)
all_ = helper.get_rep_data(split=SPLITS, action=ACTIONS)

Expand Down Expand Up @@ -86,6 +86,7 @@ def test_RepcountHelper_eval_count(self):


def test_RepcountRecognitionDataset():
DATA_ROOT = os.path.join(PROJ_ROOT, 'data/RepCount')
actions = ['push_up', 'situp', 'squat', 'jump_jack', 'pull_up']
for split in ['train', 'val', 'test']:
dataset = RepcountRecognitionDataset(DATA_ROOT,
Expand Down
13 changes: 6 additions & 7 deletions tests/test_trainer.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from workoutdetector.settings import PROJ_ROOT, REPCOUNT_ANNO_PATH
import torch
from torchvision.io import read_image
import torchvision.transforms.functional as TF
import os
import os.path as osp
from os.path import join as osj
from fvcore.common.config import CfgNode
from workoutdetector.trainer import DataModule
from tempfile import TemporaryDirectory
from workoutdetector.trainer import train, DataModule

import torch
import torchvision.transforms.functional as TF
from fvcore.common.config import CfgNode
from torchvision.io import read_image
from workoutdetector.trainer import DataModule, train


def _check_data(loader: torch.utils.data.DataLoader, num_class: int = 12):
Expand Down
9 changes: 5 additions & 4 deletions tests/test_transform.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import os
from os.path import join as osj

import pytest
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.io import read_image
import os
from os.path import join as osj
from workoutdetector.datasets.transform import Detector, PersonCrop
from workoutdetector.settings import PROJ_ROOT
import pytest


def test_Detector():
Expand Down Expand Up @@ -34,4 +35,4 @@ def test_PersonCrop():
y = func(img)
except Exception as e:
pytest.fail(e)
assert y.shape[:-2] == img.shape[:-2]
assert y.shape[:-2] == img.shape[:-2]
43 changes: 19 additions & 24 deletions workoutdetector/models/tsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@

class TemporalShift(nn.Module):

def __init__(self, net, n_segment=3, n_div=8, inplace=False):
def __init__(self,
net: nn.Module,
n_segment: int = 3,
n_div: int = 8,
inplace: bool = False):
super(TemporalShift, self).__init__()
self.net = net
self.n_segment = n_segment
Expand Down Expand Up @@ -97,49 +101,40 @@ def temporal_pool(x, n_segment):
return x


def make_temporal_shift(net, n_segment, n_div=8, place='blockres', temporal_pool=False):
def make_temporal_shift(net: nn.Module,
n_segment: int,
n_div=8,
place='blockres',
temporal_pool=False):
if temporal_pool:
n_segment_list = [n_segment, n_segment // 2, n_segment // 2, n_segment // 2]
else:
n_segment_list = [n_segment] * 4
assert n_segment_list[-1] > 0
# print('=> n_segment per stage: {}'.format(n_segment_list))

if isinstance(net, torchvision.models.ResNet):
if place == 'block':

def make_block_temporal(stage, this_segment):
blocks = list(stage.children())
for j, seg in enumerate(n_segment_list, 1):
blocks = list(getattr(net, f'layer{j}').children())
# print('=> Processing stage with {} blocks'.format(len(blocks)))
for i, b in enumerate(blocks):
blocks[i] = TemporalShift(b, n_segment=this_segment, n_div=n_div)
return nn.Sequential(*(blocks))
blocks[i] = TemporalShift(b, n_segment=seg, n_div=n_div)

net.layer1 = make_block_temporal(net.layer1, n_segment_list[0])
net.layer2 = make_block_temporal(net.layer2, n_segment_list[1])
net.layer3 = make_block_temporal(net.layer3, n_segment_list[2])
net.layer4 = make_block_temporal(net.layer4, n_segment_list[3])
setattr(net, f'layer{j}', nn.Sequential(*(blocks)))

elif 'blockres' in place:
n_round = 1
if len(list(net.layer3.children())) >= 23:
n_round = 2
# print('=> Using n_round {} to insert temporal shift'.format(n_round))

def make_block_temporal(stage, this_segment):
blocks = list(stage.children())
for j, seg in enumerate(n_segment_list, 1):
blocks = list(getattr(net, f'layer{j}').children())
# print('=> Processing stage with {} blocks residual'.format(len(blocks)))
for i, b in enumerate(blocks):
if i % n_round == 0:
blocks[i].conv1 = TemporalShift(b.conv1,
n_segment=this_segment,
n_segment=seg,
n_div=n_div)
return nn.Sequential(*blocks)

net.layer1 = make_block_temporal(net.layer1, n_segment_list[0])
net.layer2 = make_block_temporal(net.layer2, n_segment_list[1])
net.layer3 = make_block_temporal(net.layer3, n_segment_list[2])
net.layer4 = make_block_temporal(net.layer4, n_segment_list[3])
setattr(net, f'layer{j}', nn.Sequential(*(blocks)))
else:
raise NotImplementedError(place)

Expand Down Expand Up @@ -499,7 +494,7 @@ def create_model(num_class: int = 2,
# checkpoint
ckpt_path = 'checkpoints/TSM_somethingv2_RGB_resnet50_shift8_blockres_avg_segment8_e45.pth'
pretrained = create_model(2, 8, 'resnet50', checkpoint=ckpt_path)
print(pretrained)
# print(pretrained)

state_dict = torch.load(ckpt_path).get('state_dict')
base_dict = OrderedDict(
Expand Down
Loading

0 comments on commit c34a19b

Please sign in to comment.