Skip to content

Commit

Permalink
[sparsity] Fix for accumulation bug in WeightNormSparsifier (pytorch#…
Browse files Browse the repository at this point in the history
…65293)

Summary:
Pull Request resolved: pytorch#65293

This fixes a bug in the WeightNormSparsifier, where the mask is being multiplied by the newly computed mask.
Because the mask elements are binary 0/1, this accumulates the mask over every iteration, eventually collapsing the mask to zero.
This bug accidentally bled through from old versions.

Test Plan: Imported from OSS

Reviewed By: gchanan

Differential Revision: D31186829

Pulled By: z-a-f

fbshipit-source-id: 3f5b2c833148ab0bd8084e7410ce398f1252e65e
  • Loading branch information
z-a-f authored and facebook-github-bot committed Sep 28, 2021
1 parent a90912e commit 92ee5cc
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 1 deletion.
9 changes: 9 additions & 0 deletions test/ao/sparsity/test_sparsifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,15 @@ def test_step(self):
# After step
module = g['module']
assert (1.0 - module.parametrizations['weight'][0].mask.mean()) > 0 # checking sparsity level has increased
# Test if the mask collapses to all zeros if the weights are randomized
iters_before_collapse = 1000
for _ in range(iters_before_collapse):
model.linear.weight.data = torch.randn(model.linear.weight.shape)
sparsifier.step()
for g in sparsifier.module_groups:
# After step
module = g['module']
assert (1.0 - module.parametrizations['weight'][0].mask.mean()) > 0 # checking sparsity level did not collapse

def test_prepare(self):
model = Model()
Expand Down
2 changes: 1 addition & 1 deletion torch/ao/sparsity/sparsifier/weight_norm_sparsifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,4 @@ def update_mask(self, layer, sparsity_level, sparse_block_shape,
for row, col in zip(rows, cols):
new_mask[row:row + sparse_block_shape[0],
col:col + sparse_block_shape[1]] = 0
mask.data *= new_mask
mask.data = new_mask

0 comments on commit 92ee5cc

Please sign in to comment.