Skip to content

Commit

Permalink
add Transformer
Browse files Browse the repository at this point in the history
  • Loading branch information
Tianrui-Li committed Aug 22, 2023
1 parent ca308e7 commit 1c76cef
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 12 deletions.
17 changes: 10 additions & 7 deletions configs/transformer/j2sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,40 +7,43 @@
type='ViViT2',
graph_cfg=dict(layout='nturgb+d', mode='spatial'),
max_position_embeddings_1=26, # 25*40+1=1001
max_position_embeddings_2=121,
# dropout=0.1,
max_position_embeddings_2=101,
# dropout=0.1
dim=256,
),
cls_head=dict(type='vit2Head', num_classes=60, in_channels=256))

dataset_type = 'PoseDataset'
ann_file = 'data/nturgbd/ntu60_3danno.pkl'
clip_len = 120
clip_len = 100
train_pipeline = [
dict(type='PreNormalize3D'),
dict(type='STTSample', clip_len=clip_len, p_interval=(0.5, 1)),
dict(type='STTSample', clip_len=120, p_interval=(0.5, 1)),
dict(type='RandomScale', scale=0.1),
dict(type='RandomRot'),
# dict(type='RandomRot', theta=0.2),
dict(type='GenSkeFeat', dataset='nturgb+d', feats=['j']),
dict(type='UniformSample', clip_len=clip_len),
dict(type='PoseDecode'),
dict(type='FormatGCNInput', num_person=2),
dict(type='Collect', keys=['keypoint', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['keypoint'])
]
val_pipeline = [
dict(type='PreNormalize3D'),
dict(type='STTSample', clip_len=120, p_interval=(0.5, 1)),
dict(type='GenSkeFeat', dataset='nturgb+d', feats=['j']),
dict(type='STTSample', clip_len=clip_len, p_interval=(0.5, 1)),
dict(type='UniformSample', clip_len=clip_len, num_clips=1),
dict(type='PoseDecode'),
dict(type='FormatGCNInput', num_person=2),
dict(type='Collect', keys=['keypoint', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['keypoint'])
]
test_pipeline = [
dict(type='PreNormalize3D'),
dict(type='STTSample', clip_len=120, p_interval=(0.5, 1)),
dict(type='GenSkeFeat', dataset='nturgb+d', feats=['j']),
dict(type='STTSample', clip_len=clip_len, p_interval=(0.5, 1)),
dict(type='UniformSample', clip_len=clip_len, num_clips=10),
dict(type='PoseDecode'),
dict(type='FormatGCNInput', num_person=2),
dict(type='Collect', keys=['keypoint', 'label'], meta_keys=[]),
Expand Down Expand Up @@ -71,7 +74,7 @@

# runtime settings
log_level = 'INFO'
work_dir = './work_dirs/transformer/j2/8.21-tm2-3'
work_dir = './work_dirs/transformer/j2/8.21-tm2-4'

auto_resume = False
seed = 88
4 changes: 1 addition & 3 deletions pyskl/datasets/pipelines/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,7 @@ def __call__(self, results):
data_numpy = data_numpy.transpose(3, 1, 2, 0)
results['keypoint'] = data_numpy
results['frame_inds'] = self.inds.astype(int)
results['clip_len'] = self.clip_len
results['frame_interval'] = None
results['num_clips'] = self.num_clips
results['total_frames'] = self.clip_len
return results

def __repr__(self):
Expand Down
4 changes: 2 additions & 2 deletions tools/dist_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

export MASTER_PORT=$((12000 + $RANDOM % 20000))
# for server run
# export CUDA_VISIBLE_DEVICES=1,4,6
export CUDA_VISIBLE_DEVICES=4,5,6
export CUDA_VISIBLE_DEVICES=0,2,3
# export CUDA_VISIBLE_DEVICES=4,5,6
# export CUDA_VISIBLE_DEVICES=5,6

set -x
Expand Down

0 comments on commit 1c76cef

Please sign in to comment.