Skip to content

Commit

Permalink
Merge branch 'lab' into bbox
Browse files Browse the repository at this point in the history
  • Loading branch information
iucario committed Jul 10, 2022
2 parents af6a5b8 + 1e7a8d3 commit 9882b39
Show file tree
Hide file tree
Showing 27 changed files with 2,107 additions and 194 deletions.
2 changes: 1 addition & 1 deletion app/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,4 +172,4 @@ def inference_video(video):

if __name__ == '__main__':

inference_video('/home/umi/projects/WorkoutDetector/example_videos/4-YmQKoHYmw.mp4')
inference_video('example_videos/4-YmQKoHYmw.mp4')
5 changes: 4 additions & 1 deletion docker/start.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
docker run -it \
--gpus=all \
--shm-size=16gb \
--shm-size=32gb \
-u $(id -u):$(id -g) \
-e PROJ_ROOT="/work" \
-e WANDB_API_KEY=$WANDB_API_KEY \
--volume="$PWD:/work" \
--volume="/home/$USER/data:/home/user/data:ro" \
-w /work \
Expand Down
19 changes: 19 additions & 0 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import os
import os.path as osp
from os.path import join as osj
from workoutdetector.datasets import build_dataset
from fvcore.common.config import CfgNode
from torch.utils.data import DataLoader


def test_ImageDataset():
Expand All @@ -15,3 +18,19 @@ def test_ImageDataset():
img, label = train_set[0]
assert len(train_set)
assert img.shape[0] == 3, f'{img.shape}'


def test_TDNDataset():
root = 'data'
anno = '/home/user/data/Binary/all-train.txt'
cfg = CfgNode(new_allowed=True)
cfg.merge_from_file('workoutdetector/configs/tdn.yaml')
batch = cfg.data.batch_size
num_seg = cfg.data.num_segments
num_frames = cfg.data.num_frames
ds = build_dataset(cfg.data, split='train')
assert len(ds), f'No data in {ds.root}'
loader = DataLoader(ds, batch_size=batch, shuffle=True, num_workers=4)
for _, (x, y) in zip(range(10), loader):
assert x.shape == (batch, num_seg * num_frames, 3, 224, 224), \
f'{x.shape} is not ({batch}, {num_seg} * {num_frames}, 3, 224, 224)'
4 changes: 2 additions & 2 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ class Test_TSM:

model = create_model(4, 8, 'resnet18', checkpoint=None, device='cuda')
model.eval()
ckpt_path = 'checkpoints/TSM_somethingv2_RGB_resnet50_shift8_blockres_avg_segment8_e45.pth'
k400_path = 'checkpoints/TSM_kinetics_RGB_resnet50_shift8_blockres_avg_segment8_e50.pth'
ckpt_path = 'checkpoints/finetune/TSM_somethingv2_RGB_resnet50_shift8_blockres_avg_segment8_e45.pth'
k400_path = 'checkpoints/finetune/TSM_kinetics_RGB_resnet50_shift8_blockres_avg_segment8_e50.pth'

def test_train(self):
model = self.model
Expand Down
121 changes: 121 additions & 0 deletions tests/test_tdn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
from collections import OrderedDict
import random
import sys
import torch
from torch.utils.data import DataLoader
from workoutdetector.models.tdn import create_model
from workoutdetector.datasets import DebugDataset, Pipeline, TDNDataset
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from torch import optim
from einops import rearrange
import pandas as pd
import os
from torchvision.io import read_video


class Test_TDN:

model = create_model(num_class=4,
num_segments=8,
base_model='resnet50',
checkpoint=None)
model.eval()
sthv2_path = 'checkpoints/finetune/tdn_sthv2_r50_8x1x1.pth'
k400_path = 'checkpoints/finetune/tdn_k400_r50_8x1x1.pth'

