Skip to content

Commit

Permalink
Add torch.matmul function. (pytorch#1780)
Browse files Browse the repository at this point in the history
* Add torch.matmul function.

Includes test_torch, test_autograd and docs changes.

* Add __all__ to functional so imports are accidentally imported.

* Include unbind in __all__.

* Add matmul case for when one argument is 1-dimensional and the other
at least 3-dimensional.

* Add squeeze_ to Variable.

* Use squeeze_ instead of squeeze for matmul.
  • Loading branch information
gchanan authored and soumith committed Jun 14, 2017
1 parent 9fd354e commit 4e35652
Show file tree
Hide file tree
Showing 9 changed files with 215 additions and 81 deletions.
1 change: 1 addition & 0 deletions docs/source/tensors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ view of a storage and defines numeric operations on it.
.. automethod:: masked_scatter_
.. automethod:: masked_fill_
.. automethod:: masked_select
.. automethod:: matmul
.. automethod:: max
.. automethod:: mean
.. automethod:: median
Expand Down
1 change: 1 addition & 0 deletions docs/source/torch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ BLAS and LAPACK Operations
.. autofunction:: ger
.. autofunction:: gesv
.. autofunction:: inverse
.. autofunction:: matmul
.. autofunction:: mm
.. autofunction:: mv
.. autofunction:: orgqr
Expand Down
35 changes: 30 additions & 5 deletions test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1027,7 +1027,14 @@ def test_functional_blas(self):
def compare(fn, *args):
unpacked_args = tuple(arg.data if isinstance(arg, Variable) else arg
for arg in args)
self.assertEqual(fn(*args).data, fn(*unpacked_args))
unpacked_result = fn(*unpacked_args)
packed_result = fn(*args).data
# if non-Variable torch function returns a scalar, compare to scalar
if not torch.is_tensor(unpacked_result):
assert packed_result.dim() == 1
assert packed_result.nelement() == 1
packed_result = packed_result[0]
self.assertEqual(packed_result, unpacked_result)

def test_blas_add(fn, x, y, z):
# Checks all signatures
Expand Down Expand Up @@ -1056,6 +1063,14 @@ def test_blas(fn, x, y):
Variable(torch.randn(6)))
test_blas_add(torch.addr, Variable(torch.randn(5, 6)),
Variable(torch.randn(5)), Variable(torch.randn(6)))
test_blas(torch.matmul, Variable(torch.randn(6)), Variable(torch.randn(6)))
test_blas(torch.matmul, Variable(torch.randn(10, 4)), Variable(torch.randn(4)))
test_blas(torch.matmul, Variable(torch.randn(5)), Variable(torch.randn(5, 6)))
test_blas(torch.matmul, Variable(torch.randn(2, 10)), Variable(torch.randn(10, 4)))
test_blas(torch.matmul, Variable(torch.randn(5, 2, 10)), Variable(torch.randn(5, 10, 4)))
test_blas(torch.matmul, Variable(torch.randn(3, 5, 2, 10)), Variable(torch.randn(3, 5, 10, 4)))
test_blas(torch.matmul, Variable(torch.randn(3, 5, 2, 10)), Variable(torch.randn(10)))
test_blas(torch.matmul, Variable(torch.randn(10)), Variable(torch.randn(3, 5, 10, 4)))

