From 7bba3b9a92ad40bce74753d20e4223168c3c030b Mon Sep 17 00:00:00 2001 From: Feng Shi Date: Sat, 20 Feb 2021 08:51:56 -0800 Subject: [PATCH] Update base.py --- nfnets/base.py | 115 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 115 insertions(+) diff --git a/nfnets/base.py b/nfnets/base.py index 9b05f86..34d9039 100644 --- a/nfnets/base.py +++ b/nfnets/base.py @@ -7,6 +7,121 @@ from typing import Optional, List, Tuple +class WSConv1d(nn.Conv1d): + r"""Applies a 1D convolution over an input signal composed of several input + planes. + In the simplest case, the output value of the layer with input size + :math:`(N, C_{\text{in}}, L)` and output :math:`(N, C_{\text{out}}, L_{\text{out}})` can be + precisely described as: + .. math:: + \text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) + + \sum_{k = 0}^{C_{in} - 1} \text{weight}(C_{\text{out}_j}, k) + \star \text{input}(N_i, k) + where :math:`\star` is the valid `cross-correlation`_ operator, + :math:`N` is a batch size, :math:`C` denotes a number of channels, + :math:`L` is a length of signal sequence. + This module supports :ref:`TensorFloat32`. + * :attr:`stride` controls the stride for the cross-correlation, a single + number or a one-element tuple. + * :attr:`padding` controls the amount of implicit zero-paddings on both sides + for :attr:`padding` number of points. + * :attr:`dilation` controls the spacing between the kernel points; also + known as the à trous algorithm. It is harder to describe, but this `link`_ + has a nice visualization of what :attr:`dilation` does. + * :attr:`groups` controls the connections between inputs and outputs. + :attr:`in_channels` and :attr:`out_channels` must both be divisible by + :attr:`groups`. For example, + * At groups=1, all inputs are convolved to all outputs. + * At groups=2, the operation becomes equivalent to having two conv + layers side by side, each seeing half the input channels, + and producing half the output channels, and both subsequently + concatenated. + * At groups= :attr:`in_channels`, each input channel is convolved with + its own set of filters, + of size + :math:`\left\lfloor\frac{out\_channels}{in\_channels}\right\rfloor`. + Note: + Depending of the size of your kernel, several (of the last) + columns of the input might be lost, because it is a valid + `cross-correlation`_, and not a full `cross-correlation`_. + It is up to the user to add proper padding. + Note: + When `groups == in_channels` and `out_channels == K * in_channels`, + where `K` is a positive integer, this operation is also termed in + literature as depthwise convolution. + In other words, for an input of size :math:`(N, C_{in}, L_{in})`, + a depthwise convolution with a depthwise multiplier `K`, can be constructed by arguments + :math:`(C_\text{in}=C_{in}, C_\text{out}=C_{in} \times K, ..., \text{groups}=C_{in})`. + Note: + In some circumstances when using the CUDA backend with CuDNN, this operator + may select a nondeterministic algorithm to increase performance. If this is + undesirable, you can try to make the operation deterministic (potentially at + a performance cost) by setting ``torch.backends.cudnn.deterministic = + True``. + Please see the notes on :doc:`/notes/randomness` for background. + Args: + in_channels (int): Number of channels in the input image + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): Zero-padding added to both sides of + the input. Default: 0 + padding_mode (string, optional): ``'zeros'``, ``'reflect'``, + ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` + dilation (int or tuple, optional): Spacing between kernel + elements. Default: 1 + groups (int, optional): Number of blocked connections from input + channels to output channels. Default: 1 + bias (bool, optional): If ``True``, adds a learnable bias to the + output. Default: ``True`` + Shape: + - Input: :math:`(N, C_{in}, L_{in})` + - Output: :math:`(N, C_{out}, L_{out})` where + .. math:: + L_{out} = \left\lfloor\frac{L_{in} + 2 \times \text{padding} - \text{dilation} + \times (\text{kernel\_size} - 1) - 1}{\text{stride}} + 1\right\rfloor + Attributes: + weight (Tensor): the learnable weights of the module of shape + :math:`(\text{out\_channels}, + \frac{\text{in\_channels}}{\text{groups}}, \text{kernel\_size})`. + The values of these weights are sampled from + :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where + :math:`k = \frac{groups}{C_\text{in} * \text{kernel\_size}}` + bias (Tensor): the learnable bias of the module of shape + (out_channels). If :attr:`bias` is ``True``, then the values of these weights are + sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where + :math:`k = \frac{groups}{C_\text{in} * \text{kernel\_size}}` + Examples:: + >>> m = nn.Conv1d(16, 33, 3, stride=2) + >>> input = torch.randn(20, 16, 50) + >>> output = m(input) + .. _cross-correlation: + https://en.wikipedia.org/wiki/Cross-correlation + .. _link: + https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md + """ + + def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros'): + super().__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding, + dilation=dilation, groups=groups, bias=bias, padding_mode=padding_mode) + + nn.init.kaiming_normal_(self.weight) + self.gain = nn.Parameter(torch.ones(self.weight.size()[0], requires_grad=True)) + + def standardize_weight(self, eps): + var, mean = torch.var_mean(self.weight, dim=(1, 2), keepdims=True) + fan_in = torch.prod(torch.tensor(self.weight.shape)) + + scale = torch.rsqrt(torch.max( + var * fan_in, torch.tensor(eps).to(var.device))) * self.gain.view_as(var).to(var.device) + shift = mean * scale + return self.weight * scale - shift + + def forward(self, input, eps=1e-4): + weight = self.standardize_weight(eps) + return F.conv1d(input, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) + + class WSConv2d(nn.Conv2d): """Applies a 2D convolution over an input signal composed of several input planes after weight normalization/standardization.