def test_train(self):
num_diff = 5
model = self.model
batch = 4
num_class = 4
epochs = 10
i = torch.randn(4 * num_diff * 8, 3, 224, 224)
y = model(i)
assert y.shape == (4, 4), y.shape

dataset = DebugDataset(num_class=num_class, num_segments=40, size=100)
loader = DataLoader(dataset, batch_size=batch, shuffle=True)

loss_fn = CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001)
model.cuda()
model.train()
for _ in range(epochs):
for x, y in loader:
x = rearrange(x, 'b t c h w -> (b t) c h w')
assert x.shape == (batch * num_diff * 8, 3, 224, 224)
y_pred = model(x.cuda())
loss = loss_fn(y_pred.cpu(), y)

optimizer.zero_grad()
loss.backward()
optimizer.step()

print(loss.item(), y_pred.argmax(dim=1))

model.eval()
correct = 0
for x, y in loader:
x = rearrange(x, 'b (t n) c h w -> (b t) n c h w', t=8, n=num_diff)
y_pred = model(x.cuda())
print(y_pred.argmax(dim=1), y)
correct += (y_pred.cpu().argmax(dim=1) == y).sum().item()

acc = correct / len(loader.dataset)
assert acc > 0.5, f"Accuracy {acc} is too low"

def test_finetune(self):
num_class = 2
batch = 4
num_diff = 5
pretrained = create_model(num_class, 8, 'resnet50', checkpoint=self.sthv2_path)
pretrained.eval()
x = torch.randn(batch * num_diff * 8, 3, 224, 224)
y = pretrained(x)
assert y.shape == (batch, num_class), \
f"y.shape = {y.shape}. Expected {(batch, num_class)}"

# check weights
state_dict = torch.load(self.sthv2_path,
map_location=torch.device('cpu')).get('state_dict')
base_dict = OrderedDict(
('.'.join(k.split('.')[1:]), v) for k, v in state_dict.items())
for k, v in pretrained.state_dict().items():
if k in base_dict:
assert torch.allclose(v, base_dict[k]), f"{k} not equal"
else:
sys.stderr.write(f"{k}, {v.shape}, {k} is not in base_dict\n")

@torch.no_grad()
def test_k400(self):
"""Test accuracy of trained model on Kinetics400 subset Countix"""

num_samples = 50
model = create_model(400, 8, 'resnet50', checkpoint=self.k400_path)
model.eval()
model.to('cuda')
label_df = pd.read_csv('datasets/kinetics400/kinetics_400_labels.csv')
data_root = '/home/user/data/Countix/videos/train'
data_df = pd.read_csv('datasets/Countix/countix_train.csv')
video_list = os.listdir(data_root)
video_ids = random.sample(video_list, num_samples)
P = Pipeline()
acc = 0
for video_id in video_ids:
gt_label = data_df.loc[data_df['video_id'] == video_id.split('.')[0],
'class'].values[0]
video = read_video(os.path.join(data_root, video_id))[0]
inp = P.transform_read_video(video, samples=40)
inp = rearrange(inp, '(b t n) c h w -> b t n c h w', b=1, t=8, n=5)
# inp.view((-1, 15) + inp.shape[2:])
out = model(inp.cuda()).cpu()
top5 = torch.topk(out, 5)[1].tolist()[0]
labels = [label_df.iloc[i, 1] for i in top5]
#softmax
label = labels[0]
assert out.shape == (1, 400), out.shape
if not label == gt_label:
sys.stderr.write(f"Prediction: {label} != {gt_label}\n")
acc += 1 if label == gt_label else 0
assert acc / num_samples > 0.5, f"Accuracy {acc} is too low"
6 changes: 3 additions & 3 deletions tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,21 +37,21 @@ def test_DataModule():
num_class = cfg.model.num_class

with TemporaryDirectory() as tmpdir:
cfg.trainer.defaut_root_dir = tmpdir
cfg.trainer.default_root_dir = tmpdir
cfg.log.output_dir = osp.join(tmpdir, 'logs')
datamodule = DataModule(cfg.data, is_train=True, num_class=num_class)
val_loader = datamodule.val_dataloader()
_check_data(val_loader)


