Skip to content

Commit

Permalink
Add F.normalize (pytorch#1467)
Browse files Browse the repository at this point in the history
  • Loading branch information
szagoruyko authored and apaszke committed May 7, 2017
1 parent 23b556e commit 6d693fe
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 0 deletions.
5 changes: 5 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -823,6 +823,11 @@ def test_pad(self):
inputs = Variable(torch.randn(1, 2, 3, 4, 4), requires_grad=True)
self.assertTrue(gradcheck(lambda x: F.pad(x, (1, 1, 1, 1, 1, 1), mode='replicate'), (inputs,)))

def test_normalize(self):
inputs = Variable(torch.randn(1, 3, 4, 4), requires_grad=True)
self.assertTrue(gradcheck(lambda x: F.normalize(x, p=1, dim=-1), (inputs,)))
self.assertTrue(gradcheck(lambda x: F.normalize(x, p=2, dim=-2), (inputs,)))

def _test_maxpool_indices(self, num_dim, type=torch.FloatTensor):
def expected_indices(dim):
if dim == 1:
Expand Down
22 changes: 22 additions & 0 deletions torch/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,3 +744,25 @@ def triplet_margin_loss(anchor, positive, negative, margin=1.0, p=2, eps=1e-6, s
dist_hinge = torch.clamp(margin + d_p - d_n, min=0.0)
loss = torch.mean(dist_hinge)
return loss


def normalize(input, p=2, dim=1, eps=1e-12):
r"""Performs :math:`L_p` normalization of inputs over specified dimension.
Does:
.. math::
v = \frac{v}{\max(\lVert v \rVert_p, \epsilon)}
for each subtensor v over dimension dim of input. Each subtensor is flattened into a vector,
i.e. :math:`\lVert v \rVert_p` is not a matrix norm.
With default arguments normalizes over the second dimension with Euclidean norm.
Args:
input: input tensor of any shape
p (float): the exponent value in the norm formulation
dim (int): the dimension to reduce
eps (float): small value to avoid division by zero
"""
return input / input.norm(p, dim).clamp(min=eps).expand_as(input)

0 comments on commit 6d693fe

Please sign in to comment.