Skip to content

Commit

Permalink
1.9_ST_JT_new_attention_1
Browse files Browse the repository at this point in the history
增加新的attention,alpha和outer,其中我是通过repeat来实现
  • Loading branch information
Tianrui-Li committed Jan 10, 2024
1 parent abcf7f1 commit eb47dfa
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 95 deletions.
90 changes: 45 additions & 45 deletions configs/transformer/ST_JT.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,58 +14,23 @@
depth=8,
# stride1=3,
# kernel_size1=5,
# graph_cfg=dict(layout='nturgb+d', mode='spatial'),
graph_cfg=dict(layout='nturgb+d', mode='spatial'),
),
cls_head=dict(type='TRHead', num_classes=60, in_channels=512, dropout=0.))

dataset_type = 'PoseDataset'
ann_file = 'data/nturgbd/ntu60_3danno.pkl'

# clip_len = 32
# sample_rate = 3
# mode = 'zero'
# train_pipeline = [
# dict(type='PreNormalize3D'),
# dict(type='RandomScale', scale=0.1),
# dict(type='RandomRot', theta=0.3),
# dict(type='GenSkeFeat', dataset='nturgb+d', feats=['j']),
# dict(type='SampleFrames', clip_len=clip_len, frame_interval=sample_rate,
# out_of_bound_opt='repeat_last', keep_tail_frames=True),
# dict(type='PoseDecode'),
# dict(type='FormatGCNInput', num_person=2, mode=mode),
# dict(type='Collect', keys=['keypoint', 'label'], meta_keys=[]),
# dict(type='ToTensor', keys=['keypoint'])
# ]
# val_pipeline = [
# dict(type='PreNormalize3D'),
# dict(type='GenSkeFeat', dataset='nturgb+d', feats=['j']),
# dict(type='SampleFrames', clip_len=clip_len, frame_interval=sample_rate,
# out_of_bound_opt='repeat_last', keep_tail_frames=True),
# dict(type='PoseDecode'),
# dict(type='FormatGCNInput', num_person=2, mode=mode),
# dict(type='Collect', keys=['keypoint', 'label'], meta_keys=[]),
# dict(type='ToTensor', keys=['keypoint'])
# ]
# test_pipeline = [
# dict(type='PreNormalize3D'),
# dict(type='GenSkeFeat', dataset='nturgb+d', feats=['j']),
# dict(type='SampleFrames', clip_len=clip_len, frame_interval=sample_rate,
# out_of_bound_opt='repeat_last', keep_tail_frames=True, num_clips=10),
# dict(type='PoseDecode'),
# dict(type='FormatGCNInput', num_person=2, mode=mode),
# dict(type='Collect', keys=['keypoint', 'label'], meta_keys=[]),
# dict(type='ToTensor', keys=['keypoint'])
# ]

