From 72ae924fad548769c0ba2e83c37b00963ca6bcb7 Mon Sep 17 00:00:00 2001 From: Heitor Schueroff Date: Fri, 21 May 2021 08:35:25 -0700 Subject: [PATCH] Added sublist support for torch.einsum (#56625) Summary: This PR adds an alternative way of calling `torch.einsum`. Instead of specifying the subscripts as letters in the `equation` parameter, one can now specify the subscripts as a list of integers as in `torch.einsum(operand1, subscripts1, operand2, subscripts2, ..., [subscripts_out])`. This would be equivalent to `torch.einsum(',,...,->[]', operand1, operand2, ...)` TODO - [x] Update documentation - [x] Add more error checking - [x] Update tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/56625 Reviewed By: zou3519 Differential Revision: D28062616 Pulled By: heitorschueroff fbshipit-source-id: ec50ad34f127210696e7c545e4c0675166f127dc --- aten/src/ATen/native/Linear.cpp | 53 ++++++------ test/test_jit.py | 27 +++++-- test/test_linalg.py | 138 ++++++++++++++++++++------------ torch/functional.py | 41 +++++++++- 4 files changed, 172 insertions(+), 87 deletions(-) diff --git a/aten/src/ATen/native/Linear.cpp b/aten/src/ATen/native/Linear.cpp index 28d58743bbb83a..7f9a110b7a5ce3 100644 --- a/aten/src/ATen/native/Linear.cpp +++ b/aten/src/ATen/native/Linear.cpp @@ -148,12 +148,12 @@ static Tensor sumproduct_pair(const Tensor& left_, const Tensor& right_, IntArra namespace { -bool einsum_check_label(char label) { +bool einsum_check_label(unsigned char label) { return std::isalpha(label); } -int einsum_label_to_index(char label) { - constexpr int NUM_OF_LETTERS = 'z' - 'a' + 1; +uint8_t einsum_label_to_index(unsigned char label) { + constexpr uint8_t NUM_OF_LETTERS = 'z' - 'a' + 1; return std::islower(label) ? label - 'a' : NUM_OF_LETTERS + (label - 'A'); } @@ -169,7 +169,7 @@ Tensor einsum(c10::string_view equation, TensorList operands) { checkDeviceType("einsum():", operands, operands[0].device().type()); // Code used to identify ELLIPSIS ("...") - constexpr int ELLIPSIS = '.'; + constexpr uint8_t ELLIPSIS = 52; // Find arrow (->) to split equation into lhs and rhs const auto arrow_pos = equation.find("->"); @@ -177,13 +177,14 @@ Tensor einsum(c10::string_view equation, TensorList operands) { const auto num_ops = operands.size(); - // Convert labels for input operands into an index in [0, 25] and store + // Convert labels for input operands into an index in [0, 52) and store // them in op_labels for each operand along with ELLIPSIS if present. - std::vector> op_labels(num_ops); + std::vector> op_labels(num_ops); bool found_ell = false; std::size_t curr_op = 0; for (auto i = decltype(lhs.length()){0}; i < lhs.length(); ++i) { - switch (lhs[i]) { + const unsigned char label = lhs[i]; + switch (label) { case ' ': // Ignore spaces break; @@ -217,12 +218,11 @@ Tensor einsum(c10::string_view equation, TensorList operands) { default: // Parse label TORCH_CHECK( - einsum_check_label(lhs[i]), - "einsum(): operand subscript must be in [a-zA-Z] but found ", - lhs[i], - " for operand ", - curr_op); - op_labels[curr_op].push_back(einsum_label_to_index(lhs[i])); + einsum_check_label(label), + "einsum(): invalid subscript given at index ", + i, + " in the equation string, subscripts must be in [a-zA-Z]"); + op_labels[curr_op].push_back(einsum_label_to_index(label)); } } @@ -231,8 +231,8 @@ Tensor einsum(c10::string_view equation, TensorList operands) { "einsum(): more operands were provided than specified in the equation"); // Labels must be within [a-zA-Z]. - constexpr int TOTAL_LABELS = 52; - std::vector label_count(TOTAL_LABELS, 0); + constexpr uint8_t TOTAL_LABELS = 52; + std::vector label_count(TOTAL_LABELS, 0); // The maximum number of dimensions covered by any ellipsis, needed when // unsqueezing missing dimensions from operands to permute and broadcast @@ -244,7 +244,7 @@ Tensor einsum(c10::string_view equation, TensorList operands) { for(const auto i : c10::irange(num_ops)) { const auto operand = operands[i]; const auto labels = op_labels[i]; - const int64_t ndims = operand.dim(); + const auto ndims = operand.dim(); int64_t nlabels = labels.size(); bool has_ellipsis = false; @@ -295,7 +295,8 @@ Tensor einsum(c10::string_view equation, TensorList operands) { // Parse explicit output const auto rhs = equation.substr(arrow_pos + 2); for (auto i = decltype(rhs.length()){0}; i < rhs.length(); ++i) { - switch (rhs[i]) { + const unsigned char label = rhs[i]; + switch (label) { case ' ': // Ignore spaces break; @@ -316,21 +317,21 @@ Tensor einsum(c10::string_view equation, TensorList operands) { default: TORCH_CHECK( - einsum_check_label(rhs[i]), - "einsum(): subscripts must be in [a-zA-Z] but found ", - rhs[i], - " for the output"); - const auto label = einsum_label_to_index(rhs[i]); + einsum_check_label(label), + "einsum(): invalid subscript given at index ", + lhs.size() + 2 + i, + " in the equation string, subscripts must be in [a-zA-Z]"); + const auto index = einsum_label_to_index(label); TORCH_CHECK( // Ensure label appeared at least once for some input operand and at // most once for the output - label_count[label] > 0 && label_perm_index[label] == -1, + label_count[index] > 0 && label_perm_index[index] == -1, "einsum(): output subscript ", - rhs[i], - label_perm_index[label] > -1 + label, + label_perm_index[index] > -1 ? " appears more than once in the output" : " does not appear in the equation for any input operand"); - label_perm_index[label] = perm_index++; + label_perm_index[index] = perm_index++; } } } diff --git a/test/test_jit.py b/test/test_jit.py index 9118726337f637..ba6ef2e34b3cc6 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -61,6 +61,7 @@ from torch.jit.annotations import BroadcastingList2, BroadcastingList3, Any # noqa: F401 from torch.nn.utils.rnn import PackedSequence from torch.testing import FileCheck +from torch.testing._internal.common_utils import make_tensor import torch.autograd.profiler import torch.cuda import torch.jit @@ -1580,15 +1581,27 @@ def fn_out(real, img, out): self.checkScript(fn_out, (real, img, out, )) def test_einsum(self): - def outer(x, y): + def check(fn, jitted, *args): + self.assertGraphContains(jitted.graph, kind='aten::einsum') + self.assertEqual(fn(*args), jitted(*args)) + + def equation_format(x, y): return torch.einsum('i,j->ij', (x, y)) - traced = torch.jit.trace(outer, (torch.randn(4), torch.randn(5))) - script = torch.jit.script(outer) - x, y = torch.randn(10), torch.randn(2) - for fn in [traced, script]: - self.assertGraphContains(fn.graph, kind='aten::einsum') - self.assertEqual(fn(x, y), outer(x, y)) + def sublist_format(x, y): + return torch.einsum(x, [0], y, [1], [0, 1]) + + # Sublist format cannot be scripted because it is + # a NumPy API only feature + with self.assertRaises(RuntimeError): + torch.jit.script(sublist_format) + + x = make_tensor((5,), 'cpu', torch.float32) + y = make_tensor((10,), 'cpu', torch.float32) + + check(equation_format, torch.jit.script(equation_format), x, y) + check(equation_format, torch.jit.trace(equation_format, (x, y)), x, y) + check(sublist_format, torch.jit.trace(sublist_format, (x, y)), x, y) def test_python_ivalue(self): # Test if pure python object can be hold as IValue and conversion diff --git a/test/test_linalg.py b/test/test_linalg.py index 3f0f68c4c56aab..0fdd82218be188 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -4453,23 +4453,24 @@ def test_qr_error_cases(self, device, dtype): @dtypes(torch.double, torch.cdouble) def test_einsum(self, device, dtype): - def check(equation, *operands): - ref = np.einsum(equation, *[operand.cpu().numpy() for operand in operands]) - res = torch.einsum(equation, operands) - self.assertEqual(res.cpu(), torch.from_numpy(np.array(ref))) + def check(*args): + np_args = [arg.cpu().numpy() if isinstance(arg, torch.Tensor) else arg for arg in args] + ref = np.einsum(*np_args) + res = torch.einsum(*args) + self.assertEqual(torch.from_numpy(np.array(ref)), res) # Test cases from https://gist.github.com/rockt/15ee013889d65342088e9260a377dc8f - x = torch.rand(5, device=device, dtype=dtype) - y = torch.rand(7, device=device, dtype=dtype) - A = torch.randn(3, 5, device=device, dtype=dtype) - B = torch.randn(2, 5, device=device, dtype=dtype) - C = torch.randn(2, 3, 5, device=device, dtype=dtype) - D = torch.randn(2, 5, 7, device=device, dtype=dtype) - E = torch.randn(7, 9, device=device, dtype=dtype) - F = torch.randn(2, 3, 3, 5, device=device, dtype=dtype) - G = torch.randn(5, 4, 6, device=device, dtype=dtype) - H = torch.randn(4, 4, device=device, dtype=dtype) - I = torch.rand(2, 3, 2, device=device, dtype=dtype) + x = make_tensor((5,), device, dtype) + y = make_tensor((7,), device, dtype) + A = make_tensor((3, 5), device, dtype) + B = make_tensor((2, 5), device, dtype) + C = make_tensor((2, 3, 5), device, dtype) + D = make_tensor((2, 5, 7), device, dtype) + E = make_tensor((7, 9), device, dtype) + F = make_tensor((2, 3, 3, 5), device, dtype) + G = make_tensor((5, 4, 6), device, dtype) + H = make_tensor((4, 4), device, dtype) + I = make_tensor((2, 3, 2), device, dtype) # Vector operations check('i->', x) # sum @@ -4500,20 +4501,20 @@ def check(equation, *operands): check("ii", H) # trace check("ii->i", H) # diagonal check('iji->j', I) # non-contiguous trace - check('ngrg...->nrg...', torch.rand((2, 1, 3, 1, 4), device=device, dtype=dtype)) + check('ngrg...->nrg...', make_tensor((2, 1, 3, 1, 4), device, dtype)) # Test ellipsis check("i...->...", H) check("ki,...k->i...", A.t(), B) check("k...,jk->...", A.t(), B) check('...ik, ...j -> ...ij', C, x) - check('Bik,k...j->i...j', C, torch.rand(5, 3, device=device, dtype=dtype)) - check('i...j, ij... -> ...ij', C, torch.rand(2, 5, 2, 3, device=device, dtype=dtype)) + check('Bik,k...j->i...j', C, make_tensor((5, 3), device, dtype)) + check('i...j, ij... -> ...ij', C, make_tensor((2, 5, 2, 3), device, dtype)) # torch.bilinear with noncontiguous tensors - l = torch.randn(10, 5, device=device, dtype=dtype).transpose(0, 1) - r = torch.randn(20, 5, device=device, dtype=dtype).transpose(0, 1) - w = torch.randn(15, 10, 20, device=device, dtype=dtype) + l = make_tensor((5, 10), device, dtype, noncontiguous=True) + r = make_tensor((5, 20), device, dtype, noncontiguous=True) + w = make_tensor((15, 10, 20), device, dtype) check("bn,anm,bm->ba", l, w, r) # with strided tensors @@ -4553,7 +4554,7 @@ def check(equation, *operands): labels.insert(ell_index, "...") equation += ''.join(labels) + ',' - ops.append(torch.rand(sizes, device=device, dtype=dtype)) + ops.append(make_tensor(sizes, device, dtype)) equation = equation[:-1] # Test with implicit output @@ -4571,8 +4572,8 @@ def check(equation, *operands): def test_einsum_corner_cases(self, device): def check(equation, *operands, expected_output): - tensors = [torch.tensor(operand, dtype=torch.float32, device=device) if not isinstance(operand, tuple) - else torch.rand(operand, dtype=torch.float32, device=device) for operand in operands] + tensors = [torch.tensor(operand, device=device, dtype=torch.float32) if not isinstance(operand, tuple) + else make_tensor(operand, device, torch.float32) for operand in operands] output = torch.einsum(equation, tensors) self.assertEqual(output, torch.tensor(expected_output, dtype=torch.float32, device=device)) @@ -4610,33 +4611,68 @@ def check(equation, *operands, expected_output): check('a...b->ab', [[[1], [2]], [[3], [4]]], expected_output=[[3], [7]]) def test_einsum_error_cases(self, device): - def check(equation, operands, regex, exception=RuntimeError): - with self.assertRaisesRegex(exception, r'einsum\(\): ' + regex): - torch.einsum(equation, operands) - - x = torch.rand(2) - y = torch.rand(2, 3) - - check('', [], r'must provide at least one operand') - check('. ..', [x], r'found \'.\' for operand 0 that is not part of any ellipsis') - check('... ...', [x], r'found \'.\' for operand 0 for which an ellipsis was already found') - check('1', [x], r'operand subscript must be in \[a-zA-Z\] but found 1 for operand 0') - check(',', [x], r'fewer operands were provided than specified in the equation') - check('', [x, x], r'more operands were provided than specified in the equation') - check('', [x], r'the number of subscripts in the equation \(0\) does not match the number ' - r'of dimensions \(1\) for operand 0 and no ellipsis was given') - check('ai', [x], r'the number of subscripts in the equation \(2\) does not match the number ' - r'of dimensions \(1\) for operand 0 and no ellipsis was given') - check('ai...', [x], r'the number of subscripts in the equation \(2\) is more than the number ' - r'of dimensions \(1\) for operand 0') - check('a->... .', [x], r'found \'.\' for output but an ellipsis \(...\) was already found') - check('a->..', [x], r'found \'.\' for output that is not part of any ellipsis \(...\)') - check('a->1', [x], r'subscripts must be in \[a-zA-Z\] but found 1 for the output') - check('a->aa', [x], r'output subscript a appears more than once in the output') - check('a->i', [x], r'output subscript i does not appear in the equation for any input operand') - check('aa', [y], r'subscript a is repeated for operand 0 but the sizes don\'t match, 3 != 2') - check('a, ba', [x, y], r'operands do not broadcast with remapped shapes \[original->remapped\]: ' - r'\[2\]->\[1, 2\] \[2, 3\]->\[2, 3\]') + def check(*args, regex, exception=RuntimeError): + with self.assertRaisesRegex(exception, r'einsum\(\):.*' + regex): + torch.einsum(*args) + + x = make_tensor((2,), device, torch.float32) + y = make_tensor((2, 3), device, torch.float32) + + check('', [], regex=r'at least one operand', exception=ValueError) + check('. ..', [x], regex=r'found \'.\' for operand 0 that is not part of any ellipsis') + check('... ...', [x], regex=r'found \'.\' for operand 0 for which an ellipsis was already found') + check('1', [x], regex=r'invalid subscript given at index 0') + check(',', [x], regex=r'fewer operands were provided than specified in the equation') + check('', [x, x], regex=r'more operands were provided than specified in the equation') + check('', [x], regex=r'the number of subscripts in the equation \(0\) does not match the number ' + r'of dimensions \(1\) for operand 0 and no ellipsis was given') + check('ai', [x], regex=r'the number of subscripts in the equation \(2\) does not match the number ' + r'of dimensions \(1\) for operand 0 and no ellipsis was given') + check('ai...', [x], regex=r'the number of subscripts in the equation \(2\) is more than the number ' + r'of dimensions \(1\) for operand 0') + check('a->... .', [x], regex=r'found \'.\' for output but an ellipsis \(...\) was already found') + check('a->..', [x], regex=r'found \'.\' for output that is not part of any ellipsis \(...\)') + check('a->1', [x], regex=r'invalid subscript given at index 3') + check('a->aa', [x], regex=r'output subscript a appears more than once in the output') + check('a->i', [x], regex=r'output subscript i does not appear in the equation for any input operand') + check('aa', [y], regex=r'subscript a is repeated for operand 0 but the sizes don\'t match, 3 != 2') + check('a, ba', [x, y], regex=r'operands do not broadcast with remapped shapes \[original->remapped\]: ' + r'\[2\]->\[1, 2\] \[2, 3\]->\[2, 3\]') + + check(x, [-1], regex=r'not within the valid range \[0, 52\)', exception=ValueError) + check(x, [52], regex=r'not within the valid range \[0, 52\)', exception=ValueError) + + @dtypes(torch.double, torch.cdouble) + def test_einsum_sublist_format(self, device, dtype): + def check(*args): + np_args = [arg.cpu().numpy() if isinstance(arg, torch.Tensor) else arg for arg in args] + ref = np.einsum(*np_args) + res = torch.einsum(*args) + self.assertEqual(torch.from_numpy(np.array(ref)), res) + + x = make_tensor((5,), device, dtype) + y = make_tensor((7,), device, dtype) + A = make_tensor((3, 5), device, dtype) + B = make_tensor((2, 5), device, dtype) + C = make_tensor((2, 1, 3, 1, 4), device, dtype) + + check(x, [0]) + check(x, [0], []) + check(x, [0], y, [1], [0, 1]) + check(A, [0, 1], [1, 0]) + check(A, [0, 1], x, [1], [0]) + check(A, [0, 1], B, [2, 1]) + check(A, [0, 1], B, [2, 1], [0, 2]) + check(C, [0, 1, 2, 1, Ellipsis], [0, 2, 1, Ellipsis]) + check(A.t(), [0, 1], B, [Ellipsis, 0]) + check(A.t(), [0, 1], B, [Ellipsis, 0], [1, Ellipsis]) + check(A.t(), [0, Ellipsis], B, [1, 0], [Ellipsis]) + + # torch.bilinear with noncontiguous tensors + l = make_tensor((5, 10), device, dtype, noncontiguous=True) + r = make_tensor((5, 20), device, dtype, noncontiguous=True) + w = make_tensor((15, 10, 20), device, dtype) + check(l, [40, 41], w, [2, 41, 50], r, [40, 50], [40, 2]) def triangular_solve_test_helper(self, A_dims, b_dims, upper, unitriangular, device, dtype): diff --git a/torch/functional.py b/torch/functional.py index dc92caa60add76..0ca423e717b4cc 100644 --- a/torch/functional.py +++ b/torch/functional.py @@ -155,7 +155,7 @@ def split(tensor, split_size_or_sections, dim=0): return tensor.split(split_size_or_sections, dim) -def einsum(equation, *operands): +def einsum(*args): r"""einsum(equation, *operands) -> Tensor Sums the product of the elements of the input :attr:`operands` along dimensions specified using a notation @@ -171,7 +171,7 @@ def einsum(equation, *operands): Equation: - The :attr:`equation` string specifies the subscripts (lower case letters `['a', 'z']`) for each dimension of + The :attr:`equation` string specifies the subscripts (letters in `[a-zA-Z]`) for each dimension of the input :attr:`operands` in the same order as the dimensions, separating subcripts for each operand by a comma (','), e.g. `'ij,jk'` specify subscripts for two 2D operands. The dimensions labeled with the same subscript must be broadcastable, that is, their size must either match or be `1`. The exception is if a subscript is @@ -211,7 +211,7 @@ def einsum(equation, *operands): Args: equation (string): The subscripts for the Einstein summation. - operands (Tensor): The operands to compute the Einstein sum of. + operands (List[Tensor]): The tensors to compute the Einstein summation of. Examples:: @@ -259,8 +259,43 @@ def einsum(equation, *operands): tensor([[-0.3430, -5.2405, 0.4494], [ 0.3311, 5.5201, -3.0356]]) """ + if len(args) < 2: + raise ValueError('einsum(): must specify the equation string and at least one operand, ' + 'or at least one operand and its subscripts list') + + equation = None + operands = None + + if isinstance(args[0], torch.Tensor): + # Convert the subscript list format which is an interleaving of operand and its subscripts + # list with an optional output subscripts list at the end (see documentation for more details on this) + # to the equation string format by creating the equation string from the subscripts list and grouping the + # input operands into a tensorlist (List[Tensor]). + def parse_subscript(n: int) -> str: + if n == Ellipsis: + return '...' + if n >= 0 and n < 26: + return chr(n + ord('a')) + if n >= 26 and n < 52: + return chr(n - 26 + ord('A')) + raise ValueError('einsum(): subscript in subscript list is not within the valid range [0, 52)') + + # Parse subscripts for input operands + equation = ','.join(''.join(parse_subscript(s) for s in l) for l in args[1::2]) + + # Parse optional output subscripts (provided when the number of arguments is odd) + if len(args) % 2 == 1: + equation += '->' + ''.join(parse_subscript(s) for s in args[-1]) + operands = args[:-1:2] + else: + operands = args[::2] + else: + equation = args[0] + operands = args[1:] + if has_torch_function(operands): return handle_torch_function(einsum, operands, equation, *operands) + if len(operands) == 1 and isinstance(operands[0], (list, tuple)): # the old interface of passing the operands as one list argument _operands = operands[0]