Skip to content

Commit

Permalink
Mod tsm and Test_TSM
Browse files Browse the repository at this point in the history
Remove not used code
  • Loading branch information
iucario committed Jul 7, 2022
1 parent fb99993 commit 9d21338
Show file tree
Hide file tree
Showing 2 changed files with 147 additions and 135 deletions.
133 changes: 101 additions & 32 deletions tests/test_models.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,114 @@
from collections import OrderedDict
import random
import sys
import torch
from torch.utils.data import DataLoader
from workoutdetector.models import TSM
from workoutdetector.datasets import DebugDataset
from workoutdetector.models.tsm import create_model
from workoutdetector.datasets import DebugDataset, Pipeline
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
from torch import optim
from einops import rearrange
import pandas as pd
import os
from torchvision.io import read_video

def test_TSM():
model = TSM(2, 8, base_model='resnet18', img_feature_dim=512)

class Test_TSM:

model = create_model(4, 8, 'resnet18', checkpoint=None, device='cuda')
model.eval()
i = torch.randn(4 * 8, 3, 224, 224)
y = model(i)
assert y.shape == (4, 2), y.shape

dataset = DebugDataset(2, 8, 20)
loader = DataLoader(dataset, batch_size=2, shuffle=False)
EPOCHS = 3
loss_fn = CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=0.001)
model.cuda()
model.train()
for _ in range(EPOCHS):
ckpt_path = 'checkpoints/TSM_somethingv2_RGB_resnet50_shift8_blockres_avg_segment8_e45.pth'
k400_path = 'checkpoints/TSM_kinetics_RGB_resnet50_shift8_blockres_avg_segment8_e50.pth'

def test_train(self):
model = self.model
i = torch.randn(4 * 8, 3, 224, 224)
y = model(i.cuda())
assert y.shape == (4, 4), y.shape

dataset = DebugDataset(num_class=4, num_segments=8, size=100)
loader = DataLoader(dataset, batch_size=2, shuffle=True)
EPOCHS = 10
loss_fn = CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001)
model.cuda()
model.train()
for _ in range(EPOCHS):
for x, y in loader:
x = rearrange(x, 'b t c h w -> (b t) c h w')
assert x.shape == (2 * 8, 3, 224, 224)
y_pred = model(x.cuda())
loss = loss_fn(y_pred.cpu(), y)

optimizer.zero_grad()
loss.backward()
optimizer.step()

print(loss.item(), y_pred.argmax(dim=1))

model.eval()
correct = 0
for x, y in loader:
x = rearrange(x, 'b t c h w -> (b t) c h w')
assert x.shape == (2*8, 3, 224, 224)
y_pred = model(x.cuda())
loss = loss_fn(y_pred.cpu(), y)
print(y_pred.argmax(dim=1), y)
correct += (y_pred.cpu().argmax(dim=1) == y).sum().item()

optimizer.zero_grad()
loss.backward()
optimizer.step()
acc = correct / len(loader.dataset)
assert acc > 0.5, f"Accuracy {acc} is too low"

print(loss, y_pred.argmax(dim=1))
def test_finetune(self):
num_class = 2
batch = 4
pretrained = create_model(num_class,
8,
'resnet50',
checkpoint=self.ckpt_path,
device='cpu')
pretrained.eval()
x = torch.randn(batch * 8, 3, 224, 224)
y = pretrained(x)
assert y.shape == (batch, num_class), \
f"y.shape = {y.shape}. Expected {(batch, num_class)}"

model.eval()
correct = 0
for x, y in loader:
x = rearrange(x, 'b t c h w -> (b t) c h w')
y_pred = model(x.cuda())
correct += (y_pred.cpu().argmax(dim=1) == y).sum().item()

acc = correct / len(loader.dataset)
assert acc > 0.5, f"Accuracy {acc} is too low"
# check weights
state_dict = torch.load(self.ckpt_path,
map_location=torch.device('cpu')).get('state_dict')
base_dict = OrderedDict(
('.'.join(k.split('.')[1:]), v) for k, v in state_dict.items())
for k, v in pretrained.state_dict().items():
if k in base_dict:
assert torch.allclose(v, base_dict[k]), f"{k} not equal"
else:
sys.stderr.write(k, v.shape, f"{k} is not in base_dict\n")

@torch.no_grad()
def test_k400(self):
"""Test accuracy of trained model on Kinetics400 subset Countix"""