clip_len = 64
clip_len = 32
sample_rate = 3
mode = 'zero'
train_pipeline = [
dict(type='PreNormalize3D'),
# dict(type='RandomRot', theta=0.2),
dict(type='RandomScale', scale=0.1),
dict(type='RandomRot'),
dict(type='RandomRot', theta=0.3),
dict(type='GenSkeFeat', dataset='nturgb+d', feats=['j']),
dict(type='UniformSample', clip_len=clip_len),
dict(type='SampleFrames', clip_len=clip_len, frame_interval=sample_rate,
out_of_bound_opt='repeat_last', keep_tail_frames=True),
dict(type='PoseDecode'),
dict(type='FormatGCNInput', num_person=2, mode=mode),
dict(type='Collect', keys=['keypoint', 'label'], meta_keys=[]),
Expand All @@ -74,7 +39,8 @@
val_pipeline = [
dict(type='PreNormalize3D'),
dict(type='GenSkeFeat', dataset='nturgb+d', feats=['j']),
dict(type='UniformSample', clip_len=clip_len, num_clips=1),
dict(type='SampleFrames', clip_len=clip_len, frame_interval=sample_rate,
out_of_bound_opt='repeat_last', keep_tail_frames=True),
dict(type='PoseDecode'),
dict(type='FormatGCNInput', num_person=2, mode=mode),
dict(type='Collect', keys=['keypoint', 'label'], meta_keys=[]),
Expand All @@ -83,14 +49,48 @@
test_pipeline = [
dict(type='PreNormalize3D'),
dict(type='GenSkeFeat', dataset='nturgb+d', feats=['j']),
dict(type='UniformSample', clip_len=clip_len, num_clips=10),
dict(type='SampleFrames', clip_len=clip_len, frame_interval=sample_rate,
out_of_bound_opt='repeat_last', keep_tail_frames=True, num_clips=10),
dict(type='PoseDecode'),
dict(type='FormatGCNInput', num_person=2, mode=mode),
dict(type='Collect', keys=['keypoint', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['keypoint'])
]

# clip_len = 64
# mode = 'zero'
# train_pipeline = [
# dict(type='PreNormalize3D'),
# # dict(type='RandomRot', theta=0.2),
# dict(type='RandomScale', scale=0.1),
# dict(type='RandomRot'),
# dict(type='GenSkeFeat', dataset='nturgb+d', feats=['j']),
# dict(type='UniformSample', clip_len=clip_len),
# dict(type='PoseDecode'),
# dict(type='FormatGCNInput', num_person=2, mode=mode),
# dict(type='Collect', keys=['keypoint', 'label'], meta_keys=[]),
# dict(type='ToTensor', keys=['keypoint'])
# ]
# val_pipeline = [
# dict(type='PreNormalize3D'),
# dict(type='GenSkeFeat', dataset='nturgb+d', feats=['j']),
# dict(type='UniformSample', clip_len=clip_len, num_clips=1),
# dict(type='PoseDecode'),
# dict(type='FormatGCNInput', num_person=2, mode=mode),
# dict(type='Collect', keys=['keypoint', 'label'], meta_keys=[]),
# dict(type='ToTensor', keys=['keypoint'])
# ]
# test_pipeline = [
# dict(type='PreNormalize3D'),
# dict(type='GenSkeFeat', dataset='nturgb+d', feats=['j']),
# dict(type='UniformSample', clip_len=clip_len, num_clips=10),
# dict(type='PoseDecode'),
# dict(type='FormatGCNInput', num_person=2, mode=mode),
# dict(type='Collect', keys=['keypoint', 'label'], meta_keys=[]),
# dict(type='ToTensor', keys=['keypoint'])
# ]
data = dict(
videos_per_gpu=8,
videos_per_gpu=32,
workers_per_gpu=8,
test_dataloader=dict(videos_per_gpu=8),
train=dict(type=dataset_type, ann_file=ann_file, pipeline=train_pipeline, split='xsub_train'),
Expand Down Expand Up @@ -121,7 +121,7 @@

# runtime settings
log_level = 'INFO'
work_dir = './work_dirs/lst/ntu60_xsub_3dkp/j_vanilla_variable_dim/1.9_ST_JT_87config_1'
work_dir = './work_dirs/lst/ntu60_xsub_3dkp/j_vanilla_variable_dim/1.9_ST_JT_new_attention_1'
find_unused_parameters = False
auto_resume = False
seed = 88
115 changes: 65 additions & 50 deletions pyskl/models/transformer/ST_JT.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from ..builder import BACKBONES
from ..gcns import unit_tcn
import torch.nn.functional as F
# from ...utils import Graph
from ...utils import Graph


class DynamicPosBias(nn.Module):
Expand Down Expand Up @@ -60,33 +60,33 @@ class Attention(nn.Module):
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
def __init__(self, dim, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.,
def __init__(self, A, dim, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.,
position_bias=True):
super().__init__()
self.dim = dim

# # dim的变化
# if dim == 64:
# self.group_size = (64, 25)
# self.num_heads = 2
# elif dim == 128:
# self.group_size = (32, 25)
# self.num_heads = 4
# elif dim == 256:
# self.group_size = (16, 25)
# self.num_heads = 8
# elif dim == 512:
# self.group_size = (8, 25)
# self.num_heads = 16

if dim == 128:
# dim的变化
if dim == 64:
self.group_size = (64, 25)
elif dim == 256:
self.num_heads = 2
elif dim == 128:
self.group_size = (32, 25)
elif dim == 512:
self.num_heads = 4
elif dim == 256:
self.group_size = (16, 25)
self.num_heads = 8
elif dim == 512:
self.group_size = (8, 25)
self.num_heads = 16

self.num_heads = num_heads
# if dim == 128:
# self.group_size = (64, 25)
# elif dim == 256:
# self.group_size = (32, 25)
# elif dim == 512:
# self.group_size = (16, 25)
#
# self.num_heads = num_heads

head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
Expand Down Expand Up @@ -122,6 +122,16 @@ def __init__(self, dim, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., p

self.softmax = nn.Softmax(dim=-1)

self.alpha = nn.Parameter(torch.zeros(1), requires_grad=True)

# A = A.sum(0)
# A[:, :] = 0
# self.outer = nn.Parameter(torch.stack([torch.eye(A.shape[-1]) for _ in range(self.num_heads)], dim=0),
# requires_grad=True)

A = A.sum(0)
self.outer = nn.Parameter(A.unsqueeze(0).expand(self.num_heads, -1, -1).repeat(1, self.group_size[0], self.group_size[0]), requires_grad=True)

def forward(self, x, mask=None):
"""
Args:
Expand Down Expand Up @@ -153,7 +163,10 @@ def forward(self, x, mask=None):

attn = self.attn_drop(attn)

x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
# x = (attn @ v).transpose(1, 2).reshape(B_, N, C)

x = ((self.alpha * attn + self.outer) @ v).transpose(1, 2).reshape(B_, N, C)

x = self.proj(x)
x = self.proj_drop(x)
return x
Expand Down Expand Up @@ -183,12 +196,12 @@ class TransformerEncoderLayer(nn.Module):
Inspired by torch.nn.TransformerEncoderLayer and
rwightman's timm package.
"""
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
def __init__(self, A, d_model, nhead, dim_feedforward=2048, dropout=0.1,
attention_dropout=0.1, drop_path_rate=0.1):
super().__init__()

self.pre_norm = nn.LayerNorm(d_model)
self.self_attn = Attention(dim=d_model, num_heads=nhead,
self.self_attn = Attention(A=A, dim=d_model, num_heads=nhead,
attn_drop=attention_dropout, proj_drop=dropout)
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout1 = nn.Dropout(dropout)
Expand Down Expand Up @@ -253,7 +266,7 @@ class ST_JT(nn.Module):
"""
def __init__(
self,
# graph_cfg,
graph_cfg,
in_channels=3,
hidden_dim=64,
dim_mul_layers=(4, 7),
Expand All @@ -275,8 +288,10 @@ def __init__(
super().__init__()

# # Batch_normalization
# graph = Graph(**graph_cfg)
# A = torch.tensor(graph.A, dtype=torch.float32, requires_grad=False)
graph = Graph(**graph_cfg)
# A.size() 3,25,25
A = torch.tensor(graph.A, dtype=torch.float32, requires_grad=False)

# self.data_bn = nn.BatchNorm1d(in_channels * A.size(1))

self.embd_layer = nn.Linear(in_channels, hidden_dim)
Expand Down Expand Up @@ -314,7 +329,7 @@ def __init__(
TemporalPooling(
dim_in, dim_out, kernel_size1, stride1,
pooling=temporal_pooling, with_cls=use_cls),
TransformerEncoderLayer(d_model=dim_out, nhead=num_heads,
TransformerEncoderLayer(A=A, d_model=dim_out, nhead=num_heads,
dim_feedforward=dim_out * mlp_ratio, dropout=dropout_rate,
attention_dropout=attention_dropout, drop_path_rate=next(dpr_iter))
]))
Expand All @@ -324,30 +339,30 @@ def __init__(

self.init_weights()

def init_weights(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)

# def init_weights(self):
# for m in self.modules():
# if isinstance(m, nn.Conv2d):
# if hasattr(m, 'weight'):
# nn.init.kaiming_normal_(m.weight, mode='fan_out')
# if hasattr(m, 'bias') and m.bias is not None and isinstance(m.bias, torch.Tensor):
# nn.init.constant_(m.bias, 0)
# elif isinstance(m, nn.BatchNorm2d):
# if hasattr(m, 'weight') and m.weight is not None:
# m.weight.data.normal_(1.0, 0.02)
# if hasattr(m, 'bias') and m.bias is not None:
# m.bias.data.fill_(0)
# elif isinstance(m, nn.Linear):
# nn.init.trunc_normal_(m.weight, std=.02)
# if isinstance(m, nn.Linear) and m.bias is not None:
# nn.init.constant_(m.bias, 0)
# elif isinstance(m, nn.LayerNorm):
# nn.init.constant_(m.bias, 0)
# nn.init.constant_(m.weight, 1.0)
# for p in self.parameters():
# if p.dim() > 1:
# nn.init.xavier_uniform_(p)

def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
if hasattr(m, 'weight'):
nn.init.kaiming_normal_(m.weight, mode='fan_out')
if hasattr(m, 'bias') and m.bias is not None and isinstance(m.bias, torch.Tensor):
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
if hasattr(m, 'weight') and m.weight is not None:
m.weight.data.normal_(1.0, 0.02)
if hasattr(m, 'bias') and m.bias is not None:
m.bias.data.fill_(0)
elif isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)

def forward(self, x):
N, M, T, V, C = x.size()
Expand Down

0 comments on commit eb47dfa

Please sign in to comment.