diff --git a/configs/transformer/ST_JT.py b/configs/transformer/ST_JT.py index 6e54004..b003469 100644 --- a/configs/transformer/ST_JT.py +++ b/configs/transformer/ST_JT.py @@ -14,58 +14,23 @@ depth=8, # stride1=3, # kernel_size1=5, - # graph_cfg=dict(layout='nturgb+d', mode='spatial'), + graph_cfg=dict(layout='nturgb+d', mode='spatial'), ), cls_head=dict(type='TRHead', num_classes=60, in_channels=512, dropout=0.)) dataset_type = 'PoseDataset' ann_file = 'data/nturgbd/ntu60_3danno.pkl' -# clip_len = 32 -# sample_rate = 3 -# mode = 'zero' -# train_pipeline = [ -# dict(type='PreNormalize3D'), -# dict(type='RandomScale', scale=0.1), -# dict(type='RandomRot', theta=0.3), -# dict(type='GenSkeFeat', dataset='nturgb+d', feats=['j']), -# dict(type='SampleFrames', clip_len=clip_len, frame_interval=sample_rate, -# out_of_bound_opt='repeat_last', keep_tail_frames=True), -# dict(type='PoseDecode'), -# dict(type='FormatGCNInput', num_person=2, mode=mode), -# dict(type='Collect', keys=['keypoint', 'label'], meta_keys=[]), -# dict(type='ToTensor', keys=['keypoint']) -# ] -# val_pipeline = [ -# dict(type='PreNormalize3D'), -# dict(type='GenSkeFeat', dataset='nturgb+d', feats=['j']), -# dict(type='SampleFrames', clip_len=clip_len, frame_interval=sample_rate, -# out_of_bound_opt='repeat_last', keep_tail_frames=True), -# dict(type='PoseDecode'), -# dict(type='FormatGCNInput', num_person=2, mode=mode), -# dict(type='Collect', keys=['keypoint', 'label'], meta_keys=[]), -# dict(type='ToTensor', keys=['keypoint']) -# ] -# test_pipeline = [ -# dict(type='PreNormalize3D'), -# dict(type='GenSkeFeat', dataset='nturgb+d', feats=['j']), -# dict(type='SampleFrames', clip_len=clip_len, frame_interval=sample_rate, -# out_of_bound_opt='repeat_last', keep_tail_frames=True, num_clips=10), -# dict(type='PoseDecode'), -# dict(type='FormatGCNInput', num_person=2, mode=mode), -# dict(type='Collect', keys=['keypoint', 'label'], meta_keys=[]), -# dict(type='ToTensor', keys=['keypoint']) -# ] - -clip_len = 64 +clip_len = 32 +sample_rate = 3 mode = 'zero' train_pipeline = [ dict(type='PreNormalize3D'), - # dict(type='RandomRot', theta=0.2), dict(type='RandomScale', scale=0.1), - dict(type='RandomRot'), + dict(type='RandomRot', theta=0.3), dict(type='GenSkeFeat', dataset='nturgb+d', feats=['j']), - dict(type='UniformSample', clip_len=clip_len), + dict(type='SampleFrames', clip_len=clip_len, frame_interval=sample_rate, + out_of_bound_opt='repeat_last', keep_tail_frames=True), dict(type='PoseDecode'), dict(type='FormatGCNInput', num_person=2, mode=mode), dict(type='Collect', keys=['keypoint', 'label'], meta_keys=[]), @@ -74,7 +39,8 @@ val_pipeline = [ dict(type='PreNormalize3D'), dict(type='GenSkeFeat', dataset='nturgb+d', feats=['j']), - dict(type='UniformSample', clip_len=clip_len, num_clips=1), + dict(type='SampleFrames', clip_len=clip_len, frame_interval=sample_rate, + out_of_bound_opt='repeat_last', keep_tail_frames=True), dict(type='PoseDecode'), dict(type='FormatGCNInput', num_person=2, mode=mode), dict(type='Collect', keys=['keypoint', 'label'], meta_keys=[]), @@ -83,14 +49,48 @@ test_pipeline = [ dict(type='PreNormalize3D'), dict(type='GenSkeFeat', dataset='nturgb+d', feats=['j']), - dict(type='UniformSample', clip_len=clip_len, num_clips=10), + dict(type='SampleFrames', clip_len=clip_len, frame_interval=sample_rate, + out_of_bound_opt='repeat_last', keep_tail_frames=True, num_clips=10), dict(type='PoseDecode'), dict(type='FormatGCNInput', num_person=2, mode=mode), dict(type='Collect', keys=['keypoint', 'label'], meta_keys=[]), dict(type='ToTensor', keys=['keypoint']) ] + +# clip_len = 64 +# mode = 'zero' +# train_pipeline = [ +# dict(type='PreNormalize3D'), +# # dict(type='RandomRot', theta=0.2), +# dict(type='RandomScale', scale=0.1), +# dict(type='RandomRot'), +# dict(type='GenSkeFeat', dataset='nturgb+d', feats=['j']), +# dict(type='UniformSample', clip_len=clip_len), +# dict(type='PoseDecode'), +# dict(type='FormatGCNInput', num_person=2, mode=mode), +# dict(type='Collect', keys=['keypoint', 'label'], meta_keys=[]), +# dict(type='ToTensor', keys=['keypoint']) +# ] +# val_pipeline = [ +# dict(type='PreNormalize3D'), +# dict(type='GenSkeFeat', dataset='nturgb+d', feats=['j']), +# dict(type='UniformSample', clip_len=clip_len, num_clips=1), +# dict(type='PoseDecode'), +# dict(type='FormatGCNInput', num_person=2, mode=mode), +# dict(type='Collect', keys=['keypoint', 'label'], meta_keys=[]), +# dict(type='ToTensor', keys=['keypoint']) +# ] +# test_pipeline = [ +# dict(type='PreNormalize3D'), +# dict(type='GenSkeFeat', dataset='nturgb+d', feats=['j']), +# dict(type='UniformSample', clip_len=clip_len, num_clips=10), +# dict(type='PoseDecode'), +# dict(type='FormatGCNInput', num_person=2, mode=mode), +# dict(type='Collect', keys=['keypoint', 'label'], meta_keys=[]), +# dict(type='ToTensor', keys=['keypoint']) +# ] data = dict( - videos_per_gpu=8, + videos_per_gpu=32, workers_per_gpu=8, test_dataloader=dict(videos_per_gpu=8), train=dict(type=dataset_type, ann_file=ann_file, pipeline=train_pipeline, split='xsub_train'), @@ -121,7 +121,7 @@ # runtime settings log_level = 'INFO' -work_dir = './work_dirs/lst/ntu60_xsub_3dkp/j_vanilla_variable_dim/1.9_ST_JT_87config_1' +work_dir = './work_dirs/lst/ntu60_xsub_3dkp/j_vanilla_variable_dim/1.9_ST_JT_new_attention_1' find_unused_parameters = False auto_resume = False seed = 88 diff --git a/pyskl/models/transformer/ST_JT.py b/pyskl/models/transformer/ST_JT.py index 9eb4820..8b42565 100644 --- a/pyskl/models/transformer/ST_JT.py +++ b/pyskl/models/transformer/ST_JT.py @@ -5,7 +5,7 @@ from ..builder import BACKBONES from ..gcns import unit_tcn import torch.nn.functional as F -# from ...utils import Graph +from ...utils import Graph class DynamicPosBias(nn.Module): @@ -60,33 +60,33 @@ class Attention(nn.Module): attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 proj_drop (float, optional): Dropout ratio of output. Default: 0.0 """ - def __init__(self, dim, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., + def __init__(self, A, dim, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., position_bias=True): super().__init__() self.dim = dim - # # dim的变化 - # if dim == 64: - # self.group_size = (64, 25) - # self.num_heads = 2 - # elif dim == 128: - # self.group_size = (32, 25) - # self.num_heads = 4 - # elif dim == 256: - # self.group_size = (16, 25) - # self.num_heads = 8 - # elif dim == 512: - # self.group_size = (8, 25) - # self.num_heads = 16 - - if dim == 128: + # dim的变化 + if dim == 64: self.group_size = (64, 25) - elif dim == 256: + self.num_heads = 2 + elif dim == 128: self.group_size = (32, 25) - elif dim == 512: + self.num_heads = 4 + elif dim == 256: self.group_size = (16, 25) + self.num_heads = 8 + elif dim == 512: + self.group_size = (8, 25) + self.num_heads = 16 - self.num_heads = num_heads + # if dim == 128: + # self.group_size = (64, 25) + # elif dim == 256: + # self.group_size = (32, 25) + # elif dim == 512: + # self.group_size = (16, 25) + # + # self.num_heads = num_heads head_dim = dim // num_heads self.scale = qk_scale or head_dim ** -0.5 @@ -122,6 +122,16 @@ def __init__(self, dim, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., p self.softmax = nn.Softmax(dim=-1) + self.alpha = nn.Parameter(torch.zeros(1), requires_grad=True) + + # A = A.sum(0) + # A[:, :] = 0 + # self.outer = nn.Parameter(torch.stack([torch.eye(A.shape[-1]) for _ in range(self.num_heads)], dim=0), + # requires_grad=True) + + A = A.sum(0) + self.outer = nn.Parameter(A.unsqueeze(0).expand(self.num_heads, -1, -1).repeat(1, self.group_size[0], self.group_size[0]), requires_grad=True) + def forward(self, x, mask=None): """ Args: @@ -153,7 +163,10 @@ def forward(self, x, mask=None): attn = self.attn_drop(attn) - x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + # x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + + x = ((self.alpha * attn + self.outer) @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) x = self.proj_drop(x) return x @@ -183,12 +196,12 @@ class TransformerEncoderLayer(nn.Module): Inspired by torch.nn.TransformerEncoderLayer and rwightman's timm package. """ - def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, + def __init__(self, A, d_model, nhead, dim_feedforward=2048, dropout=0.1, attention_dropout=0.1, drop_path_rate=0.1): super().__init__() self.pre_norm = nn.LayerNorm(d_model) - self.self_attn = Attention(dim=d_model, num_heads=nhead, + self.self_attn = Attention(A=A, dim=d_model, num_heads=nhead, attn_drop=attention_dropout, proj_drop=dropout) self.linear1 = nn.Linear(d_model, dim_feedforward) self.dropout1 = nn.Dropout(dropout) @@ -253,7 +266,7 @@ class ST_JT(nn.Module): """ def __init__( self, - # graph_cfg, + graph_cfg, in_channels=3, hidden_dim=64, dim_mul_layers=(4, 7), @@ -275,8 +288,10 @@ def __init__( super().__init__() # # Batch_normalization - # graph = Graph(**graph_cfg) - # A = torch.tensor(graph.A, dtype=torch.float32, requires_grad=False) + graph = Graph(**graph_cfg) + # A.size() 3,25,25 + A = torch.tensor(graph.A, dtype=torch.float32, requires_grad=False) + # self.data_bn = nn.BatchNorm1d(in_channels * A.size(1)) self.embd_layer = nn.Linear(in_channels, hidden_dim) @@ -314,7 +329,7 @@ def __init__( TemporalPooling( dim_in, dim_out, kernel_size1, stride1, pooling=temporal_pooling, with_cls=use_cls), - TransformerEncoderLayer(d_model=dim_out, nhead=num_heads, + TransformerEncoderLayer(A=A, d_model=dim_out, nhead=num_heads, dim_feedforward=dim_out * mlp_ratio, dropout=dropout_rate, attention_dropout=attention_dropout, drop_path_rate=next(dpr_iter)) ])) @@ -324,30 +339,30 @@ def __init__( self.init_weights() - def init_weights(self): - for p in self.parameters(): - if p.dim() > 1: - nn.init.xavier_uniform_(p) - # def init_weights(self): - # for m in self.modules(): - # if isinstance(m, nn.Conv2d): - # if hasattr(m, 'weight'): - # nn.init.kaiming_normal_(m.weight, mode='fan_out') - # if hasattr(m, 'bias') and m.bias is not None and isinstance(m.bias, torch.Tensor): - # nn.init.constant_(m.bias, 0) - # elif isinstance(m, nn.BatchNorm2d): - # if hasattr(m, 'weight') and m.weight is not None: - # m.weight.data.normal_(1.0, 0.02) - # if hasattr(m, 'bias') and m.bias is not None: - # m.bias.data.fill_(0) - # elif isinstance(m, nn.Linear): - # nn.init.trunc_normal_(m.weight, std=.02) - # if isinstance(m, nn.Linear) and m.bias is not None: - # nn.init.constant_(m.bias, 0) - # elif isinstance(m, nn.LayerNorm): - # nn.init.constant_(m.bias, 0) - # nn.init.constant_(m.weight, 1.0) + # for p in self.parameters(): + # if p.dim() > 1: + # nn.init.xavier_uniform_(p) + + def init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + if hasattr(m, 'weight'): + nn.init.kaiming_normal_(m.weight, mode='fan_out') + if hasattr(m, 'bias') and m.bias is not None and isinstance(m.bias, torch.Tensor): + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): + if hasattr(m, 'weight') and m.weight is not None: + m.weight.data.normal_(1.0, 0.02) + if hasattr(m, 'bias') and m.bias is not None: + m.bias.data.fill_(0) + elif isinstance(m, nn.Linear): + nn.init.trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) def forward(self, x): N, M, T, V, C = x.size()