num_samples = 50
model = create_model(400, 8, 'resnet50', checkpoint=self.k400_path)
model.eval()
label_df = pd.read_csv('datasets/kinetics400/kinetics_400_labels.csv')
data_root = 'data/Countix/videos/train'
data_df = pd.read_csv('datasets/Countix/countix_train.csv')
video_list = os.listdir(data_root)
video_ids = random.sample(video_list, num_samples)
P = Pipeline()
acc = 0
for video_id in video_ids:
gt_label = data_df.loc[data_df['video_id'] == video_id.split('.')[0],
'class'].values[0]
video = read_video(os.path.join(data_root, video_id))[0]
inp = P.transform_read_video(video)
out = model(inp.cuda()).cpu()
top5 = torch.topk(out, 5)[1].tolist()[0]
labels = [label_df.iloc[i, 1] for i in top5]
#softmax
label = labels[0]
assert out.shape == (1, 400), out.shape
if not label == gt_label:
sys.stderr.write(f"Prediction: {label} != {gt_label}\n")
acc += 1 if label == gt_label else 0
assert acc / num_samples > 0.5, f"Accuracy {acc} is too low"
149 changes: 46 additions & 103 deletions workoutdetector/models/tsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from torch import nn
from torch.nn.init import constant_, normal_
import torchvision
from .build import MODEL_REGISTRY
# from .build import MODEL_REGISTRY


class TemporalShift(nn.Module):
Expand Down Expand Up @@ -191,7 +191,7 @@ def forward(self, input):
return SegmentConsensus(self.consensus_type, self.dim)(input)


