Skip to content

Commit

Permalink
Follow-up on out scaling alpha (#360)
Browse files Browse the repository at this point in the history
  • Loading branch information
maljoras authored Mar 11, 2022
1 parent 6b7f5ba commit d295cf7
Show file tree
Hide file tree
Showing 15 changed files with 281 additions and 362 deletions.
17 changes: 17 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,23 @@ The format is based on [Keep a Changelog], and this project adheres to
* `Fixed` for any bug fixes.
* `Security` in case of vulnerabilities.

## Unreleased

### Added

* Set weights can be used to re-apply the weight scaling omega. (\#360)
* Out scaling factors can be learnt even if weight scaling omega was set to 0. (\#360)

### Fixed

* Legacy checkpoint load with alpha scaling. (\#360)
* Re-application of weight scaling omega when loading checkpoints. (\#360)

### Changed

* The ``set_alpha_scale`` and ``get_alpha_scale`` methods of the C++ tiles are removed. (\#360)


## [0.5.1] - 2022/01/27

### Added
Expand Down
7 changes: 6 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ option(USE_CUDA "Build with CUDA support" OFF)
option(RPU_DEBUG "Enable debug printing" OFF)
option(RPU_USE_FASTMOD "Use fast mod" ON)
option(RPU_USE_FASTRAND "Use fastrand" OFF)
option(USE_ABI_ZERO "Whether to set _GLIBCXX_USE_CXX11_ABI=0" ON)

set(RPU_BLAS "OpenBLAS" CACHE STRING "BLAS backend of choice (OpenBLAS, MKL)")
set(RPU_CUDA_ARCHITECTURES "60" CACHE STRING "Target CUDA architectures")

Expand Down Expand Up @@ -49,7 +51,10 @@ else()
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -O3 -ftree-vectorize")
endif()

add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0)
if (USE_ABI_ZERO)
add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0)
endif()

if (APPLE)
string(APPEND CMAKE_CXX_FLAGS " -fvisibility=hidden")
endif()
Expand Down
5 changes: 4 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.

.PHONY: build_inplace clean clean-doc clang-format mypy pycodestyle pylint pytest build_inplace_mkl build_inplace_cuda build_cuda
.PHONY: build_inplace clean clean-doc clang-format mypy pycodestyle pylint pytest build_inplace_mkl build_inplace_cuda build_cuda build_inplace_cuda_abi

build_inplace:
python setup.py build_ext -j8 -DCMAKE_BUILD_TYPE=Debug -DCMAKE_EXPORT_COMPILE_COMMANDS=TRUE --inplace ${flags}
Expand All @@ -30,6 +30,9 @@ build_cuda:
build_inplace_cuda:
make build_inplace_mkl flags="-DUSE_CUDA=ON ${flags}"

build_inplace_cuda_abi:
make build_inplace_mkl flags="-DUSE_CUDA=ON -DUSE_ABI_ZERO=OFF ${flags}"

clean:
python setup.py clean
rm -rf _skbuild
Expand Down
4 changes: 2 additions & 2 deletions examples/20_mnist_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def train(model, train_set):
# Decay learning rate if needed.
scheduler.step()

dist.all_gather(total_time, torch.tensor(time()-time_init).to(device))
dist.all_gather(total_time, torch.Tensor(time()-time_init).to(device))

if rank == 0:
avg_train_time = torch.mean(torch.cat(total_time, 0))
Expand Down Expand Up @@ -231,7 +231,7 @@ def test_evaluation(model, val_set):
total_images += labels.size(0)
predicted_ok += (predicted == labels).sum().item()

dist.all_gather(acc_list, torch.tensor(predicted_ok/total_images).to(device))
dist.all_gather(acc_list, torch.Tensor(predicted_ok/total_images).to(device))

if rank == 0:
acc = torch.mean(torch.cat(acc_list, 0))
Expand Down
4 changes: 4 additions & 0 deletions src/aihwkit/cloud/converter/v1/mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,10 @@ def get_field_value_to_proto(self, source: Any, field: str, default: Any = None)
"""Get the value of a field."""
if field == 'bias':
return getattr(source, 'bias', None) is not None

if field == 'weight_scaling_omega':
return list(source.analog_tiles())[0].rpu_config.mapping.weight_scaling_omega

if field == 'rpu_config':
preset_cls = type(source.analog_tile.rpu_config)
try:
Expand Down
89 changes: 51 additions & 38 deletions src/aihwkit/nn/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,12 @@
# that they have been altered from the originals.

"""Base class for analog Modules."""
import warnings

from typing import (
Any, Dict, List, Optional, Tuple, NamedTuple, Union,
Generator, TYPE_CHECKING
)

from torch import Tensor, no_grad
from torch import Tensor, no_grad, ones, float32
from torch.nn import Module, Parameter

from aihwkit.exceptions import ModuleError
Expand Down Expand Up @@ -64,9 +62,6 @@ class AnalogModuleBase(Module):
bias: whether to use a bias row on the analog tile or not.
realistic_read_write: whether to enable realistic read/write
for setting initial weights and during reading of the weights.
weight_scaling_omega: the weight value that the current max
weight value will be scaled to. If zero, no weight scaling will
be performed.
mapping: Configuration of the hardware architecture (e.g. tile size).
"""
# pylint: disable=abstract-method, too-many-instance-attributes
Expand All @@ -81,7 +76,6 @@ def __init__(
out_features: int,
bias: bool,
realistic_read_write: bool = False,
weight_scaling_omega: Optional[float] = None,
mapping: Optional[MappingParameter] = None,
) -> None:
# pylint: disable=super-init-not-called
Expand All @@ -95,20 +89,6 @@ def __init__(
self.use_bias = bias
self.digital_bias = bias and mapping.digital_bias
self.analog_bias = bias and not mapping.digital_bias
self.weight_scaling_omega = mapping.weight_scaling_omega if weight_scaling_omega is None \
else weight_scaling_omega
if weight_scaling_omega is not None:
warnings.warn(DeprecationWarning('\nSetting the weight_scaling_omega through the '
'layers input parameters will be deprecated in the '
'future. Please set it through the MappingParameter '
'of the rpu_config.\n'))

self.weight_scaling_omega_columnwise = mapping.weight_scaling_omega_columnwise
self.learn_out_scaling_alpha = mapping.learn_out_scaling_alpha

if self.learn_out_scaling_alpha and self.weight_scaling_omega == 0:
raise ValueError('out_scaling_alpha can only be learned if weight_scaling_omega > 0')

self.realistic_read_write = realistic_read_write
self.in_features = in_features
self.out_features = out_features
Expand Down Expand Up @@ -143,8 +123,11 @@ def register_analog_tile(self, tile: 'BaseTile', name: Optional[str] = None) ->
if par_name not in self._registered_helper_parameter:
self._registered_helper_parameter.append(par_name)

if self.learn_out_scaling_alpha:
if not isinstance(tile.out_scaling_alpha, Parameter):
mapping = tile.rpu_config.mapping
if mapping.learn_out_scaling_alpha:
if tile.out_scaling_alpha is None:
tile.out_scaling_alpha = Parameter(ones([1], device=tile.device, dtype=float32))
elif not isinstance(tile.out_scaling_alpha, Parameter):
tile.out_scaling_alpha = Parameter(tile.out_scaling_alpha)
par_name = self.ANALOG_OUT_SCALING_ALPHA_PREFIX + str(self._analog_tile_counter)
self.register_parameter(par_name, tile.out_scaling_alpha)
Expand Down Expand Up @@ -219,7 +202,9 @@ def set_weights(
self,
weight: Tensor,
bias: Optional[Tensor] = None,
force_exact: bool = False
force_exact: bool = False,
remap_weights: bool = True,
weight_scaling_omega: float = None
) -> None:
"""Set the weight (and bias) values with given tensors.
Expand All @@ -242,6 +227,17 @@ def set_weights(
weight: weight matrix
bias: bias vector
force_exact: forces an exact write to the analog tiles
remap_weights: Whether to rescale the given weight matrix
and populate the digital output scaling factors as
specified in the configuration
:class:`~aihwkit.configs.utils.MappingParameter`. A
new ``weight_scaling_omega`` can be given. Note that
this will overwrite the existing digital out scaling
factors.
weight_scaling_omega: The weight scaling omega factor (see
:class:`~aihwkit.configs.utils.MappingParameter`). If
given explicitly here, it will overwrite the value in
the mapping field.
Raises:
ModuleError: in case of multiple defined analog tiles in the module
Expand All @@ -257,13 +253,15 @@ def set_weights(
raise ModuleError("AnalogModuleBase.set_weights only supports a single tile.")
analog_tile = analog_tiles[0]

if self.weight_scaling_omega > 0.0:
if remap_weights:
omega = weight_scaling_omega
if omega is None:
omega = analog_tile.rpu_config.mapping.weight_scaling_omega

analog_tile.set_weights_scaled(
weight, bias if self.analog_bias else None,
realistic=realistic,
omega=self.weight_scaling_omega,
weight_scaling_omega_columnwise=self.weight_scaling_omega_columnwise,
learn_out_scaling_alpha=self.learn_out_scaling_alpha)
weight_scaling_omega=omega)
else:
analog_tile.set_weights(weight, bias if self.analog_bias else None,
realistic=realistic)
Expand All @@ -276,7 +274,8 @@ def set_weights(

def get_weights(
self,
force_exact: bool = False
force_exact: bool = False,
apply_out_scales: bool = True,
) -> Tuple[Tensor, Optional[Tensor]]:
"""Get the weight (and bias) tensors.
Expand All @@ -294,30 +293,46 @@ def get_weights(
analog tile library, for performance reasons.
Args:
force_exact: forces an exact read to the analog tiles
force_exact: Forces an exact read to the analog tiles
apply_out_scales: Whether to return the weights with the
(digital) output scaling factors applied. Note the
"logical" weights of the layer which the DNN is
effectively using are those with the output scales
applied. If ``apply_out_scales`` is set to False, then
only the weight values that is programmed onto the
crossbar array are returned, without applying the
digital scales.
Returns:
tuple: weight matrix, bias vector
Raises:
ModuleError: in case of multiple defined analog tiles in the module
"""
analog_tiles = list(self.analog_tiles())
if len(analog_tiles) != 1:
raise ModuleError("AnalogModuleBase.get_weights only supports a single tile.")
analog_tile = analog_tiles[0]

realistic = self.realistic_read_write and not force_exact
if self.weight_scaling_omega > 0.0:
weight, bias = analog_tile.get_weights_scaled(
realistic=realistic,
weight_scaling_omega_columnwise=self.weight_scaling_omega_columnwise)
if apply_out_scales:
weight, analog_bias = analog_tile.get_weights_scaled(realistic=realistic)
else:
weight, bias = analog_tile.get_weights(realistic=realistic)
weight, analog_bias = analog_tile.get_weights(realistic=realistic)

digital_bias = None
if self.digital_bias:
with no_grad():
bias = self.bias.data.detach().cpu()
digital_bias = self.bias.data.clone().detach().cpu()

if (digital_bias is not None) and (analog_bias is not None):
bias = digital_bias + analog_bias
elif digital_bias is not None:
bias = digital_bias
else:
bias = analog_bias
return weight, bias

def _sync_weights_from_tile(self) -> None:
Expand Down Expand Up @@ -491,8 +506,6 @@ def extra_repr(self) -> str:
output = super().extra_repr()
if self.realistic_read_write:
output += ', realistic_read_write={}'.format(self.realistic_read_write)
if self.weight_scaling_omega > 0:
output += ', weight_scaling_omega={:.3f}'.format(self.weight_scaling_omega)
if self.analog_bias:
output += ', analog bias'
if self.digital_bias:
Expand Down
Loading

0 comments on commit d295cf7

Please sign in to comment.