def test_save_none_for_backward(self):
test_case = self
Expand Down Expand Up @@ -1427,7 +1442,7 @@ class dont_convert(tuple):
(Inverse, (), ((S, S),), '', (), [skipIfNoLapack]),
(Gesv, (), ((S, S), (S, S)), '', (), [skipIfNoLapack]),
(Clone, (), ((S, M, S),)),
(Squeeze, (), ((S, 1, M, 1),)),
(Squeeze, (), ((S, 1, M, 1), None)),
# TODO: enable neg dim checks
(Squeeze, (), ((S, 1, M, 1), 1), 'dim'),
(Unsqueeze, (), ((S, M, S), 0), '0'),
Expand Down Expand Up @@ -1543,6 +1558,17 @@ class dont_convert(tuple):
('addr', (S, M), ((S,), (M,)),),
('addr', (S, M), (0.2, 0.6, (S,), (M,)), 'coef'),
('dot', (L,), ((L,),),),
('mm', (S, M), ((M, S),)),
('bmm', (M, S, M), ((M, M, S),)),
('mv', (S, M), ((M,),)),
('ger', (S,), ((M,),)),
('matmul', (L,), ((L,),),),
('matmul', (S, M), ((M,),), "2d_1d"),
('matmul', (M, ), ((M, S),), "1d_2d"),
('matmul', (S, M), ((M, S),), "2d_2d"),
('matmul', (S, S, M, M), ((S, S, M, S),), "4d_4d"),
('matmul', (S, S, M, M), ((M,),), "4d_1d"),
('matmul', (M,), ((S, S, M, S),), "1d_4d"),
('addcmul', (S, S), ((S, S), (S, S))),
('addcmul', (S, S), (0.5, (S, S), (S, S)), 'scale'),
('addcdiv', (S, S), ((S, S), (S, S))),
Expand Down Expand Up @@ -1589,8 +1615,7 @@ class dont_convert(tuple):
('masked_scatter_', (M, M), (Variable(torch.ByteTensor(M, M).bernoulli_(), requires_grad=False), (M, M))),
]
# TODO: mm, bmm, mv, ger
# TODO: max, min with dim (problem with indices)
# TODO: mode, median, sort, kthvalue, topk (problem with indices)
# TODO: sort, topk (problem with indices)
# TODO: indexAdd, indexCopy, indexFill
# TODO: resize, resize_as (tensors only have resize_ and resize_as_)
# TODO: clamp with min/max
Expand Down Expand Up @@ -1686,7 +1711,7 @@ def apply_inplace_fn(*input):
if not isinstance(output, tuple):
output = (output,)
inplace_input = deepcopy(input)
inplace_input_copy = tuple(i + 0 for i in inplace_input)
inplace_input_copy = tuple(i + 0 if i is not None else None for i in inplace_input)
inplace_output = apply_inplace_fn(*inplace_input_copy)
if not isinstance(inplace_output, tuple):
inplace_output = (inplace_output,)
Expand Down
66 changes: 46 additions & 20 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1239,35 +1239,61 @@ def _test_broadcast_batched_matmul(self, cast):
full_batch_dims = [random.randint(1, 3) for i in range(random.randint(1, 3))]
(batch_dims_small, _, _) = self._select_broadcastable_dims(full_batch_dims)

def verify_batched_matmul(full_lhs):
lhs_dims = [n_dim, m_dim]
rhs_dims = [m_dim, p_dim]
full_mat_dims = lhs_dims if full_lhs else rhs_dims
small_mat_dims = rhs_dims if full_lhs else lhs_dims

small = cast(torch.randn(*(batch_dims_small + small_mat_dims)).float())
dim0 = cast(torch.randn(*(small_mat_dims)).float())
def verify_batched_matmul(full_lhs, one_dimensional):
if not one_dimensional:
lhs_dims = [n_dim, m_dim]
rhs_dims = [m_dim, p_dim]
result_dims = [n_dim, p_dim]
else:
lhs_dims = [n_dim, m_dim] if full_lhs else [m_dim]
rhs_dims = [m_dim, p_dim] if not full_lhs else [m_dim]
result_dims = [n_dim] if full_lhs else [p_dim]

lhs_mat_dims = lhs_dims if len(lhs_dims) != 1 else [1, m_dim]
rhs_mat_dims = rhs_dims if len(rhs_dims) != 1 else [m_dim, 1]
full_mat_dims = lhs_mat_dims if full_lhs else rhs_mat_dims
dim0_dims = rhs_dims if full_lhs else lhs_dims
small_dims = batch_dims_small + (rhs_mat_dims if full_lhs else lhs_mat_dims)