@MODEL_REGISTRY.register()
# @MODEL_REGISTRY.register()
class TSM(nn.Module):
"""TSN with temporal shift module
Input shape: (batch_size*num_segments, channel, height, width)
Expand Down Expand Up @@ -240,7 +240,6 @@ def __init__(self,
self.num_segments = num_segments
self.reshape = True
self.before_softmax = before_softmax
self.dropout = dropout
self.crop_num = crop_num
self.consensus_type = consensus_type
self.img_feature_dim = img_feature_dim
Expand All @@ -255,24 +254,17 @@ def __init__(self,

if not before_softmax and consensus_type != 'avg':
raise ValueError("Only avg consensus can be used after Softmax")

self.new_length = 1
if print_spec:
print(f"Initializing TSN with base model: {base_model}.",
"TSN Configurations:",
f"input_modality: {self.modality}",
f"num_segments: {self.num_segments}",
f"new_length: {self.new_length}",
f"consensus_module: {consensus_type}",
f"dropout_ratio: {self.dropout}",
f"img_feature_dim: {self.img_feature_dim}",
sep='\n')

std = 0.001
self._prepare_base_model(base_model)

feature_dim = self._prepare_tsn(num_class)

self.avgpool = nn.AdaptiveAvgPool2d(1)
self.dropout = nn.Dropout(p=dropout)
feature_dim = self.base_model.fc.in_features
self.fc = nn.Linear(feature_dim, num_class)
normal_(self.fc.weight, 0, std)
constant_(self.fc.bias, 0)
self.consensus = ConsensusModule(consensus_type)
self.base_model = nn.Sequential(
OrderedDict(list(self.base_model.named_children())[:-2]))

if not self.before_softmax:
self.softmax = nn.Softmax()
Expand All @@ -281,29 +273,6 @@ def __init__(self,
if partial_bn:
self.partialBN(True)

def _prepare_tsn(self, num_class: int) -> int:
feature_dim = getattr(self.base_model,
self.base_model.last_layer_name).in_features
if self.dropout == 0:
setattr(self.base_model, self.base_model.last_layer_name,
nn.Linear(feature_dim, num_class))
self.new_fc = None
else:
setattr(self.base_model, self.base_model.last_layer_name,
nn.Dropout(p=self.dropout))
self.new_fc = nn.Linear(feature_dim, num_class)

std = 0.001
if self.new_fc is None:
normal_(
getattr(self.base_model, self.base_model.last_layer_name).weight, 0, std)
constant_(getattr(self.base_model, self.base_model.last_layer_name).bias, 0)
else:
if hasattr(self.new_fc, 'weight'):
normal_(self.new_fc.weight, 0, std)
constant_(self.new_fc.bias, 0)
return feature_dim

def _prepare_base_model(self, base_model: str):
# print('=> base model: {}'.format(base_model))

Expand All @@ -322,8 +291,6 @@ def _prepare_base_model(self, base_model: str):
self.input_mean = [0.485, 0.456, 0.406]
self.input_std = [0.229, 0.224, 0.225]

self.base_model.avgpool = nn.AdaptiveAvgPool2d(1)

else:
raise ValueError('Unknown base model: {}'.format(base_model))

Expand Down Expand Up @@ -451,41 +418,17 @@ def get_optim_policies(self):
},
]

def head(self, base_out):
if self.dropout > 0:
base_out = self.new_fc(base_out)

if not self.before_softmax:
base_out = self.softmax(base_out)
if self.reshape:
if self.is_shift and self.temporal_pool:
base_out = base_out.view((-1, self.num_segments // 2) +
base_out.size()[1:])
else:
base_out = base_out.view((-1, self.num_segments) + base_out.size()[1:])
output = self.consensus(base_out)
return output.squeeze(1)

def forward(self, input_x, no_reshape=False):
base_out = self.forward_features(input_x, no_reshape)
return self.head(base_out)

def forward_features(self, input_x, reshape=False):
if reshape:
sample_len = (3 if self.modality == 'RGB' else 2) * self.new_length
base_out = self.base_model(
input_x.view((-1, sample_len) + input_x.size()[-2:]))
def forward(self, x):
o = self.base_model(x)
o = self.avgpool(o)
o = self.dropout(o)
o = self.fc(o.view(o.size(0), -1))
if self.is_shift and self.temporal_pool:
o = o.view((-1, self.num_segments // 2) + o.size()[1:])
else:
base_out = self.base_model(input_x)
return base_out

@property
def crop_size(self):
return self.input_size

@property
def scale_size(self):
return self.input_size * 256 // 224
o = o.view((-1, self.num_segments) + o.size()[1:])
output = self.consensus(o)
return output.squeeze(1)


def create_model(num_class: int = 2,
Expand Down Expand Up @@ -521,11 +464,15 @@ def create_model(num_class: int = 2,
if checkpoint is not None:
ckpt = torch.load(checkpoint, map_location=device)
state_dict = ckpt['state_dict']
dim_feature = state_dict['module.new_fc.weight'].shape
if dim_feature[0] != num_class:
state_dict['module.new_fc.weight'] = torch.zeros(num_class,
dim_feature[1]).cuda()
state_dict['module.new_fc.bias'] = torch.zeros(num_class).cuda()
fc_layer_weight = list(state_dict.keys())[-2]
fc_layer_bias = list(state_dict.keys())[-1]
dim_feature = state_dict[fc_layer_weight].shape
if dim_feature[0] == num_class:
state_dict['module.fc.weight'] = state_dict[fc_layer_weight]
state_dict['module.fc.bias'] = state_dict[fc_layer_bias]

del state_dict[fc_layer_weight]
del state_dict[fc_layer_bias]
base_dict = OrderedDict(
('.'.join(k.split('.')[1:]), v) for k, v in state_dict.items())
# replace_dict = {
Expand All @@ -536,7 +483,7 @@ def create_model(num_class: int = 2,
# if k in base_dict:
# base_dict[v] = base_dict.pop(k)

model.load_state_dict(base_dict)
model.load_state_dict(base_dict, strict=False)

model.to(device)
return model
Expand All @@ -558,22 +505,18 @@ def create_model(num_class: int = 2,
y = model(x.cuda())
print(y)

from workoutdetector.datasets import DebugDataset
dataset = DebugDataset(2, 8, size=100)
loader = torch.utils.data.DataLoader(dataset, batch_size=8, shuffle=True)
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
model.train()
model.cuda()

for _ in range(20):
for x, y in loader:
x = x.permute(0, 2, 1, 3, 4)
x = x.reshape(-1, 3, 224, 224)
y_pred = model(x.cuda())
loss = loss_fn(y_pred.cpu(), y)

optimizer.zero_grad()
loss.backward()
optimizer.step()
print(loss.item())
# checkpoint
ckpt_path = 'checkpoints/TSM_somethingv2_RGB_resnet50_shift8_blockres_avg_segment8_e45.pth'
pretrained = create_model(2, 8, 'resnet50', checkpoint=ckpt_path)
print(pretrained)

state_dict = torch.load(ckpt_path).get('state_dict')
base_dict = OrderedDict(
('.'.join(k.split('.')[1:]), v) for k, v in state_dict.items())

# check weights
for k, v in pretrained.state_dict().items():
if k in base_dict:
assert torch.allclose(v, base_dict[k]), f"{k} not equal"
else:
print(k, v.shape, f"{k} is not in base_dict")

0 comments on commit 9d21338

Please sign in to comment.