diff --git a/README.md b/README.md index 35e8a82..71a59eb 100644 --- a/README.md +++ b/README.md @@ -56,7 +56,7 @@ To install PyTorch, please refer to https://github.com/pytorch/pytorch#installat To install the package containing the iABN layers: ```bash -pip install git+https://github.com/mapillary/inplace_abn.git@v1.0.12 +pip install inplace_abn ``` Note that some parts of InPlace-ABN have native C++/CUDA implementations, meaning that the command above will need to compile them. @@ -74,7 +74,7 @@ The last of the commands above will install some additional libraries required b ## Training on ImageNet-1k -Here you can find the results from our arXiv paper (top-1 / top-5 scores) with corresponding, trained models and md5 checksums, respectively. The model files provided below are made available under the [license attached to ImageNet](http://www.image-net.org/download-faq). +Here you can find the results from our arXiv paper (top-1 / top-5 scores) with corresponding, trained models and md5 checksums, respectively. The model files provided below are made available under the [license attached to ImageNet](http://www.image-net.org/download-faq). | Network | Batch | 224 | 224, 10-crops | 320 | Trained models (+md5) | |-----------------------------------|-------|----------------|----------------|---------------|----------------------------------| @@ -87,7 +87,7 @@ Here you can find the results from our arXiv paper (top-1 / top-5 scores) with c | [ResNet50v1, InPlace-ABN sync][13] | 512 | 75.53 / 92.59 | 77.04 / 93.57 | 76.60 / 93.49 | [`2522ca639f7fdfd7c0089ba1f5f6c2e8`][14] | | [ResNet34v1, InPlace-ABN sync][15] | 512 | 73.27 / 91.34 | 75.19 / 92.66 | 74.87 / 92.42 | [`61515c1484911c3cc753d405131e1dda`][16] | | [ResNet101v1, InPlace-ABN sync][17] | 512 | 77.07 / 93.45 | 78.58 / 94.40 | 78.25 / 94.19 | [`1552ae0f3d610108df702135f56bd27b`][18] | - + [1]: scripts/experiments/resnext101_stdbn_lr_256.json [2]: scripts/experiments/resnext101_ipabn_lr_512.json [3]: scripts/experiments/resnext152_ipabn_lr_256.json @@ -125,7 +125,7 @@ root/val/[class_id2]/__32_.{jpg,png,jpeg} Images can have any name, as long as the extension is that of a recognized image format. Class ids are also free-form, but they are expected to match between train and validation data. Note that the training data in the standard ImageNet distribution is already given in the required format, while -validation images need to be split into class sub-folders as described above. +validation images need to be split into class sub-folders as described above. ### Training @@ -167,7 +167,7 @@ We have successfully used InPlace-ABN with a DeepLab3 segmentation head that was model above. Due to InPlace-ABN, we can significantly increase the amount of input data to this model, which eventually allowed us to obtain #1 positions on [Cityscapes](https://www.cityscapes-dataset.com/benchmarks/#scene-labeling-task), -[Mapillary Vistas](https://eval-vistas.mapillary.com/featured-challenges/1/leaderboard/1), [AutoNUE](http://cvit.iiit.ac.in/scene-understanding-challenge-2018/benchmarks.php), +[Mapillary Vistas](https://eval-vistas.mapillary.com/featured-challenges/1/leaderboard/1), [AutoNUE](http://cvit.iiit.ac.in/scene-understanding-challenge-2018/benchmarks.php), [Kitti](http://www.cvlibs.net/datasets/kitti/eval_semseg.php?benchmark=semantics2015) and [ScanNet](http://dovahkiin.stanford.edu/adai/semantic_label) segmentation leaderboards. The training settings mostly follow the description in our [paper](https://arxiv.org/abs/1712.02616). @@ -196,7 +196,7 @@ The script will process all `.png`, `.jpg` and `.jpeg` images from the input fol output folder as `.png` images. For additional options, _e.g._ test time augmentation, please consult the script's help message. -The results on the test data written above were obtained by employing only scale 1.0 + flipping. +The results on the test data written above were obtained by employing only scale 1.0 + flipping. ## Changelog diff --git a/inplace_abn/abn.py b/inplace_abn/abn.py index fbdcf9c..a44b93c 100644 --- a/inplace_abn/abn.py +++ b/inplace_abn/abn.py @@ -1,9 +1,11 @@ +from typing import Optional + import torch import torch.distributed as distributed import torch.nn as nn import torch.nn.functional as functional -from .functions import * +from .functions import inplace_abn, inplace_abn_sync class ABN(nn.Module): @@ -11,144 +13,312 @@ class ABN(nn.Module): This gathers a BatchNorm and an activation function in a single module - Parameters - ---------- - num_features : int - Number of feature channels in the input and output. - eps : float - Small constant to prevent numerical issues. - momentum : float - Momentum factor applied to compute running statistics. - affine : bool - If `True` apply learned scale and shift transformation after normalization. - activation : str - Name of the activation functions, one of: `relu`, `leaky_relu`, `elu` or `identity`. - activation_param : float - Negative slope for the `leaky_relu` activation. + Args: + num_features: Number of feature channels in the input and output + eps: Small constant to prevent numerical issues + momentum: Momentum factor applied to compute running statistics with + exponential moving average, or `None` to compute running statistics + with cumulative moving average + affine: If `True` apply learned scale and shift transformation after normalization + track_running_stats: a boolean value that when set to `True`, this + module tracks the running mean and variance, and when set to `False`, + this module does not track such statistics and uses batch statistics instead + in both training and eval modes if the running mean and variance are `None` + activation: Name of the activation functions, one of: `relu`, `leaky_relu`, + `elu` or `identity` + activation_param: Negative slope for the `leaky_relu` activation or `alpha` + parameter for the `elu` activation """ - def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, activation="leaky_relu", - activation_param=0.01): + _version = 2 + __constants__ = [ + "track_running_stats", + "momentum", + "eps", + "num_features", + "affine", + "activation", + "activation_param", + ] + num_features: int + eps: float + momentum: Optional[float] + affine: bool + track_running_stats: bool + activation: str + activation_param: float + + def __init__( + self, + num_features: int, + eps: float = 1e-5, + momentum: Optional[float] = 0.1, + affine: bool = True, + track_running_stats: bool = True, + activation: str = "leaky_relu", + activation_param: float = 0.01, + ): super(ABN, self).__init__() self.num_features = num_features - self.affine = affine self.eps = eps self.momentum = momentum + self.affine = affine + self.track_running_stats = track_running_stats self.activation = activation self.activation_param = activation_param if self.affine: - self.weight = nn.Parameter(torch.ones(num_features)) - self.bias = nn.Parameter(torch.zeros(num_features)) + self.weight = nn.Parameter(torch.Tensor(num_features)) + self.bias = nn.Parameter(torch.Tensor(num_features)) + else: + self.register_parameter("weight", None) + self.register_parameter("bias", None) + if self.track_running_stats: + self.register_buffer("running_mean", torch.zeros(num_features)) + self.register_buffer("running_var", torch.ones(num_features)) + self.register_buffer( + "num_batches_tracked", torch.tensor(0, dtype=torch.long) + ) else: - self.register_parameter('weight', None) - self.register_parameter('bias', None) - self.register_buffer('running_mean', torch.zeros(num_features)) - self.register_buffer('running_var', torch.ones(num_features)) + self.register_parameter("running_mean", None) + self.register_parameter("running_var", None) + self.register_parameter("num_batches_tracked", None) self.reset_parameters() - def reset_parameters(self): - nn.init.constant_(self.running_mean, 0) - nn.init.constant_(self.running_var, 1) + def reset_running_stats(self) -> None: + if self.track_running_stats: + self.running_mean.zero_() + self.running_var.fill_(1) + self.num_batches_tracked.zero_() + + def reset_parameters(self) -> None: + self.reset_running_stats() if self.affine: - nn.init.constant_(self.weight, 1) - nn.init.constant_(self.bias, 0) + nn.init.ones_(self.weight) + nn.init.zeros_(self.bias) - def forward(self, x): - x = functional.batch_norm(x, self.running_mean, self.running_var, self.weight, self.bias, - self.training, self.momentum, self.eps) + def _get_momentum_and_training(self): + if self.momentum is None: + momentum = 0.0 + else: + momentum = self.momentum + + if self.training and self.track_running_stats: + if self.num_batches_tracked is not None: + self.num_batches_tracked = self.num_batches_tracked + 1 + if self.momentum is None: + momentum = 1.0 / float(self.num_batches_tracked) + else: + momentum = self.momentum + + if self.training: + training = True + else: + training = (self.running_mean is None) and (self.running_var is None) + + return momentum, training + + def _get_running_stats(self): + running_mean = ( + self.running_mean if not self.training or self.track_running_stats else None + ) + running_var = ( + self.running_var if not self.training or self.track_running_stats else None + ) + return running_mean, running_var + + def forward(self, x: torch.Tensor) -> torch.Tensor: + momentum, training = self._get_momentum_and_training() + running_mean, running_var = self._get_running_stats() + + x = functional.batch_norm( + x, + running_mean, + running_var, + self.weight, + self.bias, + training, + momentum, + self.eps, + ) if self.activation == "relu": return functional.relu(x, inplace=True) elif self.activation == "leaky_relu": - return functional.leaky_relu(x, negative_slope=self.activation_param, inplace=True) + return functional.leaky_relu( + x, negative_slope=self.activation_param, inplace=True + ) elif self.activation == "elu": return functional.elu(x, alpha=self.activation_param, inplace=True) elif self.activation == "identity": return x else: - raise RuntimeError("Unknown activation function {}".format(self.activation)) + raise RuntimeError(f"Unknown activation function {self.activation}") + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + version = local_metadata.get("version", None) - def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, - error_msgs): - # Post-Pytorch 1.0 models using standard BatchNorm have a "num_batches_tracked" parameter that we need to ignore - num_batches_tracked_key = prefix + "num_batches_tracked" - if num_batches_tracked_key in state_dict: - del state_dict[num_batches_tracked_key] + if (version is None or version < 2) and self.track_running_stats: + # at version 2: added num_batches_tracked buffer + # this should have a default value of 0 + num_batches_tracked_key = prefix + "num_batches_tracked" + if num_batches_tracked_key not in state_dict: + state_dict[num_batches_tracked_key] = torch.tensor(0, dtype=torch.long) - super(ABN, self)._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, - error_msgs, unexpected_keys) + super(ABN, self)._load_from_state_dict( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) def extra_repr(self): - rep = '{num_features}, eps={eps}, momentum={momentum}, affine={affine}, activation={activation}' + rep = "{num_features}, eps={eps}, momentum={momentum}, affine={affine}, activation={activation}" if self.activation in ["leaky_relu", "elu"]: - rep += '[{activation_param}]' + rep += "[{activation_param}]" return rep.format(**self.__dict__) class InPlaceABN(ABN): """InPlace Activated Batch Normalization - Parameters - ---------- - num_features : int - Number of feature channels in the input and output. - eps : float - Small constant to prevent numerical issues. - momentum : float - Momentum factor applied to compute running statistics. - affine : bool - If `True` apply learned scale and shift transformation after normalization. - activation : str - Name of the activation functions, one of: `leaky_relu`, `elu` or `identity`. - activation_param : float - Negative slope for the `leaky_relu` activation. + Args: + num_features: Number of feature channels in the input and output + eps: Small constant to prevent numerical issues + momentum: Momentum factor applied to compute running statistics with + exponential moving average, or `None` to compute running statistics + with cumulative moving average + affine: If `True` apply learned scale and shift transformation after normalization + track_running_stats: a boolean value that when set to `True`, this + module tracks the running mean and variance, and when set to `False`, + this module does not track such statistics and uses batch statistics instead + in both training and eval modes if the running mean and variance are `None` + activation: Name of the activation functions, one of: `relu`, `leaky_relu`, + `elu` or `identity` + activation_param: Negative slope for the `leaky_relu` activation or `alpha` + parameter for the `elu` activation """ - def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, activation="leaky_relu", - activation_param=0.01): - super(InPlaceABN, self).__init__(num_features, eps, momentum, affine, activation, activation_param) + def __init__( + self, + num_features: int, + eps: float = 1e-5, + momentum: Optional[float] = 0.1, + affine: bool = True, + track_running_stats: bool = True, + activation: str = "leaky_relu", + activation_param: float = 0.01, + ): + super(InPlaceABN, self).__init__( + num_features, + eps, + momentum, + affine, + track_running_stats, + activation, + activation_param, + ) def forward(self, x): - x, _, _ = inplace_abn(x, self.weight, self.bias, self.running_mean, self.running_var, - self.training, self.momentum, self.eps, self.activation, self.activation_param) - return x + momentum, training = self._get_momentum_and_training() + running_mean, running_var = self._get_running_stats() + + return inplace_abn( + x, + self.weight, + self.bias, + running_mean, + running_var, + training, + momentum, + self.eps, + self.activation, + self.activation_param, + ) class InPlaceABNSync(ABN): - """InPlace Activated Batch Normalization with cross-GPU synchronization - - This assumes that it will be replicated across GPUs using the same mechanism as in - `nn.parallel.DistributedDataParallel`. - - Parameters - ---------- - num_features : int - Number of feature channels in the input and output. - eps : float - Small constant to prevent numerical issues. - momentum : float - Momentum factor applied to compute running statistics. - affine : bool - If `True` apply learned scale and shift transformation after normalization. - activation : str - Name of the activation functions, one of: `leaky_relu`, `elu` or `identity`. - activation_param : float - Negative slope for the `leaky_relu` activation. - group : distributed.group - Distributed group to synchronize with, default is WORLD + """InPlace Activated Batch Normalization with distributed synchronization + + This operates like `inplace_abn`, but assumes to be called by all replicas + in a given distributed group, and computes batch statistics across all of them. + Note that the input tensors can have different dimensions in each replica. + + Args: + num_features: Number of feature channels in the input and output + eps: Small constant to prevent numerical issues + momentum: Momentum factor applied to compute running statistics with + exponential moving average, or `None` to compute running statistics + with cumulative moving average + affine: If `True` apply learned scale and shift transformation after normalization + track_running_stats: a boolean value that when set to `True`, this + module tracks the running mean and variance, and when set to `False`, + this module does not track such statistics and uses batch statistics instead + in both training and eval modes if the running mean and variance are `None` + activation: Name of the activation functions, one of: `relu`, `leaky_relu`, + `elu` or `identity` + activation_param: Negative slope for the `leaky_relu` activation or `alpha` + parameter for the `elu` activation + group: Distributed group to synchronize with, default is WORLD """ - def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, activation="leaky_relu", - activation_param=0.01, group=distributed.group.WORLD): - super(InPlaceABNSync, self).__init__(num_features, eps, momentum, affine, activation, activation_param) + def __init__( + self, + num_features: int, + eps: float = 1e-5, + momentum: Optional[float] = 0.1, + affine: bool = True, + track_running_stats: bool = True, + activation: str = "leaky_relu", + activation_param: float = 0.01, + group=distributed.group.WORLD, + ): + super(InPlaceABNSync, self).__init__( + num_features, + eps, + momentum, + affine, + track_running_stats, + activation, + activation_param, + ) self.group = group def set_group(self, group): - """Set distributed group to synchronize with, should never be called between forward and backward""" + """Set distributed group to synchronize with + + This function should never be called between forward and backward + + Args: + group: The new distributed group to synchronize with + """ self.group = group def forward(self, x): - x, _, _ = inplace_abn_sync( - x, self.weight, self.bias, self.running_mean, self.running_var, self.training, self.momentum, self.eps, - self.activation, self.activation_param, self.group) - return x + momentum, training = self._get_momentum_and_training() + running_mean, running_var = self._get_running_stats() + + return inplace_abn_sync( + x, + self.weight, + self.bias, + running_mean, + running_var, + training, + momentum, + self.eps, + self.activation, + self.activation_param, + self.group, + ) diff --git a/inplace_abn/functions.py b/inplace_abn/functions.py index 3b836b1..f7d3cb1 100644 --- a/inplace_abn/functions.py +++ b/inplace_abn/functions.py @@ -1,3 +1,6 @@ +from typing import Optional + +import torch import torch.autograd as autograd import torch.distributed as distributed from torch.autograd.function import once_differentiable @@ -30,7 +33,9 @@ def _gather_values(*tensors, group, world_size): gathered, gather_ops = [], [] for t in tensors: t_all = t.new_empty(world_size, *t.shape) - t_op = distributed.all_gather(list(t_all.unbind(0)), t, group=group, async_op=True) + t_op = distributed.all_gather( + list(t_all.unbind(0)), t, group=group, async_op=True + ) gathered.append(t_all) gather_ops.append(t_op) @@ -45,19 +50,32 @@ def _gather_values(*tensors, group, world_size): @staticmethod def _reduce_forward(mean, var, count, group, world_size): all_mean, all_var, all_count = InPlaceABN._gather_values( - mean, var, count, group=group, world_size=world_size) + mean, var, count, group=group, world_size=world_size + ) return _backend.reduce_statistics(all_mean, all_var, all_count) @staticmethod def _reduce_backward(sum_dy, sum_xhat_dy, group, world_size): all_sum_dy, all_sum_xhat_dy = InPlaceABN._gather_values( - sum_dy, sum_xhat_dy, group=group, world_size=world_size) + sum_dy, sum_xhat_dy, group=group, world_size=world_size + ) return all_sum_dy.sum(dim=0), all_sum_xhat_dy.sum(dim=0) @staticmethod - def forward(ctx, x, weight, bias, running_mean, running_var, - training=True, momentum=0.1, eps=1e-05, activation="leaky_relu", activation_param=0.01, - group=None): + def forward( + ctx, + x, + weight, + bias, + running_mean, + running_var, + training=True, + momentum=0.1, + eps=1e-05, + activation="leaky_relu", + activation_param=0.01, + group=None, + ): # Save context ctx.training = training ctx.momentum = momentum @@ -65,6 +83,7 @@ def forward(ctx, x, weight, bias, running_mean, running_var, ctx.activation = _activation_from_name(activation) ctx.activation_param = activation_param ctx.group = group + ctx.has_running_stats = running_mean is not None and running_mean is not None # Check if we really need to perform distributed operations if ctx.group is not None: @@ -79,43 +98,51 @@ def forward(ctx, x, weight, bias, running_mean, running_var, # Gather stats from all workers if needed if ctx.distributed: - mean, var, count = InPlaceABN._reduce_forward(mean, var, count, ctx.group, ctx.world_size) - - # Update running stats - count_ = count.to(dtype=var.dtype) - running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * mean) - running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * var * count_ / (count_ - 1)) - - # Mark in-place modified tensors - ctx.mark_dirty(x, running_mean, running_var) + mean, var, count = InPlaceABN._reduce_forward( + mean, var, count, ctx.group, ctx.world_size + ) + + # Update running stats if needed + if ctx.has_running_stats: + count_ = count.to(dtype=var.dtype) + running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * mean) + running_var.mul_((1 - ctx.momentum)).add_( + ctx.momentum * var * count_ / (count_ - 1) + ) else: mean, var, count = running_mean, running_var, None - # Mark in-place modified tensors - ctx.mark_dirty(x) - # Transform x - _backend.forward(x, mean, var, weight, bias, ctx.eps, ctx.activation, ctx.activation_param) + _backend.forward( + x, mean, var, weight, bias, ctx.eps, ctx.activation, ctx.activation_param + ) - # Save for backward + # Save for backward and mark dirty tensors ctx.save_for_backward(x, var, count, weight, bias) - - ctx.mark_non_differentiable(running_mean, running_var) - return x, running_mean, running_var + ctx.mark_dirty(x) + return x @staticmethod @once_differentiable - def backward(ctx, dy_act, _drunning_mean, _drunning_var): + def backward(ctx, dy_act): y_act, var, count, weight, bias = ctx.saved_tensors # Call backward_reduce if we need to compute at least one of the gradients if any(ctx.needs_input_grad): xhat, dy, sum_dy_local, sum_xhat_dy_local = _backend.backward_reduce( - y_act, dy_act, weight, bias, ctx.eps, ctx.activation, ctx.activation_param) + y_act, + dy_act, + weight, + bias, + ctx.eps, + ctx.activation, + ctx.activation_param, + ) if ctx.distributed: sum_dy, sum_xhat_dy = InPlaceABN._reduce_backward( - sum_dy_local, sum_xhat_dy_local, ctx.group, ctx.world_size) + sum_dy_local, sum_xhat_dy_local, ctx.group, ctx.world_size + ) else: sum_dy, sum_xhat_dy = sum_dy_local, sum_xhat_dy_local else: @@ -125,10 +152,12 @@ def backward(ctx, dy_act, _drunning_mean, _drunning_var): if ctx.needs_input_grad[0]: if ctx.training: # This overwrites dy with dx - _backend.backward_train(xhat, dy, var, count, sum_dy, sum_xhat_dy, weight, ctx.eps) + _backend.backward_train( + xhat, dy, var, count, sum_dy, sum_xhat_dy, weight, ctx.eps + ) dx = dy else: - dx = _backend.backward_test(dy_act, var, weight, ctx.eps) + dx = _backend.backward_test(dy, var, weight, ctx.eps) else: dx = None @@ -148,17 +177,135 @@ def backward(ctx, dy_act, _drunning_mean, _drunning_var): return dx, dweight, dbias, None, None, None, None, None, None, None, None -def inplace_abn(x, weight, bias, running_mean, running_var, - training=True, momentum=0.1, eps=1e-05, activation="leaky_relu", activation_param=0.01): - return InPlaceABN.apply(x, weight, bias, running_mean, running_var, - training, momentum, eps, activation, activation_param, None) - - -def inplace_abn_sync(x, weight, bias, running_mean, running_var, - training=True, momentum=0.1, eps=1e-05, activation="leaky_relu", activation_param=0.01, - group=distributed.group.WORLD): - return InPlaceABN.apply(x, weight, bias, running_mean, running_var, - training, momentum, eps, activation, activation_param, group) +def inplace_abn( + x: torch.Tensor, + weight: Optional[torch.Tensor], + bias: Optional[torch.Tensor], + running_mean: Optional[torch.Tensor], + running_var: Optional[torch.Tensor], + training: bool = True, + momentum: float = 0.1, + eps: float = 1e-05, + activation: str = "leaky_relu", + activation_param: float = 0.01, +): + """InPlace Activated Batch Normalization + + This applies the following per-channel combined BatchNorm + activation operation: + + x_hat = (x - mu) / sqrt(sigma^2 + eps) + x <- act(x_hat, p) * (|weight| + eps) + bias + + where: + - mu is the per-channel batch mean, or `running_mean` if `training` is `False` + - sigma^2 is the per-channel batch variance, or `running_var` if `training` is `False` + - act(., p) is the activation function specified by `activation` + - p is `activation_param`, i.e. the negative slope of Leaky ReLU or alpha + parameter of ELU + - `weight` and `bias` are the optional affine parameters + - `eps` is a small positive number + + The running statistics, if given and if `training` is `True` are updated as follows: + + running_mean <- running_mean * momentum + (1 - momentum) * mu + running_var <- running_var * momentum + (1 - momentum) * unbiased_sigma^2 + + where unbiased_sigma^2 is the unbiased batch variance + + Args: + x: Input tensor with shape N x C or N x C x S_1 x ... x S_n, which will be + overwritten with the result + weight: Tensor of affine scale parameters with shape C, or `None` + bias: Tensor of affine bias parameters with shape C, or `None` + running_mean: Running mean tensor with shape C, or `None` + running_var: Running variance tensor with shape C, or `None` + training: If `True` compute, use and update batch statistics, otherwise use + running statistics + momentum: Momentum factor applied to compute running statistics + eps: Small constant to prevent numerical issues + activation: Name of the activation function, one of: `leaky_relu`, `elu` or `identity` + activation_param: Negative slope for the `leaky_relu` activation or `alpha` + parameter for the `elu` activation + """ + if training: + samples = _count_samples(x) + if samples <= 1: + raise ValueError( + "inplace_abn is trying to compute batch statistics, but the input " + "tensor only contains a single sample per channel" + ) + + return InPlaceABN.apply( + x, + weight, + bias, + running_mean, + running_var, + training, + momentum, + eps, + activation, + activation_param, + None, + ) + + +def inplace_abn_sync( + x: torch.Tensor, + weight: Optional[torch.Tensor], + bias: Optional[torch.Tensor], + running_mean: Optional[torch.Tensor], + running_var: Optional[torch.Tensor], + training: bool = True, + momentum: float = 0.1, + eps: float = 1e-05, + activation: str = "leaky_relu", + activation_param: float = 0.01, + group=distributed.group.WORLD, +): + """InPlace Activated Batch Normalization with distributed synchronization + + This operates like `inplace_abn`, but assumes to be called by all replicas + in the given distributed group, and computes batch statistics across all of them. + Note that the input tensors can have different dimensions in each replica. + + Args: + x: Input tensor with shape N x C or N x C x S_1 x ... x S_n, which will be + overwritten with the result + weight: Tensor of affine scale parameters with shape C, or `None` + bias: Tensor of affine bias parameters with shape C, or `None` + running_mean: Running mean tensor with shape C, or `None` + running_var: Running variance tensor with shape C, or `None` + training: If `True` compute, use and update batch statistics, otherwise use + running statistics + momentum: Momentum factor applied to compute running statistics + eps: Small constant to prevent numerical issues + activation: Name of the activation function, one of: `leaky_relu`, `elu` or `identity` + activation_param: Negative slope for the `leaky_relu` activation or `alpha` + parameter for the `elu` activation + group: Distributed group to synchronize with, default is WORLD + """ + if training: + samples = _count_samples(x) + if samples <= 1: + raise ValueError( + "inplace_abn_sync is trying to compute batch statistics, but the input " + "tensor only contains a single sample per channel" + ) + + return InPlaceABN.apply( + x, + weight, + bias, + running_mean, + running_var, + training, + momentum, + eps, + activation, + activation_param, + group, + ) __all__ = ["inplace_abn", "inplace_abn_sync"] diff --git a/inplace_abn/group.py b/inplace_abn/group.py index f15aa7b..ea4faba 100644 --- a/inplace_abn/group.py +++ b/inplace_abn/group.py @@ -3,17 +3,15 @@ import torch.nn as nn -def active_group(active): +def active_group(active: bool): """Initialize a distributed group where each process can independently decide whether to participate or not - Parameters - ---------- - active : bool - Whether this process will be active in the group or not + Args: + active: Whether this process will be active in the group or not - Returns - ------- - A distributed group containing all processes that passed `active=True`, or `None` if all passed `False` + Returns: + group: A distributed group containing all processes that passed `active=True`, + or `None` if all passed `False` """ world_size = distributed.get_world_size() rank = distributed.get_rank() @@ -22,12 +20,16 @@ def active_group(active): if not hasattr(active_group, "__cache__"): active_group.__cache__ = { frozenset(range(world_size)): distributed.group.WORLD, - frozenset(): None + frozenset(): None, } # Gather active status from all workers - active = torch.tensor(rank if active else -1, dtype=torch.long, device=torch.cuda.current_device()) - active_workers = torch.empty(world_size, dtype=torch.long, device=torch.cuda.current_device()) + active = torch.tensor( + rank if active else -1, dtype=torch.long, device=torch.cuda.current_device() + ) + active_workers = torch.empty( + world_size, dtype=torch.long, device=torch.cuda.current_device() + ) distributed.all_gather(list(active_workers.unbind(0)), active) # Create and cache group if it doesn't exist yet