small = cast(torch.randn(*(small_dims)).float())
dim0 = cast(torch.randn(*(dim0_dims)).float())
full = cast(torch.randn(*(full_batch_dims + full_mat_dims)).float())
(lhsTensors, rhsTensors) = ((full,), (small, dim0)) if full_lhs else ((small, dim0), (full,))
if not one_dimensional:
(lhsTensors, rhsTensors) = ((full,), (small, dim0)) if full_lhs else ((small, dim0), (full,))
else:
(lhsTensors, rhsTensors) = ((full,), (dim0,)) if full_lhs else ((dim0,), (full,))

def maybe_squeeze_result(l, r, result):
if len(lhs_dims) == 1 and l.dim() != 1:
return result.squeeze(-2)
elif len(rhs_dims) == 1 and r.dim() != 1:
return result.squeeze(-1)
else:
return result

for lhs in lhsTensors:
lhs_expanded = lhs.expand(*(torch.Size(full_batch_dims) + torch.Size(lhs_dims)))
lhs_expanded_matmul_fn = getattr(lhs_expanded, "__matmul__")
lhs_expanded = lhs.expand(*(torch.Size(full_batch_dims) + torch.Size(lhs_mat_dims)))
lhs_expanded_matmul_fn = getattr(lhs_expanded, "matmul")
for rhs in rhsTensors:
rhs_expanded = rhs.expand(*(torch.Size(full_batch_dims) + torch.Size(rhs_dims)))
truth = lhs_expanded_matmul_fn(rhs_expanded)
rhs_expanded = ((rhs if len(rhs_dims) != 1 else rhs.unsqueeze(-1)).
expand(*(torch.Size(full_batch_dims) + torch.Size(rhs_mat_dims))))
truth = maybe_squeeze_result(lhs_expanded, rhs_expanded, lhs_expanded_matmul_fn(rhs_expanded))
for l in (lhs, lhs_expanded):
for r in (rhs, rhs_expanded):
l_matmul_fn = getattr(l, "__matmul__")
result = l_matmul_fn(r)
l_matmul_fn = getattr(l, "matmul")
result = maybe_squeeze_result(l, r, l_matmul_fn(r))
self.assertEqual(truth, result)
# test torch.matmul function as well
torch_result = maybe_squeeze_result(l, r, torch.matmul(l, r))
self.assertEqual(truth, torch_result)

# compare to bmm
bmm_result = (torch.bmm(lhs_expanded.contiguous().view(-1, n_dim, m_dim),
rhs_expanded.contiguous().view(-1, m_dim, p_dim)))
self.assertEqual(truth.view(-1, n_dim, p_dim), bmm_result)
bmm_result = (torch.bmm(lhs_expanded.contiguous().view(-1, *lhs_mat_dims),
rhs_expanded.contiguous().view(-1, *rhs_mat_dims)))
self.assertEqual(truth.view(-1, *result_dims), bmm_result.view(-1, *result_dims))

verify_batched_matmul(False)
verify_batched_matmul(True)
for indices in product((True, False), repeat=2):
verify_batched_matmul(*indices)

