forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[sparsity] Base sparsity level scheduler class (pytorch#59770)
Summary: Pull Request resolved: pytorch#59770 Implements the base scheduler class for changing the sparsity levels in the sparsifier. Test Plan: ``` python test/test_ao_sparsity.py ``` Imported from OSS Differential Revision: D29070603 D29070603 Reviewed By: raghuramank100 Pulled By: z-a-f fbshipit-source-id: 0b160e4eb0a2a303d2d19e6a3beb4784002b2cb7
- Loading branch information
1 parent
ed63fb5
commit 37ebf2e
Showing
5 changed files
with
225 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
# -*- coding: utf-8 -*- | ||
from torch import nn | ||
from torch.ao.sparsity import WeightNormSparsifier | ||
from torch.ao.sparsity import BaseScheduler | ||
|
||
from torch.testing._internal.common_utils import TestCase | ||
|
||
import warnings | ||
|
||
class ImplementedScheduler(BaseScheduler): | ||
def get_sl(self): | ||
if self.last_epoch > 0: | ||
return [group['sparsity_level'] * 0.5 | ||
for group in self.sparsifier.module_groups] | ||
else: | ||
return list(self.base_sl) | ||
|
||
|
||
class TestBaseScheduler(TestCase): | ||
def test_constructor(self): | ||
model = nn.Sequential( | ||
nn.Linear(16, 16) | ||
) | ||
sparsifier = WeightNormSparsifier() | ||
sparsifier.prepare(model, config=None) | ||
scheduler = ImplementedScheduler(sparsifier) | ||
|
||
assert scheduler.sparsifier is sparsifier | ||
assert scheduler._step_count == 1 | ||
assert scheduler.base_sl == [sparsifier.module_groups[0]['sparsity_level']] | ||
|
||
def test_order_of_steps(self): | ||
"""Checks if the warning is thrown if the scheduler step is called | ||
before the sparsifier step""" | ||
|
||
model = nn.Sequential( | ||
nn.Linear(16, 16) | ||
) | ||
sparsifier = WeightNormSparsifier() | ||
sparsifier.prepare(model, config=None) | ||
scheduler = ImplementedScheduler(sparsifier) | ||
|
||
# Sparsifier step is not called | ||
with self.assertWarns(UserWarning): | ||
scheduler.step() | ||
|
||
# Correct order has no warnings | ||
# Note: This will trigger if other warnings are present. | ||
with warnings.catch_warnings(record=True) as w: | ||
sparsifier.step() | ||
scheduler.step() | ||
# Make sure there is no warning related to the base_scheduler | ||
for warning in w: | ||
fname = warning.filename | ||
fname = '/'.join(fname.split('/')[-5:]) | ||
assert fname != 'torch/ao/sparsity/scheduler/base_scheduler.py' | ||
|
||
def test_step(self): | ||
model = nn.Sequential( | ||
nn.Linear(16, 16) | ||
) | ||
sparsifier = WeightNormSparsifier() | ||
sparsifier.prepare(model, config=None) | ||
assert sparsifier.module_groups[0]['sparsity_level'] == 0.5 | ||
scheduler = ImplementedScheduler(sparsifier) | ||
assert sparsifier.module_groups[0]['sparsity_level'] == 0.5 | ||
|
||
sparsifier.step() | ||
scheduler.step() | ||
assert sparsifier.module_groups[0]['sparsity_level'] == 0.25 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
|
||
from torch.ao.sparsity import BaseSparsifier | ||
|
||
from functools import wraps | ||
import warnings | ||
import weakref | ||
|
||
class BaseScheduler(object): | ||
|
||
def __init__(self, sparsifier, last_epoch=-1, verbose=False): | ||
|
||
# Attach sparsifier | ||
if not isinstance(sparsifier, BaseSparsifier): | ||
raise TypeError('{} is not an instance of torch.ao.sparsity.BaseSparsifier'.format( | ||
type(sparsifier).__name__)) | ||
self.sparsifier = sparsifier | ||
|
||
# Initialize epoch and base sparsity levels | ||
|
||
self.base_sl = [group['sparsity_level'] for group in sparsifier.module_groups] | ||
self.last_epoch = last_epoch | ||
|
||
# Following https://github.com/pytorch/pytorch/issues/20124 | ||
# We would like to ensure that `scheduler.step()` is called after | ||
# `sparsifier.step()` | ||
def with_counter(method): | ||
if getattr(method, '_with_counter', False): | ||
# `sparsifier.step()` has already been replaced, return. | ||
return method | ||
|
||
# Keep a weak reference to the sparsifier instance to prevent | ||
# cyclic references. | ||
instance_ref = weakref.ref(method.__self__) | ||
# Get the unbound method for the same purpose. | ||
func = method.__func__ | ||
cls = instance_ref().__class__ | ||
del method | ||
|
||
@wraps(func) | ||
def wrapper(*args, **kwargs): | ||
instance = instance_ref() | ||
instance._step_count += 1 # type: ignore[union-attr] | ||
wrapped = func.__get__(instance, cls) | ||
return wrapped(*args, **kwargs) | ||
|
||
# Note that the returned function here is no longer a bound method, | ||
# so attributes like `__func__` and `__self__` no longer exist. | ||
wrapper._with_counter = True # type: ignore[attr-defined] | ||
return wrapper | ||
|
||
self.sparsifier.step = with_counter(self.sparsifier.step) # type: ignore[assignment] | ||
self.sparsifier._step_count = 0 # type: ignore[attr-defined] | ||
self._step_count: int = 0 | ||
self.verbose = verbose | ||
|
||
# Housekeeping | ||
self._get_sl_called_within_step: bool = False | ||
|
||
self.step() | ||
|
||
def state_dict(self): | ||
"""Returns the state of the scheduler as a :class:`dict`. | ||
It contains an entry for every variable in self.__dict__ which | ||
is not the sparsifier. | ||
""" | ||
return {key: value for key, value in self.__dict__.items() if key != 'sparsifier'} | ||
|
||
def load_state_dict(self, state_dict): | ||
"""Loads the schedulers state. | ||
Args: | ||
state_dict (dict): scheduler state. Should be an object returned | ||
from a call to :meth:`state_dict`. | ||
""" | ||
self.__dict__.update(state_dict) | ||
|
||
def get_last_sl(self): | ||
""" Return last computed sparsity level by current scheduler. | ||
""" | ||
return self._last_sl | ||
|
||
def get_sl(self): | ||
# Compute sparsity level using chainable form of the scheduler | ||
# Note: This method is not intended to be called directly, and is only | ||
# used by the ".step" method. Use .get_last_sl() instead. | ||
if not self._get_sl_called_within_step: | ||
warnings.warn( | ||
"To get the last sparsity level computed by the scheduler, " | ||
"please use `get_last_sl()`.") | ||
raise NotImplementedError | ||
|
||
def print_sl(self, is_verbose, group, sl, epoch=None): | ||
"""Display the current sparsity level. | ||
""" | ||
if is_verbose: | ||
if epoch is None: | ||
print('Adjusting sparsity level' | ||
' of group {} to {:.4e}.'.format(group, sl)) | ||
else: | ||
print('Epoch {:5d}: adjusting sparsity level' | ||
' of group {} to {:.4e}.'.format(epoch, group, sl)) | ||
|
||
def __repr__(self): | ||
format_string = self.__class__.__name__ + ' (' | ||
format_string += '\n' | ||
format_string += 'Sparsifier {0}\n'.format(self.sparsifier) | ||
format_string += ' {0}: {1}\n'.format('base_sl', self.base_sl) | ||
format_string += ')' | ||
return format_string | ||
|
||
def step(self, epoch=None): | ||
# Raise warning if trying to call scheduler step before the sparsifier. | ||
# https://github.com/pytorch/pytorch/issues/20124 | ||
if self._step_count == 1: | ||
if not hasattr(self.sparsifier.step, "_with_counter"): | ||
warnings.warn("Seems like `sparsifier.step()` has been overridden after sparsity scheduler " | ||
"initialization. Please, make sure to call `sparsifier.step()` before " | ||
"`scheduler.step()`.", UserWarning) | ||
|
||
# Just check if there were two first scheduler.step() calls before sparsifier.step() | ||
elif self.sparsifier._step_count < 1: # type: ignore[attr-defined] | ||
warnings.warn("Detected call of `scheduler.step()` before `sparsifier.step()`. " | ||
"You have to make sure you run the sparsifier.step() BEFORE any " | ||
"calls to the scheduer.step().", UserWarning) | ||
self._step_count += 1 | ||
|
||
class _enable_get_sl_call: | ||
|
||
def __init__(self, o): | ||
self.o = o | ||
|
||
def __enter__(self): | ||
self.o._get_sl_called_within_step = True | ||
return self | ||
|
||
def __exit__(self, type, value, traceback): | ||
self.o._get_sl_called_within_step = False | ||
|
||
with _enable_get_sl_call(self): | ||
self.last_epoch += 1 | ||
values = self.get_sl() | ||
|
||
for i, data in enumerate(zip(self.sparsifier.module_groups, values)): | ||
param_group, sl = data | ||
param_group['sparsity_level'] = sl | ||
self.print_sl(self.verbose, i, sl, epoch) | ||
|
||
self._last_sl = [group['sparsity_level'] for group in self.sparsifier.module_groups] | ||
self.sparsifier.enable_mask_update = True |