Skip to content

Commit

Permalink
[sparsity] Base sparsity level scheduler class (pytorch#59770)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#59770

Implements the base scheduler class for changing the sparsity levels in the sparsifier.

Test Plan:
```
python test/test_ao_sparsity.py
```
Imported from OSS

Differential Revision:
D29070603
D29070603

Reviewed By: raghuramank100

Pulled By: z-a-f

fbshipit-source-id: 0b160e4eb0a2a303d2d19e6a3beb4784002b2cb7
  • Loading branch information
z-a-f authored and facebook-github-bot committed Jul 3, 2021
1 parent ed63fb5 commit 37ebf2e
Show file tree
Hide file tree
Showing 5 changed files with 225 additions and 0 deletions.
70 changes: 70 additions & 0 deletions test/ao/sparsity/test_scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# -*- coding: utf-8 -*-
from torch import nn
from torch.ao.sparsity import WeightNormSparsifier
from torch.ao.sparsity import BaseScheduler

from torch.testing._internal.common_utils import TestCase

import warnings

class ImplementedScheduler(BaseScheduler):
def get_sl(self):
if self.last_epoch > 0:
return [group['sparsity_level'] * 0.5
for group in self.sparsifier.module_groups]
else:
return list(self.base_sl)


class TestBaseScheduler(TestCase):
def test_constructor(self):
model = nn.Sequential(
nn.Linear(16, 16)
)
sparsifier = WeightNormSparsifier()
sparsifier.prepare(model, config=None)
scheduler = ImplementedScheduler(sparsifier)

assert scheduler.sparsifier is sparsifier
assert scheduler._step_count == 1
assert scheduler.base_sl == [sparsifier.module_groups[0]['sparsity_level']]

def test_order_of_steps(self):
"""Checks if the warning is thrown if the scheduler step is called
before the sparsifier step"""

model = nn.Sequential(
nn.Linear(16, 16)
)
sparsifier = WeightNormSparsifier()
sparsifier.prepare(model, config=None)
scheduler = ImplementedScheduler(sparsifier)

# Sparsifier step is not called
with self.assertWarns(UserWarning):
scheduler.step()

# Correct order has no warnings
# Note: This will trigger if other warnings are present.
with warnings.catch_warnings(record=True) as w:
sparsifier.step()
scheduler.step()
# Make sure there is no warning related to the base_scheduler
for warning in w:
fname = warning.filename
fname = '/'.join(fname.split('/')[-5:])
assert fname != 'torch/ao/sparsity/scheduler/base_scheduler.py'

def test_step(self):
model = nn.Sequential(
nn.Linear(16, 16)
)
sparsifier = WeightNormSparsifier()
sparsifier.prepare(model, config=None)
assert sparsifier.module_groups[0]['sparsity_level'] == 0.5
scheduler = ImplementedScheduler(sparsifier)
assert sparsifier.module_groups[0]['sparsity_level'] == 0.5

sparsifier.step()
scheduler.step()
assert sparsifier.module_groups[0]['sparsity_level'] == 0.25
3 changes: 3 additions & 0 deletions test/test_ao_sparsity.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,8 @@
from ao.sparsity.test_sparsifier import TestBaseSparsifier # noqa: F401
from ao.sparsity.test_sparsifier import TestWeightNormSparsifier # noqa: F401

# Scheduler
from ao.sparsity.test_scheduler import TestBaseScheduler # noqa: F401

if __name__ == '__main__':
run_tests()
2 changes: 2 additions & 0 deletions torch/ao/sparsity/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
# Sparsifier
from .sparsifier.base_sparsifier import BaseSparsifier
from .sparsifier.weight_norm_sparsifier import WeightNormSparsifier
# Scheduler
from .scheduler.base_scheduler import BaseScheduler

# Parametrizations
from .sparsifier.utils import FakeSparsity
Expand Down
Empty file.
150 changes: 150 additions & 0 deletions torch/ao/sparsity/scheduler/base_scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@

from torch.ao.sparsity import BaseSparsifier

from functools import wraps
import warnings
import weakref

class BaseScheduler(object):

def __init__(self, sparsifier, last_epoch=-1, verbose=False):

# Attach sparsifier
if not isinstance(sparsifier, BaseSparsifier):
raise TypeError('{} is not an instance of torch.ao.sparsity.BaseSparsifier'.format(
type(sparsifier).__name__))
self.sparsifier = sparsifier

# Initialize epoch and base sparsity levels

self.base_sl = [group['sparsity_level'] for group in sparsifier.module_groups]
self.last_epoch = last_epoch

# Following https://github.com/pytorch/pytorch/issues/20124
# We would like to ensure that `scheduler.step()` is called after
# `sparsifier.step()`
def with_counter(method):
if getattr(method, '_with_counter', False):
# `sparsifier.step()` has already been replaced, return.
return method

# Keep a weak reference to the sparsifier instance to prevent
# cyclic references.
instance_ref = weakref.ref(method.__self__)
# Get the unbound method for the same purpose.
func = method.__func__
cls = instance_ref().__class__
del method

@wraps(func)
def wrapper(*args, **kwargs):
instance = instance_ref()
instance._step_count += 1 # type: ignore[union-attr]
wrapped = func.__get__(instance, cls)
return wrapped(*args, **kwargs)

# Note that the returned function here is no longer a bound method,
# so attributes like `__func__` and `__self__` no longer exist.
wrapper._with_counter = True # type: ignore[attr-defined]
return wrapper

self.sparsifier.step = with_counter(self.sparsifier.step) # type: ignore[assignment]
self.sparsifier._step_count = 0 # type: ignore[attr-defined]
self._step_count: int = 0
self.verbose = verbose

# Housekeeping
self._get_sl_called_within_step: bool = False

self.step()

def state_dict(self):
"""Returns the state of the scheduler as a :class:`dict`.
It contains an entry for every variable in self.__dict__ which
is not the sparsifier.
"""
return {key: value for key, value in self.__dict__.items() if key != 'sparsifier'}

def load_state_dict(self, state_dict):
"""Loads the schedulers state.
Args:
state_dict (dict): scheduler state. Should be an object returned
from a call to :meth:`state_dict`.
"""
self.__dict__.update(state_dict)

def get_last_sl(self):
""" Return last computed sparsity level by current scheduler.
"""
return self._last_sl

def get_sl(self):
# Compute sparsity level using chainable form of the scheduler
# Note: This method is not intended to be called directly, and is only
# used by the ".step" method. Use .get_last_sl() instead.
if not self._get_sl_called_within_step:
warnings.warn(
"To get the last sparsity level computed by the scheduler, "
"please use `get_last_sl()`.")
raise NotImplementedError

def print_sl(self, is_verbose, group, sl, epoch=None):
"""Display the current sparsity level.
"""
if is_verbose:
if epoch is None:
print('Adjusting sparsity level'
' of group {} to {:.4e}.'.format(group, sl))
else:
print('Epoch {:5d}: adjusting sparsity level'
' of group {} to {:.4e}.'.format(epoch, group, sl))

def __repr__(self):
format_string = self.__class__.__name__ + ' ('
format_string += '\n'
format_string += 'Sparsifier {0}\n'.format(self.sparsifier)
format_string += ' {0}: {1}\n'.format('base_sl', self.base_sl)
format_string += ')'
return format_string

def step(self, epoch=None):
# Raise warning if trying to call scheduler step before the sparsifier.
# https://github.com/pytorch/pytorch/issues/20124
if self._step_count == 1:
if not hasattr(self.sparsifier.step, "_with_counter"):
warnings.warn("Seems like `sparsifier.step()` has been overridden after sparsity scheduler "
"initialization. Please, make sure to call `sparsifier.step()` before "
"`scheduler.step()`.", UserWarning)

# Just check if there were two first scheduler.step() calls before sparsifier.step()
elif self.sparsifier._step_count < 1: # type: ignore[attr-defined]
warnings.warn("Detected call of `scheduler.step()` before `sparsifier.step()`. "
"You have to make sure you run the sparsifier.step() BEFORE any "
"calls to the scheduer.step().", UserWarning)
self._step_count += 1

class _enable_get_sl_call:

def __init__(self, o):
self.o = o

def __enter__(self):
self.o._get_sl_called_within_step = True
return self

def __exit__(self, type, value, traceback):
self.o._get_sl_called_within_step = False

with _enable_get_sl_call(self):
self.last_epoch += 1
values = self.get_sl()

for i, data in enumerate(zip(self.sparsifier.module_groups, values)):
param_group, sl = data
param_group['sparsity_level'] = sl
self.print_sl(self.verbose, i, sl, epoch)

self._last_sl = [group['sparsity_level'] for group in self.sparsifier.module_groups]
self.sparsifier.enable_mask_update = True

0 comments on commit 37ebf2e

Please sign in to comment.