Skip to content

Commit

Permalink
[sparsity] Fix GPU training for sparsity (pytorch#66412)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#66412

The GPU training was not supported in the sparsifier.
The reason was that when the sparsifier was created the masks would default to the CPU.
Attaching a GPU model to the sparsifier would throw an error.
The solution is to create the masks on the same device as the weight.

Test Plan: Imported from OSS

Reviewed By: vkuzo

Differential Revision: D31590675

Pulled By: z-a-f

fbshipit-source-id: 98c2c1cedc7c60aecea4076e5254ef6b3443139e
  • Loading branch information
z-a-f authored and facebook-github-bot committed Nov 23, 2021
1 parent 0b06741 commit e7d8f09
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion torch/ao/sparsity/sparsifier/base_sparsifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def _prepare(self, *args, **kwargs):
for config in self.module_groups:
module = config['module']
param = config.get('parametrization', FakeSparsity)
mask = config.get('mask', torch.ones(module.weight.shape))
mask = config.get('mask', torch.ones_like(module.weight))
self.state[config['fqn']]['mask'] = mask
parametrize.register_parametrization(module, 'weight', param(mask))

Expand Down

0 comments on commit e7d8f09

Please sign in to comment.