def test_broadcast_batched_matmul(self):
self._test_broadcast_batched_matmul(self, lambda t: t)
Expand Down
2 changes: 1 addition & 1 deletion torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
__all__ = [
'typename', 'is_tensor', 'is_storage', 'set_default_tensor_type',
'set_rng_state', 'get_rng_state', 'manual_seed', 'initial_seed',
'save', 'load', 'set_printoptions', 'chunk', 'split', 'stack',
'save', 'load', 'set_printoptions', 'chunk', 'split', 'stack', 'matmul',
'DoubleStorage', 'FloatStorage', 'LongStorage', 'IntStorage',
'ShortStorage', 'CharStorage', 'ByteStorage',
'DoubleTensor', 'FloatTensor', 'LongTensor', 'IntTensor',
Expand Down
24 changes: 16 additions & 8 deletions torch/autograd/_functions/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,22 +331,30 @@ def backward(ctx, grad_output):
return grad_output


class Squeeze(Function):
class Squeeze(InplaceFunction):

@staticmethod
def forward(ctx, input, dim=None):
def forward(ctx, input, dim=None, inplace=False):
ctx.dim = dim
ctx.input_size = input.size()
if dim is not None:
result = input.squeeze(dim)
if inplace:
ctx.mark_dirty(input)
if dim is not None:
return input.squeeze_(dim)
else:
return input.squeeze_()
else:
result = input.squeeze()
ctx.mark_shared_storage((input, result))
return result
if dim is not None:
result = input.squeeze(dim)
else:
result = input.squeeze()

ctx.mark_shared_storage((input, result))
return result

@staticmethod
def backward(ctx, grad_output):
return grad_output.contiguous().view(ctx.input_size), None
return grad_output.contiguous().view(ctx.input_size), None, None


class Unsqueeze(Function):
Expand Down
22 changes: 8 additions & 14 deletions torch/autograd/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,9 @@ def renorm(self, p, dim, maxnorm):
flat_out = flat.mul(norms.expand_as(flat))
return flat_out.view(t.size()).transpose(dim, 0)

def matmul(self, other):
return torch.matmul(self, other)

@staticmethod
def _static_blas(cls, args, inplace):
num_args = len(args)
Expand Down Expand Up @@ -706,6 +709,9 @@ def chunk(self, num_chunks, dim=0):
def squeeze(self, dim=None):
return Squeeze.apply(self, dim)

def squeeze_(self, dim=None):
return Squeeze.apply(self, dim, True)

def unsqueeze(self, dim):
return Unsqueeze.apply(self, dim)

Expand Down Expand Up @@ -787,21 +793,9 @@ def __imul__(self, other):
return self.mul_(other)

def __matmul__(self, other):
dim_self = self.dim()
try:
dim_other = other.dim()
except AttributeError: # not a Variable
if not isinstance(other, Variable):
return NotImplemented
if dim_self == 1 and dim_other == 1:
return self.dot(other)
if dim_self == 2 and dim_other == 1:
return self.mv(other)
if dim_self == 1 and dim_other == 2:
return self.unsqueeze(0).mm(other).squeeze(0)
elif dim_self == 2 and dim_other == 2:
return self.mm(other)
raise ValueError("both arguments to __matmul__ need to be 1D or 2D, "
"but they are {}D and {}D".format(dim_self, dim_other))
return self.matmul(other)

def __div__(self, other):
return self.div(other)
Expand Down
104 changes: 104 additions & 0 deletions torch/functional.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
import torch
from ._utils import _range
from operator import mul
from functools import reduce

__all__ = [
'split', 'chunk', 'stack', 'unbind', 'btriunpack', 'matmul',
]


def split(tensor, split_size, dim=0):
Expand Down Expand Up @@ -113,3 +119,101 @@ def btriunpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True):
P = None

return P, L, U


