Skip to content

Commit

Permalink
fix gaitedge config and add check
Browse files Browse the repository at this point in the history
for num of transform
  • Loading branch information
darkliang committed Nov 12, 2022
1 parent a71444b commit 1588fde
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 6 deletions.
4 changes: 2 additions & 2 deletions configs/gaitedge/phase2_e2e.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ trainer_cfg:
scheduler_reset: true
sync_BN: true
restore_hint:
- /home/leeeung/workspace/OpenGait/output/CASIA-B_new/Segmentation/Segmentation/checkpoints/Segmentation-25000.pt
- /home/leeeung/OpenGait/output/CASIA-B_new/GaitGL/GaitGL/checkpoints/GaitGL-80000.pt
- Segmentation-25000.pt
- GaitGL-80000.pt
save_iter: 2000
save_name: GaitGL_E2E
total_iter: 20000
Expand Down
4 changes: 2 additions & 2 deletions configs/gaitedge/phase2_gaitedge.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ trainer_cfg:
optimizer_reset: true
scheduler_reset: true
sync_BN: true
restore_hint:
- Segmentation-30000.pt
restore_hint:
- Segmentation-25000.pt
- GaitGL-80000.pt
save_iter: 2000
save_name: GaitEdge
Expand Down
4 changes: 3 additions & 1 deletion opengait/modeling/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,9 @@ def inputs_pretreament(self, inputs):
seqs_batch, labs_batch, typs_batch, vies_batch, seqL_batch = inputs
trf_cfgs = self.engine_cfg['transform']
seq_trfs = get_transform(trf_cfgs)

if len(seqs_batch) != len(seq_trfs):
raise ValueError(
"The number of types of input data and transform should be same. But got {} and {}".format(len(seqs_batch), len(seq_trfs)))
requires_grad = bool(self.training)
seqs = [np2var(np.asarray([trf(fra) for fra in seq]), requires_grad=requires_grad).float()
for trf, seq in zip(seq_trfs, seqs_batch)]
Expand Down
2 changes: 1 addition & 1 deletion opengait/modeling/models/gaitedge.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def build_network(self, model_cfg):
self.is_edge = model_cfg['edge']
self.seg_lr = model_cfg['seg_lr']
self.kernel = torch.ones(
(model_cfg['kernel_size'], model_cfg['kernel_size'])).cuda()
(model_cfg['kernel_size'], model_cfg['kernel_size']))

def finetune_parameters(self):
fine_tune_params = list()
Expand Down

0 comments on commit 1588fde

Please sign in to comment.