Skip to content

Commit

Permalink
Added sublist support for torch.einsum (pytorch#56625)
Browse files Browse the repository at this point in the history
Summary:
This PR adds an alternative way of calling `torch.einsum`. Instead of specifying the subscripts as letters in the `equation` parameter, one can now specify the subscripts as a list of integers as in `torch.einsum(operand1, subscripts1, operand2, subscripts2, ..., [subscripts_out])`. This would be equivalent to `torch.einsum('<subscripts1>,<subscripts2>,...,->[<subscript_out>]', operand1, operand2, ...)`

TODO
- [x] Update documentation
- [x] Add more error checking
- [x] Update tests

Pull Request resolved: pytorch#56625

Reviewed By: zou3519

Differential Revision: D28062616

Pulled By: heitorschueroff

fbshipit-source-id: ec50ad34f127210696e7c545e4c0675166f127dc
  • Loading branch information
heitorschueroff authored and facebook-github-bot committed May 21, 2021
1 parent fc804b5 commit 72ae924
Show file tree
Hide file tree
Showing 4 changed files with 172 additions and 87 deletions.
53 changes: 27 additions & 26 deletions aten/src/ATen/native/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,12 +148,12 @@ static Tensor sumproduct_pair(const Tensor& left_, const Tensor& right_, IntArra

namespace {

bool einsum_check_label(char label) {
bool einsum_check_label(unsigned char label) {
return std::isalpha(label);
}

int einsum_label_to_index(char label) {
constexpr int NUM_OF_LETTERS = 'z' - 'a' + 1;
uint8_t einsum_label_to_index(unsigned char label) {
constexpr uint8_t NUM_OF_LETTERS = 'z' - 'a' + 1;
return std::islower(label) ? label - 'a' : NUM_OF_LETTERS + (label - 'A');
}

Expand All @@ -169,21 +169,22 @@ Tensor einsum(c10::string_view equation, TensorList operands) {
checkDeviceType("einsum():", operands, operands[0].device().type());

// Code used to identify ELLIPSIS ("...")
constexpr int ELLIPSIS = '.';
constexpr uint8_t ELLIPSIS = 52;

// Find arrow (->) to split equation into lhs and rhs
const auto arrow_pos = equation.find("->");
const auto lhs = equation.substr(0, arrow_pos);

const auto num_ops = operands.size();

// Convert labels for input operands into an index in [0, 25] and store
// Convert labels for input operands into an index in [0, 52) and store
// them in op_labels for each operand along with ELLIPSIS if present.
std::vector<std::vector<int>> op_labels(num_ops);
std::vector<std::vector<uint8_t>> op_labels(num_ops);
bool found_ell = false;
std::size_t curr_op = 0;
for (auto i = decltype(lhs.length()){0}; i < lhs.length(); ++i) {
switch (lhs[i]) {
const unsigned char label = lhs[i];
switch (label) {
case ' ':
// Ignore spaces
break;
Expand Down Expand Up @@ -217,12 +218,11 @@ Tensor einsum(c10::string_view equation, TensorList operands) {
default:
// Parse label
TORCH_CHECK(
einsum_check_label(lhs[i]),
"einsum(): operand subscript must be in [a-zA-Z] but found ",
lhs[i],
" for operand ",
curr_op);
op_labels[curr_op].push_back(einsum_label_to_index(lhs[i]));
einsum_check_label(label),
"einsum(): invalid subscript given at index ",
i,
" in the equation string, subscripts must be in [a-zA-Z]");
op_labels[curr_op].push_back(einsum_label_to_index(label));
}
}

Expand All @@ -231,8 +231,8 @@ Tensor einsum(c10::string_view equation, TensorList operands) {
"einsum(): more operands were provided than specified in the equation");

// Labels must be within [a-zA-Z].
constexpr int TOTAL_LABELS = 52;
std::vector<int> label_count(TOTAL_LABELS, 0);
constexpr uint8_t TOTAL_LABELS = 52;
std::vector<int64_t> label_count(TOTAL_LABELS, 0);

// The maximum number of dimensions covered by any ellipsis, needed when
// unsqueezing missing dimensions from operands to permute and broadcast
Expand All @@ -244,7 +244,7 @@ Tensor einsum(c10::string_view equation, TensorList operands) {
for(const auto i : c10::irange(num_ops)) {
const auto operand = operands[i];
const auto labels = op_labels[i];
const int64_t ndims = operand.dim();
const auto ndims = operand.dim();
int64_t nlabels = labels.size();
bool has_ellipsis = false;

Expand Down Expand Up @@ -295,7 +295,8 @@ Tensor einsum(c10::string_view equation, TensorList operands) {
// Parse explicit output
const auto rhs = equation.substr(arrow_pos + 2);
for (auto i = decltype(rhs.length()){0}; i < rhs.length(); ++i) {
switch (rhs[i]) {
const unsigned char label = rhs[i];
switch (label) {
case ' ':
// Ignore spaces
break;
Expand All @@ -316,21 +317,21 @@ Tensor einsum(c10::string_view equation, TensorList operands) {

default:
TORCH_CHECK(
einsum_check_label(rhs[i]),
"einsum(): subscripts must be in [a-zA-Z] but found ",
rhs[i],
" for the output");
const auto label = einsum_label_to_index(rhs[i]);
einsum_check_label(label),
"einsum(): invalid subscript given at index ",
lhs.size() + 2 + i,
" in the equation string, subscripts must be in [a-zA-Z]");
const auto index = einsum_label_to_index(label);
TORCH_CHECK(
// Ensure label appeared at least once for some input operand and at
// most once for the output
label_count[label] > 0 && label_perm_index[label] == -1,
label_count[index] > 0 && label_perm_index[index] == -1,
"einsum(): output subscript ",
rhs[i],
label_perm_index[label] > -1
label,
label_perm_index[index] > -1
? " appears more than once in the output"
: " does not appear in the equation for any input operand");
label_perm_index[label] = perm_index++;
label_perm_index[index] = perm_index++;
}
}
}
Expand Down
27 changes: 20 additions & 7 deletions test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
from torch.jit.annotations import BroadcastingList2, BroadcastingList3, Any # noqa: F401
from torch.nn.utils.rnn import PackedSequence
from torch.testing import FileCheck
from torch.testing._internal.common_utils import make_tensor
import torch.autograd.profiler
import torch.cuda
import torch.jit
Expand Down Expand Up @@ -1580,15 +1581,27 @@ def fn_out(real, img, out):
self.checkScript(fn_out, (real, img, out, ))

def test_einsum(self):
def outer(x, y):
def check(fn, jitted, *args):
self.assertGraphContains(jitted.graph, kind='aten::einsum')
self.assertEqual(fn(*args), jitted(*args))

def equation_format(x, y):
return torch.einsum('i,j->ij', (x, y))

traced = torch.jit.trace(outer, (torch.randn(4), torch.randn(5)))
script = torch.jit.script(outer)
x, y = torch.randn(10), torch.randn(2)
for fn in [traced, script]:
self.assertGraphContains(fn.graph, kind='aten::einsum')
self.assertEqual(fn(x, y), outer(x, y))
def sublist_format(x, y):
return torch.einsum(x, [0], y, [1], [0, 1])

# Sublist format cannot be scripted because it is
# a NumPy API only feature
with self.assertRaises(RuntimeError):
torch.jit.script(sublist_format)

x = make_tensor((5,), 'cpu', torch.float32)
y = make_tensor((10,), 'cpu', torch.float32)

check(equation_format, torch.jit.script(equation_format), x, y)
check(equation_format, torch.jit.trace(equation_format, (x, y)), x, y)
check(sublist_format, torch.jit.trace(sublist_format, (x, y)), x, y)

def test_python_ivalue(self):
# Test if pure python object can be hold as IValue and conversion
Expand Down
138 changes: 87 additions & 51 deletions test/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -4453,23 +4453,24 @@ def test_qr_error_cases(self, device, dtype):

@dtypes(torch.double, torch.cdouble)
def test_einsum(self, device, dtype):
def check(equation, *operands):
ref = np.einsum(equation, *[operand.cpu().numpy() for operand in operands])
res = torch.einsum(equation, operands)
self.assertEqual(res.cpu(), torch.from_numpy(np.array(ref)))
def check(*args):
np_args = [arg.cpu().numpy() if isinstance(arg, torch.Tensor) else arg for arg in args]
ref = np.einsum(*np_args)
res = torch.einsum(*args)
self.assertEqual(torch.from_numpy(np.array(ref)), res)

# Test cases from https://gist.github.com/rockt/15ee013889d65342088e9260a377dc8f
x = torch.rand(5, device=device, dtype=dtype)
y = torch.rand(7, device=device, dtype=dtype)
A = torch.randn(3, 5, device=device, dtype=dtype)
B = torch.randn(2, 5, device=device, dtype=dtype)
C = torch.randn(2, 3, 5, device=device, dtype=dtype)
D = torch.randn(2, 5, 7, device=device, dtype=dtype)
E = torch.randn(7, 9, device=device, dtype=dtype)
F = torch.randn(2, 3, 3, 5, device=device, dtype=dtype)
G = torch.randn(5, 4, 6, device=device, dtype=dtype)
H = torch.randn(4, 4, device=device, dtype=dtype)
I = torch.rand(2, 3, 2, device=device, dtype=dtype)
x = make_tensor((5,), device, dtype)
y = make_tensor((7,), device, dtype)
A = make_tensor((3, 5), device, dtype)
B = make_tensor((2, 5), device, dtype)
C = make_tensor((2, 3, 5), device, dtype)
D = make_tensor((2, 5, 7), device, dtype)
E = make_tensor((7, 9), device, dtype)
F = make_tensor((2, 3, 3, 5), device, dtype)
G = make_tensor((5, 4, 6), device, dtype)
H = make_tensor((4, 4), device, dtype)
I = make_tensor((2, 3, 2), device, dtype)

# Vector operations
check('i->', x) # sum
Expand Down Expand Up @@ -4500,20 +4501,20 @@ def check(equation, *operands):
check("ii", H) # trace
check("ii->i", H) # diagonal
check('iji->j', I) # non-contiguous trace
check('ngrg...->nrg...', torch.rand((2, 1, 3, 1, 4), device=device, dtype=dtype))
check('ngrg...->nrg...', make_tensor((2, 1, 3, 1, 4), device, dtype))

# Test ellipsis
check("i...->...", H)
check("ki,...k->i...", A.t(), B)
check("k...,jk->...", A.t(), B)
check('...ik, ...j -> ...ij', C, x)
check('Bik,k...j->i...j', C, torch.rand(5, 3, device=device, dtype=dtype))
check('i...j, ij... -> ...ij', C, torch.rand(2, 5, 2, 3, device=device, dtype=dtype))
check('Bik,k...j->i...j', C, make_tensor((5, 3), device, dtype))
check('i...j, ij... -> ...ij', C, make_tensor((2, 5, 2, 3), device, dtype))

# torch.bilinear with noncontiguous tensors
l = torch.randn(10, 5, device=device, dtype=dtype).transpose(0, 1)
r = torch.randn(20, 5, device=device, dtype=dtype).transpose(0, 1)
w = torch.randn(15, 10, 20, device=device, dtype=dtype)
l = make_tensor((5, 10), device, dtype, noncontiguous=True)
r = make_tensor((5, 20), device, dtype, noncontiguous=True)
w = make_tensor((15, 10, 20), device, dtype)
check("bn,anm,bm->ba", l, w, r)

# with strided tensors
Expand Down Expand Up @@ -4553,7 +4554,7 @@ def check(equation, *operands):
labels.insert(ell_index, "...")

equation += ''.join(labels) + ','
ops.append(torch.rand(sizes, device=device, dtype=dtype))
ops.append(make_tensor(sizes, device, dtype))
equation = equation[:-1]

# Test with implicit output
Expand All @@ -4571,8 +4572,8 @@ def check(equation, *operands):

def test_einsum_corner_cases(self, device):
def check(equation, *operands, expected_output):
tensors = [torch.tensor(operand, dtype=torch.float32, device=device) if not isinstance(operand, tuple)
else torch.rand(operand, dtype=torch.float32, device=device) for operand in operands]
tensors = [torch.tensor(operand, device=device, dtype=torch.float32) if not isinstance(operand, tuple)
else make_tensor(operand, device, torch.float32) for operand in operands]
output = torch.einsum(equation, tensors)
self.assertEqual(output, torch.tensor(expected_output, dtype=torch.float32, device=device))

Expand Down Expand Up @@ -4610,33 +4611,68 @@ def check(equation, *operands, expected_output):
check('a...b->ab', [[[1], [2]], [[3], [4]]], expected_output=[[3], [7]])

def test_einsum_error_cases(self, device):
def check(equation, operands, regex, exception=RuntimeError):
with self.assertRaisesRegex(exception, r'einsum\(\): ' + regex):
torch.einsum(equation, operands)

x = torch.rand(2)
y = torch.rand(2, 3)

check('', [], r'must provide at least one operand')
check('. ..', [x], r'found \'.\' for operand 0 that is not part of any ellipsis')
check('... ...', [x], r'found \'.\' for operand 0 for which an ellipsis was already found')
check('1', [x], r'operand subscript must be in \[a-zA-Z\] but found 1 for operand 0')
check(',', [x], r'fewer operands were provided than specified in the equation')
check('', [x, x], r'more operands were provided than specified in the equation')
check('', [x], r'the number of subscripts in the equation \(0\) does not match the number '
r'of dimensions \(1\) for operand 0 and no ellipsis was given')
check('ai', [x], r'the number of subscripts in the equation \(2\) does not match the number '
r'of dimensions \(1\) for operand 0 and no ellipsis was given')
check('ai...', [x], r'the number of subscripts in the equation \(2\) is more than the number '
r'of dimensions \(1\) for operand 0')
check('a->... .', [x], r'found \'.\' for output but an ellipsis \(...\) was already found')
check('a->..', [x], r'found \'.\' for output that is not part of any ellipsis \(...\)')
check('a->1', [x], r'subscripts must be in \[a-zA-Z\] but found 1 for the output')
check('a->aa', [x], r'output subscript a appears more than once in the output')
check('a->i', [x], r'output subscript i does not appear in the equation for any input operand')
check('aa', [y], r'subscript a is repeated for operand 0 but the sizes don\'t match, 3 != 2')
check('a, ba', [x, y], r'operands do not broadcast with remapped shapes \[original->remapped\]: '
r'\[2\]->\[1, 2\] \[2, 3\]->\[2, 3\]')
def check(*args, regex, exception=RuntimeError):
with self.assertRaisesRegex(exception, r'einsum\(\):.*' + regex):
torch.einsum(*args)

x = make_tensor((2,), device, torch.float32)
y = make_tensor((2, 3), device, torch.float32)

check('', [], regex=r'at least one operand', exception=ValueError)
check('. ..', [x], regex=r'found \'.\' for operand 0 that is not part of any ellipsis')
check('... ...', [x], regex=r'found \'.\' for operand 0 for which an ellipsis was already found')
check('1', [x], regex=r'invalid subscript given at index 0')
check(',', [x], regex=r'fewer operands were provided than specified in the equation')
check('', [x, x], regex=r'more operands were provided than specified in the equation')
check('', [x], regex=r'the number of subscripts in the equation \(0\) does not match the number '
r'of dimensions \(1\) for operand 0 and no ellipsis was given')
check('ai', [x], regex=r'the number of subscripts in the equation \(2\) does not match the number '
r'of dimensions \(1\) for operand 0 and no ellipsis was given')
check('ai...', [x], regex=r'the number of subscripts in the equation \(2\) is more than the number '
r'of dimensions \(1\) for operand 0')
check('a->... .', [x], regex=r'found \'.\' for output but an ellipsis \(...\) was already found')
check('a->..', [x], regex=r'found \'.\' for output that is not part of any ellipsis \(...\)')
check('a->1', [x], regex=r'invalid subscript given at index 3')
check('a->aa', [x], regex=r'output subscript a appears more than once in the output')
check('a->i', [x], regex=r'output subscript i does not appear in the equation for any input operand')
check('aa', [y], regex=r'subscript a is repeated for operand 0 but the sizes don\'t match, 3 != 2')
check('a, ba', [x, y], regex=r'operands do not broadcast with remapped shapes \[original->remapped\]: '
r'\[2\]->\[1, 2\] \[2, 3\]->\[2, 3\]')

check(x, [-1], regex=r'not within the valid range \[0, 52\)', exception=ValueError)
check(x, [52], regex=r'not within the valid range \[0, 52\)', exception=ValueError)

@dtypes(torch.double, torch.cdouble)
def test_einsum_sublist_format(self, device, dtype):
def check(*args):
np_args = [arg.cpu().numpy() if isinstance(arg, torch.Tensor) else arg for arg in args]
ref = np.einsum(*np_args)
res = torch.einsum(*args)
self.assertEqual(torch.from_numpy(np.array(ref)), res)

x = make_tensor((5,), device, dtype)
y = make_tensor((7,), device, dtype)
A = make_tensor((3, 5), device, dtype)
B = make_tensor((2, 5), device, dtype)
C = make_tensor((2, 1, 3, 1, 4), device, dtype)

check(x, [0])
check(x, [0], [])
check(x, [0], y, [1], [0, 1])
check(A, [0, 1], [1, 0])
check(A, [0, 1], x, [1], [0])
check(A, [0, 1], B, [2, 1])
check(A, [0, 1], B, [2, 1], [0, 2])
check(C, [0, 1, 2, 1, Ellipsis], [0, 2, 1, Ellipsis])
check(A.t(), [0, 1], B, [Ellipsis, 0])
check(A.t(), [0, 1], B, [Ellipsis, 0], [1, Ellipsis])
check(A.t(), [0, Ellipsis], B, [1, 0], [Ellipsis])

# torch.bilinear with noncontiguous tensors
l = make_tensor((5, 10), device, dtype, noncontiguous=True)
r = make_tensor((5, 20), device, dtype, noncontiguous=True)
w = make_tensor((15, 10, 20), device, dtype)
check(l, [40, 41], w, [2, 41, 50], r, [40, 50], [40, 2])

def triangular_solve_test_helper(self, A_dims, b_dims, upper, unitriangular,
device, dtype):
Expand Down
Loading

0 comments on commit 72ae924

Please sign in to comment.