Skip to content

Commit

Permalink
Update train.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Bobo-y authored Mar 5, 2024
1 parent c73bcfe commit cbaa57a
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
# DP mode
if opt.sparsity:
ASP.init_model_for_pruning(model, mask_calculator='m4n2_1d', verbosity=2, whitelist=[torch.nn.Linear, torch.nn.Conv2D],
allow_recompute_mask=False, disallowed_layer_names=opt.sparsity_ignore_names, allow_permutation=False)
allow_recompute_mask=True, disallowed_layer_names=opt.sparsity_ignore_names, allow_permutation=False)
ASP.init_optimizer_for_pruning(optimizer)
ASP.compute_sparse_masks()
if cuda and RANK == -1 and torch.cuda.device_count() > 1:
Expand Down

0 comments on commit cbaa57a

Please sign in to comment.