def matmul(tensor1, tensor2, out=None):
"""Matrix product of two tensors.
The behavior depends on the dimensionality of the tensors as follows:
- If both tensors are 1-dimensional, the dot product (scalar) is returned.
- If both arguments are 2-dimensional, the matrix-matrix product is returned.
- If the first argument is 1-dimensional and the second argument is 2-dimensional,
a 1 is prepended to its dimension for the purpose of the matrix multiply.
After the matrix multiply, the prepended dimension is removed.
- If the first argument is 2-dimensional and the second argument is 1-dimensional,
the matrix-vector product is returned.
- If both arguments are at least 1-dimensional and at least one argument is
N-dimensional (where N > 2), then a batched matrix multiply is returned. If the first
argument is 1-dimensional, a 1 is prepended to its dimension for the purpose of the
batched matrix multiply and removed after. If the second argument is 1-dimensional, a
1 is appended to its dimension for the purpose of the batched matrix multiple and removed after.
The non-matrix (i.e. batch) dimensions are :ref:`broadcasted <broadcasting-semantics>` (and thus
must be broadcastable). For example, if :attr:`tensor1` is a `j x 1 x n x m` Tensor
and :attr:`tensor2` is a `k x m x p` Tensor, :attr:`out` will be an `j x k x n x p` Tensor.
.. note::
The 1-dimensional dot product version of this function does not support an :attr:`out` parameter.
Arguments:
tensor1 (Tensor): First tensor to be multiplied
tensor2 (Tensor): Second tensor to be multiplied
out (Tensor, optional): Output tensor
"""
dim_tensor1 = tensor1.dim()
dim_tensor2 = tensor2.dim()
if dim_tensor1 == 1 and dim_tensor2 == 1:
if out is None:
return torch.dot(tensor1, tensor2)
else:
raise ValueError("out must be None for 1-d tensor matmul, returns a scalar")
if dim_tensor1 == 2 and dim_tensor2 == 1:
if out is None:
return torch.mv(tensor1, tensor2)
else:
return torch.mv(tensor1, tensor2, out=out)
elif dim_tensor1 == 1 and dim_tensor2 == 2:
if out is None:
return torch.mm(tensor1.unsqueeze(0), tensor2).squeeze_(0)
else:
return torch.mm(tensor1.unsqueeze(0), tensor2, out=out).squeeze_(0)
elif dim_tensor1 == 2 and dim_tensor2 == 2:
if out is None:
return torch.mm(tensor1, tensor2)
else:
return torch.mm(tensor1, tensor2, out=out)
elif (dim_tensor1 >= 1 and dim_tensor2 >= 1) and (dim_tensor1 >= 3 or dim_tensor2 >= 3):
# ensure each tensor size is at least 3-dimensional
tensor1_exp_size = torch.Size((1,) * max(3 - tensor1.dim(), 0) + tensor1.size())
# rhs needs to be a separate case since we can't freely expand 1s on the rhs, but can on lhs
if dim_tensor2 == 1:
tensor2 = tensor2.unsqueeze(1)
tensor2_exp_size = torch.Size((1,) * max(3 - tensor2.dim(), 0) + tensor2.size())

# expand the batch portion (i.e. cut off matrix dimensions and expand rest)
expand_batch_portion = torch._C._infer_size(tensor1_exp_size[:-2], tensor2_exp_size[:-2])

# flatten expanded batches
tensor1_expanded = tensor1.expand(*(expand_batch_portion + tensor1_exp_size[-2:])) \
.contiguous().view(reduce(mul, expand_batch_portion), *tensor1_exp_size[-2:])
tensor2_expanded = tensor2.expand(*(expand_batch_portion + tensor2_exp_size[-2:])) \
.contiguous().view(reduce(mul, expand_batch_portion), *tensor2_exp_size[-2:])

# reshape batches back into result
total_expansion = expand_batch_portion + (tensor1_exp_size[-2], tensor2_exp_size[-1])

def maybeSqueeze(tensor):
if dim_tensor1 == 1:
return tensor.squeeze_(-2)
elif dim_tensor2 == 1:
return tensor.squeeze_(-1)
else:
return tensor

if out is None:
return maybeSqueeze(torch.bmm(tensor1_expanded, tensor2_expanded).view(*(total_expansion)))
else:
# We can only safely reshape the output if the output (after the torch.bmm call)
# is contiguous. This will happen only if:
# 1) We force it to be contiguous
# 2) The output came in as contiguous
# 3) The output came in as the wrong size (so was resized in the torch.bmm call).
#
# Even though 1) is inconsistent with other functions (e.g. torch.bmm) that will maintain
# output non-contiguity if the size is correct, we'll do it here for simplicity.
out = out.contiguous()
return (torch.bmm(tensor1_expanded, tensor2_expanded, out=out).
set_(maybeSqueeze(out.view(*(total_expansion)))))
raise ValueError("both arguments to __matmul__ need to be at least 1D, "
"but they are {}D and {}D".format(dim_tensor1, dim_tensor2))
Loading

0 comments on commit 4e35652

Please sign in to comment.