Skip to content

Commit

Permalink
Enable functional torch.where. (#4298)
Browse files Browse the repository at this point in the history
  • Loading branch information
gchanan authored Dec 21, 2017
1 parent 5f7c550 commit 41c9959
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/source/torch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ Indexing, Slicing, Joining, Mutating Ops
.. autofunction:: transpose
.. autofunction:: unbind
.. autofunction:: unsqueeze
.. autofunction:: where


Random sampling
Expand Down
21 changes: 21 additions & 0 deletions test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1734,6 +1734,26 @@ def as_strided(x):
gradcheck(as_strided, [x], raise_exception=True)
gradgradcheck(as_strided, [x], [Variable(torch.randn(3, 3))])

def _test_where_functional(self, t):
x = Variable(t(torch.randn(5, 5)), requires_grad=True)
y = Variable(t(torch.randn(5, 5)), requires_grad=True)
cond = Variable(t(mask_not_all_zeros((5, 5))), requires_grad=False)

def where(cond, x, y):
return torch.where(cond, x, y)

gradcheck(where, [cond, x, y], raise_exception=True)
gradgradcheck(where, [cond, x, y], [Variable(t(torch.randn(5, 5)))])

x = Variable(t(torch.randn(5, 1, 5)), requires_grad=True)
y = Variable(t(torch.randn(5, 5, 1)), requires_grad=True)
gradcheck(where, [cond, x, y], raise_exception=True)
gradgradcheck(where, [cond, x, y], [Variable(t(torch.randn(5, 5, 5)))])

def test_where_functional(self):
# TODO: .cuda() lambda
self._test_where_functional(lambda t: t)

def test_inplace_view_backprop_base(self):
# modify view and back-prop through base
root = Variable(torch.randn(2, 2), requires_grad=True)
Expand Down Expand Up @@ -2398,6 +2418,7 @@ def maybe_non_contig(tensor):
'baddbmm',
'addmv',
'addr',
'where' # argument order
}
EXCLUDE_GRADCHECK = {
}
Expand Down
29 changes: 28 additions & 1 deletion torch/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

__all__ = [
'split', 'chunk', 'stack', 'unbind', 'btriunpack', 'matmul', 'det', 'stft',
'hann_window', 'hamming_window', 'bartlett_window',
'hann_window', 'hamming_window', 'bartlett_window', 'where',
]


Expand Down Expand Up @@ -440,3 +440,30 @@ def bartlett_window(window_length, periodic=True):
return window[:-1]
else:
return window


def where(condition, x, y):
r"""Return a tensor of elements selected from either :attr:`x` or :attr:`y`, depending on :attr:`condition`.
defined as::
out_i = x_i if condition_i
y_i else
.. note::
This function only works with ``Variables``.
.. note::
The tensors :attr:`condition`, :attr:`x`, :attr:`y` must be :ref:`broadcastable <broadcasting-semantics>`.
Arguments:
condition (ByteTensor): When True (nonzero), yield x, otherwise yield y.
x (Tensor): values selected at indices where :attr:`condition` is True.
y (Tensor): values selected at indices where :attr:`condition` is False.
Returns:
Tensor: A tensor of shape equal to the broadcasted shape of :attr:`condition`, :attr:`x`, :attr:`y`
"""
# the parameter order is changed here; the functional order is the same as numpy; the
# method follows the usual torch mask semantics of x.fn(mask, y)
return torch._C._VariableBase.where(x, condition, y)

0 comments on commit 41c9959

Please sign in to comment.