Skip to content

Commit

Permalink
Update base.py
Browse files Browse the repository at this point in the history
  • Loading branch information
shi27feng authored Feb 20, 2021
1 parent 9740054 commit 7bba3b9
Showing 1 changed file with 115 additions and 0 deletions.
115 changes: 115 additions & 0 deletions nfnets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,121 @@
from typing import Optional, List, Tuple


class WSConv1d(nn.Conv1d):
r"""Applies a 1D convolution over an input signal composed of several input
planes.
In the simplest case, the output value of the layer with input size
:math:`(N, C_{\text{in}}, L)` and output :math:`(N, C_{\text{out}}, L_{\text{out}})` can be
precisely described as:
.. math::
\text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) +
\sum_{k = 0}^{C_{in} - 1} \text{weight}(C_{\text{out}_j}, k)
\star \text{input}(N_i, k)
where :math:`\star` is the valid `cross-correlation`_ operator,
:math:`N` is a batch size, :math:`C` denotes a number of channels,
:math:`L` is a length of signal sequence.
This module supports :ref:`TensorFloat32<tf32_on_ampere>`.
* :attr:`stride` controls the stride for the cross-correlation, a single
number or a one-element tuple.
* :attr:`padding` controls the amount of implicit zero-paddings on both sides
for :attr:`padding` number of points.
* :attr:`dilation` controls the spacing between the kernel points; also
known as the à trous algorithm. It is harder to describe, but this `link`_
has a nice visualization of what :attr:`dilation` does.
* :attr:`groups` controls the connections between inputs and outputs.
:attr:`in_channels` and :attr:`out_channels` must both be divisible by
:attr:`groups`. For example,
* At groups=1, all inputs are convolved to all outputs.
* At groups=2, the operation becomes equivalent to having two conv
layers side by side, each seeing half the input channels,
and producing half the output channels, and both subsequently
concatenated.
* At groups= :attr:`in_channels`, each input channel is convolved with
its own set of filters,
of size
:math:`\left\lfloor\frac{out\_channels}{in\_channels}\right\rfloor`.
Note:
Depending of the size of your kernel, several (of the last)
columns of the input might be lost, because it is a valid
`cross-correlation`_, and not a full `cross-correlation`_.
It is up to the user to add proper padding.
Note:
When `groups == in_channels` and `out_channels == K * in_channels`,
where `K` is a positive integer, this operation is also termed in
literature as depthwise convolution.
In other words, for an input of size :math:`(N, C_{in}, L_{in})`,
a depthwise convolution with a depthwise multiplier `K`, can be constructed by arguments
:math:`(C_\text{in}=C_{in}, C_\text{out}=C_{in} \times K, ..., \text{groups}=C_{in})`.
Note:
In some circumstances when using the CUDA backend with CuDNN, this operator
may select a nondeterministic algorithm to increase performance. If this is
undesirable, you can try to make the operation deterministic (potentially at
a performance cost) by setting ``torch.backends.cudnn.deterministic =
True``.
Please see the notes on :doc:`/notes/randomness` for background.
Args:
in_channels (int): Number of channels in the input image
out_channels (int): Number of channels produced by the convolution
kernel_size (int or tuple): Size of the convolving kernel
stride (int or tuple, optional): Stride of the convolution. Default: 1
padding (int or tuple, optional): Zero-padding added to both sides of
the input. Default: 0
padding_mode (string, optional): ``'zeros'``, ``'reflect'``,
``'replicate'`` or ``'circular'``. Default: ``'zeros'``
dilation (int or tuple, optional): Spacing between kernel
elements. Default: 1
groups (int, optional): Number of blocked connections from input
channels to output channels. Default: 1
bias (bool, optional): If ``True``, adds a learnable bias to the
output. Default: ``True``
Shape:
- Input: :math:`(N, C_{in}, L_{in})`
- Output: :math:`(N, C_{out}, L_{out})` where
.. math::
L_{out} = \left\lfloor\frac{L_{in} + 2 \times \text{padding} - \text{dilation}
\times (\text{kernel\_size} - 1) - 1}{\text{stride}} + 1\right\rfloor
Attributes:
weight (Tensor): the learnable weights of the module of shape
:math:`(\text{out\_channels},
\frac{\text{in\_channels}}{\text{groups}}, \text{kernel\_size})`.
The values of these weights are sampled from
:math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
:math:`k = \frac{groups}{C_\text{in} * \text{kernel\_size}}`
bias (Tensor): the learnable bias of the module of shape
(out_channels). If :attr:`bias` is ``True``, then the values of these weights are
sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
:math:`k = \frac{groups}{C_\text{in} * \text{kernel\_size}}`
Examples::
>>> m = nn.Conv1d(16, 33, 3, stride=2)
>>> input = torch.randn(20, 16, 50)
>>> output = m(input)
.. _cross-correlation:
https://en.wikipedia.org/wiki/Cross-correlation
.. _link:
https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
"""

def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros'):
super().__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding,
dilation=dilation, groups=groups, bias=bias, padding_mode=padding_mode)

nn.init.kaiming_normal_(self.weight)
self.gain = nn.Parameter(torch.ones(self.weight.size()[0], requires_grad=True))

def standardize_weight(self, eps):
var, mean = torch.var_mean(self.weight, dim=(1, 2), keepdims=True)
fan_in = torch.prod(torch.tensor(self.weight.shape))

scale = torch.rsqrt(torch.max(
var * fan_in, torch.tensor(eps).to(var.device))) * self.gain.view_as(var).to(var.device)
shift = mean * scale
return self.weight * scale - shift

def forward(self, input, eps=1e-4):
weight = self.standardize_weight(eps)
return F.conv1d(input, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)


class WSConv2d(nn.Conv2d):
"""Applies a 2D convolution over an input signal composed of several input
planes after weight normalization/standardization.
Expand Down

0 comments on commit 7bba3b9

Please sign in to comment.