Skip to content

Commit

Permalink
add basic ST_ST
Browse files Browse the repository at this point in the history
  • Loading branch information
Tianrui-Li committed Dec 9, 2023
1 parent 5fe32c6 commit e516057
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 1 deletion.
110 changes: 110 additions & 0 deletions configs/transformer/ST_ST.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import wandb
wandb.init(project='ViViT2')

model = dict(
type='RecognizerGCN',
backbone=dict(
type='ST_JT',
),
cls_head=dict(type='TRHead', num_classes=60, in_channels=512, dropout=0.))

dataset_type = 'PoseDataset'
ann_file = 'data/nturgbd/ntu60_3danno.pkl'

clip_len = 64
sample_rate = 3
mode = 'zero'
train_pipeline = [
dict(type='PreNormalize3D'),
dict(type='RandomScale', scale=0.1),
dict(type='RandomRot', theta=0.3),
dict(type='GenSkeFeat', dataset='nturgb+d', feats=['j']),
dict(type='SampleFrames', clip_len=clip_len, frame_interval=sample_rate,
out_of_bound_opt='repeat_last', keep_tail_frames=True),
dict(type='PoseDecode'),
dict(type='FormatGCNInput', num_person=2, mode=mode),
dict(type='Collect', keys=['keypoint', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['keypoint'])
]
val_pipeline = [
dict(type='PreNormalize3D'),
dict(type='GenSkeFeat', dataset='nturgb+d', feats=['j']),
dict(type='SampleFrames', clip_len=clip_len, frame_interval=sample_rate,
out_of_bound_opt='repeat_last', keep_tail_frames=True),
dict(type='PoseDecode'),
dict(type='FormatGCNInput', num_person=2, mode=mode),
dict(type='Collect', keys=['keypoint', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['keypoint'])
]
test_pipeline = [
dict(type='PreNormalize3D'),
dict(type='GenSkeFeat', dataset='nturgb+d', feats=['j']),
dict(type='SampleFrames', clip_len=clip_len, frame_interval=sample_rate,
out_of_bound_opt='repeat_last', keep_tail_frames=True, num_clips=10),
dict(type='PoseDecode'),
dict(type='FormatGCNInput', num_person=2, mode=mode),
dict(type='Collect', keys=['keypoint', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['keypoint'])
]
# mode = 'zero'
# train_pipeline = [
# dict(type='PreNormalize3D'),
# # dict(type='RandomRot', theta=0.2),
# dict(type='RandomScale', scale=0.1),
# dict(type='RandomRot'),
# dict(type='GenSkeFeat', dataset='nturgb+d', feats=['j']),
# dict(type='UniformSample', clip_len=clip_len),
# dict(type='PoseDecode'),
# dict(type='FormatGCNInput', num_person=2, mode=mode),
# dict(type='Collect', keys=['keypoint', 'label'], meta_keys=[]),
# dict(type='ToTensor', keys=['keypoint'])
# ]
# val_pipeline = [
# dict(type='PreNormalize3D'),
# dict(type='GenSkeFeat', dataset='nturgb+d', feats=['j']),
# dict(type='UniformSample', clip_len=clip_len, num_clips=1),
# dict(type='PoseDecode'),
# dict(type='FormatGCNInput', num_person=2, mode=mode),
# dict(type='Collect', keys=['keypoint', 'label'], meta_keys=[]),
# dict(type='ToTensor', keys=['keypoint'])
# ]
# test_pipeline = [
# dict(type='PreNormalize3D'),
# dict(type='GenSkeFeat', dataset='nturgb+d', feats=['j']),
# dict(type='UniformSample', clip_len=clip_len, num_clips=10),
# dict(type='PoseDecode'),
# dict(type='FormatGCNInput', num_person=2, mode=mode),
# dict(type='Collect', keys=['keypoint', 'label'], meta_keys=[]),
# dict(type='ToTensor', keys=['keypoint'])
# ]
data = dict(
videos_per_gpu=2,
workers_per_gpu=8,
test_dataloader=dict(videos_per_gpu=8),
train=dict(type=dataset_type, ann_file=ann_file, pipeline=train_pipeline, split='xsub_train'),
val=dict(type=dataset_type, ann_file=ann_file, pipeline=val_pipeline, split='xsub_val'),
test=dict(type=dataset_type, ann_file=ann_file, pipeline=test_pipeline, split='xsub_val'))

# optimizer
optimizer = dict(type='AdamW', lr=0.0001, weight_decay=0.01)
optimizer_config = dict(grad_clip=dict(max_norm=3.0, norm_type=2))
# learning policy
# lr_config = dict(policy='CosineAnnealing', min_lr=0, by_epoch=False)
lr_config = dict(
policy='CosineAnnealing',
warmup='linear',
warmup_by_epoch=True,
warmup_iters=10,
warmup_ratio=1.0 / 100,
min_lr_ratio=1e-6)
total_epochs = 100
checkpoint_config = dict(interval=1)
evaluation = dict(interval=4, metrics=['top_k_accuracy'])
log_config = dict(interval=100, hooks=[dict(type='TextLoggerHook'), dict(type='WandbLoggerHook')])

# runtime settings
log_level = 'INFO'
work_dir = './work_dirs/lst/ntu60_xsub_3dkp/j_vanilla_variable_dim/12.8-ST_ST-basic1'
find_unused_parameters = False
auto_resume = False
seed = 88
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,5 @@ mediapipe~=0.9.0.1
pyskl~=0.1.0
einops~=0.6.1
requests~=2.28.2
setuptools~=67.6.0
setuptools~=67.6.0
wandb~=0.16.1

0 comments on commit e516057

Please sign in to comment.