def test_config():
config = 'workoutdetector/configs/repcount_12_tsm.yaml'
config = 'workoutdetector/configs/tdn.yaml'
cfg = CfgNode(new_allowed=True)
cfg.merge_from_file(config)
cfg.trainer.fast_dev_run = True
cfg.trainer.devices = 1
cfg.log.wandb.offline = True
with TemporaryDirectory() as tmpdir:
cfg.trainer.defaut_root_dir = tmpdir
cfg.trainer.default_root_dir = tmpdir
cfg.log.output_dir = osp.join(tmpdir, 'logs')
train(cfg)
11 changes: 11 additions & 0 deletions tools/dist_train.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#!/usr/bin/env bash

python -m torch.distributed.launch \
--nnodes=1 \
--node_rank=0 \
--master_addr=localhost \
--nproc_per_node=8 \
--master_port=29500 \
workoutdetector/train_rep.py \
--cfg workoutdetector/configs/tpn.py

34 changes: 20 additions & 14 deletions workoutdetector/configs/defaults.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
trainer:
default_root_dir: exp/repcount-12-tsm
default_root_dir: exp/default # where to save logs and checkpoints
max_epochs: 50
enable_checkpointing: true
num_nodes: 1
Expand All @@ -23,79 +23,85 @@ trainer:
auto_scale_batch_size: null
prepare_data_per_node: null
fast_dev_run: false

optimizer:
method: SGD
lr: 0.0015
lr: 0.005
momentum: 0.9
weight_decay: 5.0e-4
eps: 1.0e-8
lr_scheduler:
policy: StepLR
gamma: 0.1
step: 7
step: 8

model:
model_type: TSM
num_class: 12
num_segments: 8

# Frames around the center to calculate difference. Used in Temportal Difference Network
num_frames: 1
base_model: resnet50
consensus_type: avg

# I don't think it is used
img_feature_dim: 256
is_shift: true
shift_div: 8
shift_place: blockres
fc_lr5: true
temporal_pool: false
non_local: false
checkpoint: checkpoints/TSM_somethingv2_RGB_resnet50_shift8_blockres_avg_segment8_e45.pth
checkpoint: checkpoints/finetune/TSM_somethingv2_RGB_resnet50_shift8_blockres_avg_segment8_e45.pth

data:
dataset_type: FrameDataset
data_root: /home/user/data
data_root: /home/root/data
num_segments: 8
filename_tmpl: 'img_{:05}.jpg'
filename_tmpl: "img_{:05}.jpg"
anno_col: 4
batch_size: 4
train:
anno: /home/user/data/Binary/all-train.txt
anno: /home/root/data/Binary/all-train.txt
data_prefix: null
transform:
person_crop: false
val:
anno: /home/user/data/Binary/all-val.txt
anno: /home/root/data/Binary/all-val.txt
data_prefix: null
transform:
person_crop: false
test:
anno: /home/user/data/Binary/all-test.txt
anno: /home/root/data/Binary/all-test.txt
data_prefix: null
transform:
person_crop: false
num_workers: 8

log:
output_dir: exp/repcount-12-tsm
name: repcount-12-tsm
output_dir: null # os.path.join(trainer.default_root_dir, timestamp)
log_every_n_steps: 20
csv:
enable: true
tensorboard:
enable: true
wandb:
enable: false
enable: true
offline: false
project: repcount-12-tsm
name: repcount-12-tsm

callbacks:
modelcheckpoint:
save_top_k: 1
save_weights_only: false
monitor: val/acc
mode: max
dirpath: null
dirpath: null # if None, defaults to log.output_dir
early_stopping:
enable: false
patience: 10

seed: 0
train: true
timestamp: null # Will be initialized in python file
Loading

0 comments on commit 9882b39

Please sign in to comment.