Skip to content

Commit

Permalink
Implement ellipses ('...') and diagonals (e.g. 'ii->i') in einsum. (#…
Browse files Browse the repository at this point in the history
…7173)

This brings the two most important missing numpy einsum features
to toch.einsum.
  • Loading branch information
t-vi authored and ezyang committed May 13, 2018
1 parent 7edd451 commit cfc1d92
Show file tree
Hide file tree
Showing 3 changed files with 200 additions and 89 deletions.
250 changes: 171 additions & 79 deletions aten/src/ATen/native/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
#include "ATen/NativeFunctions.h"
#include "ATen/WrapDimUtilsMulti.h"


namespace at { namespace native {


Expand Down Expand Up @@ -109,11 +108,28 @@ Tensor einsum(std::string eqn, TensorList tensors) {
constexpr size_t number_of_letters = 26;
std::string in_eqn;
size_t pos;
// we need are number of mappings (letter) index for analysing the equation. The index runs from 0='a' through 25='z'.
std::array<std::int64_t, number_of_letters> number_of_occurrences; // number of occurrence in the equation of this index
number_of_occurrences.fill(0);
std::array<std::int64_t, number_of_letters> last_occurrence; // the last operator (left to right) using this index
last_occurrence.fill(-1);
// The equation is given in terms of single lowercase letters ('a'..'z') and potentially an ellipsis.
// Internally, we represent it using indices from 0 to num_total_dimensions, with each letter
// mapped to an index and the ellipsis ('...') being mapped to a number of consequtive indices.
// The mapping of letters to internal indices is given in letter_mapping. A value of -1 means that
// the letter has not been assigned an index yet (because it has not been seen).
// The ellipsis is defined by first_ell_idx (the first index) and num_ell_idxes (the number of indices).
// A value of -1 for num_ell_idxes specifies that we have not seen an ellipsis yet.
// Note: The internal indices are NOT the dimensions used internally. There is a mapping to them below.

std::array<std::int64_t, number_of_letters> letter_mapping; // map letter to internal (numerical) label
letter_mapping.fill(-1);
int64_t num_ell_idxes = -1;
int64_t first_ell_idx = 0;

// The internal representation of the left hand side fo the equation (with ellipsis expanded) is stored in input_op_idxes.
// For each operand, we have a vector mapping each dimension to an internal index.
// We also keep track of the number of occurrences for each letter (to infer a right hand side if not given) and
// of the last occurence of each index.
std::vector<std::vector<int64_t>> input_op_idxes; // the parsed operand indices
std::array<std::int64_t, number_of_letters> num_letter_occurrences; // number of occurrence in the equation of this letter
num_letter_occurrences.fill(0);
std::vector<std::int64_t> last_idx_occurrence; // the last operator (left to right) using this index

if ((pos = eqn.find("->")) != std::string::npos) { // check whether we have a right hand side. in_eq is the left hand side
in_eqn = eqn.substr(0, pos);
Expand All @@ -125,120 +141,196 @@ Tensor einsum(std::string eqn, TensorList tensors) {
int64_t operand = 0;
std::stringstream eqn_stream(in_eqn);
std::string term;
int64_t num_total_idxes = 0;
while (! eqn_stream.eof()) {
std::getline(eqn_stream, term, ','); // term = string with indices of current term
int64_t dims_in_operand = 0;
for (auto &c : term) { // c = character with a single index
AT_CHECK(('a' <= c) && (c <= 'z'), "only lowercase letters a-z allowed as indices");
int64_t index_num = c-'a'; // index_num = index to be used in the vectors above
number_of_occurrences[index_num]++;
// when there are two occurrences we need to take a diagonal with respect to the dimensions
// occuring multiple times before continuing the processing.
// e.g. einsum('ii->i', [A]) should return the diagonal
// This waits for the general diagonal handling discussed in #6479
// for now, we error out here
AT_CHECK(last_occurrence[index_num] < operand, "diagonals (multiple occurrences of the same index for one tensor) not implemented yet")
last_occurrence[index_num] = operand;
dims_in_operand++;
AT_CHECK((int64_t) tensors.size()>operand, "more operands in equation than tensors"); // we cannot have a longer equation than operands. We need to check here before we use the dimension

int64_t ell_char_count = 0; // handling of ellipsis '...' is a bit tedious, we count the '.'
// if there is an ellipsis, the number of dimensions it represents must be total dim - letter dimensions
int64_t candidate_num_ell_idxes = tensors[operand].dim() - term.size() + 3;
int64_t dims_in_term = 0; // dimensions we have seen
std::vector<int64_t> current_op_idxes; // mapping of operand dimensions to indices for current term
for (auto &c : term) { // c = character with a single letter or '.'
if (c == '.') {
ell_char_count++;
AT_CHECK(ell_char_count <= 3, "can only have '.' in one ellispis '...' in term ", operand, " of the equation");
if (ell_char_count == 3) { // this completes the ellipsis
if (num_ell_idxes == -1) { // if we have not seen an ellipsis before, keep track of indices and size
first_ell_idx = num_total_idxes;
num_ell_idxes = candidate_num_ell_idxes;
num_total_idxes += num_ell_idxes;
}
else { // we have seen an ellipsis before, so we check compatibility
AT_CHECK(candidate_num_ell_idxes == num_ell_idxes,
"ellipsis must represent ", num_ell_idxes, " dimensions in all terms");
}
for (int64_t i = 0; i < num_ell_idxes; ++i) { // map ellipsis dimensions in operand to indices
current_op_idxes.push_back(first_ell_idx + i);
last_idx_occurrence.push_back(operand);
}
dims_in_term += num_ell_idxes; // keep track of dimensions
}
} else { // a letter (hopefully)
AT_CHECK((ell_char_count == 0) || (ell_char_count == 3), "'.' must only occur in ellipsis, operand ", operand);
AT_CHECK(('a' <= c) && (c <= 'z'), "only lowercase letters a-z allowed as indices");
int64_t letter_num = c-'a'; // letter_num = position in letter_mapping
if (letter_mapping[letter_num] == -1) { // new letter, add internal index and mapping
letter_mapping[letter_num] = num_total_idxes;
num_total_idxes++;
last_idx_occurrence.push_back(operand);
} else { // letter we have already seen
last_idx_occurrence[letter_mapping[letter_num]] = operand;
}
num_letter_occurrences[letter_num]++;
current_op_idxes.push_back(letter_mapping[letter_num]);
dims_in_term++;
}
}
AT_CHECK((int64_t) tensors.size()>operand, "more operands in equation than tensors"); // we cannot have a longer equation than operands. We need to check here before we check the dimensions
AT_CHECK(dims_in_operand == tensors[operand].dim(),
"dimension mismatch for operand ", operand, ": equation ", dims_in_operand, ", tensor ", tensors[operand].dim());
AT_CHECK(dims_in_term == tensors[operand].dim(), "dimension mismatch for operand ", operand, ": equation ", dims_in_term, " tensor ", tensors[operand].dim());
input_op_idxes.push_back(std::move(current_op_idxes));
operand++;
}
AT_CHECK((int64_t) tensors.size()==operand, "more tensors than operands in equation"); // we need ==, but > is captured above, so the error message can be specific that it is <.
// in the check below, we need ==, but > is captured above, so the error message can be specific that it is <.
AT_CHECK((int64_t) tensors.size()==operand, "more tensors than operands in equation");

// the following parses or infers output (right hand side)
// it also assigns the sorted_positions ((letter) index -> dimension in Tensors) and position_labels (dimensions in Tensors -> index)
// for the output indices
std::array<std::int64_t, number_of_letters> sorted_position; // the position of the index in the tensor dimensions
sorted_position.fill(-1);
// it also assigns the idxes_to_preprocessed_dims (index -> dimension in preprocessed / output tensors)
// for the output indices. -1 means that the index has not been assigned a dimension yet
std::vector<int64_t> idxes_to_preprocessed_dims(num_total_idxes, -1); // the position of the index in the tensor dimensions
int64_t num_output_dims = 0;
std::vector<int64_t> position_labels;
if (pos != std::string::npos) { // parse the user provided right hand side
int64_t ell_char_count = 0;
for (auto &c : eqn.substr(pos+2)) {
AT_CHECK(('a' <= c) && (c <= 'z'), "only lowercase letters a-z allowed as indices");
int64_t index_num = c-'a';
AT_CHECK(sorted_position[index_num] == -1, "index ", c, " occurs twice in output");
sorted_position[index_num] = num_output_dims;
position_labels.push_back(index_num);
num_output_dims++;
if (c == '.') { // '.' as part of ellipsis
ell_char_count++;
AT_CHECK(ell_char_count <= 3, "can only have '.' in one ellispis '...' in right hand side of the equation");
if (ell_char_count == 3) { // ellipsis complete
AT_CHECK(num_ell_idxes >= 0, "ellipsis '...' may only appear in right hand side if it does in left hand side");
for (int64_t i = 0; i < num_ell_idxes; ++i) {
idxes_to_preprocessed_dims[first_ell_idx + i] = num_output_dims;
num_output_dims++;
}
}
} else { // letter (hopefully)
AT_CHECK((ell_char_count == 0) || (ell_char_count == 3), "'.' must only occur in ellipsis in the right hand side");
AT_CHECK(('a' <= c) && (c <= 'z'), "only lowercase letters a-z allowed as indices");
int64_t letter_num = c-'a';
AT_CHECK(idxes_to_preprocessed_dims[letter_mapping[letter_num]] == -1, "index ", c, "occurs twice in output");
idxes_to_preprocessed_dims[letter_mapping[letter_num]] = num_output_dims;
num_output_dims++;
}
}
} else { // create an inferred right hand side
// the ellipsis (if in the lhs) comes first
if (num_ell_idxes >= 0) {
for (int64_t i = 0; i < num_ell_idxes; ++i) {
idxes_to_preprocessed_dims[first_ell_idx + i] = num_output_dims;
num_output_dims++;
}
}
} else { // create a right hand side: the indices that occur exactly once in alphabetic order
// then the indices that occur exactly once in alphabetic order
for (size_t idx = 0; idx < number_of_letters; idx++) {
if (number_of_occurrences[idx] == 1) {
sorted_position[idx] = num_output_dims;
position_labels.push_back(idx);
num_output_dims++;
if (num_letter_occurrences[idx] == 1) {
idxes_to_preprocessed_dims[letter_mapping[idx]] = num_output_dims;
num_output_dims++;
}
}
}
// now we assign the sorted_positions ((letter) index -> dimension in Tensors) and position_labels (dimensions in Tensors -> index)
// now we assign the idxes_to_preprocessed_dims (index -> dimension in preprocessed / output tensors)
// for the non-output indices - those that are eventually summed over
int64_t position = num_output_dims; // we now determine the porder of the remaining indices (in so far they are in the equation)
for (size_t idx = 0; idx < number_of_letters; idx++) {
if ((number_of_occurrences[idx] > 0) && (sorted_position[idx]==-1)) {
sorted_position[idx] = position;
position_labels.push_back(idx);
int64_t position = num_output_dims;
for (int64_t i = 0; i < num_total_idxes; i++) {
if (idxes_to_preprocessed_dims[i]==-1) {
idxes_to_preprocessed_dims[i] = position;
position++;
}
}
// we now "homogenize the dimensions", i.e. create all dimensions in each tensor and sort the dimensions according to the mapping in
// sorted_postition / position_labels

// we now "homogenize the dimensions", i.e.
// - take diagonals for duplicated indices
// - permute the dimensions to match the order given by idxes_to_preprocessed_dims
// - unsqueeze to create all dimensions for each index in each tensor where they are missing
// we also check that sizes match
// after this, all operands will have compatible shapes (i.e. all dimensions are aligned are broadcastable)
std::vector<Tensor> permuted_ops;
eqn_stream.clear();
eqn_stream.seekg(0, std::ios_base::beg);
std::vector<Tensor> preprocessed_operands;
std::vector<std::int64_t> size_of_dims(num_total_idxes, -1); // keep track of sizes for each index, -1 means we have not seen a size yet
for (int64_t op = 0; op < (int64_t) tensors.size(); op++) {
std::array<int64_t, number_of_letters> axes; // the dimension which the letter refers to in the permuted tensor
axes.fill(-1);
std::vector<int64_t> permutation; // permutation for this tensor
std::getline(eqn_stream, term, ',');
int64_t dim = 0;
for (auto &c : term) {
int64_t index_num = c-'a';
axes[index_num] = dim;
dim++;
auto preprocessed_op = tensors[op];
std::vector<int64_t> idx_to_dim(num_total_idxes, -1); // the dimension which the index refers to in the original tensor, -1 means it does not appear
std::vector<int64_t>& current_op_input_idxes = input_op_idxes[op];
int64_t dim = 0; // there are two dimension indices: dim is after taking diagonals, i is in input
for (size_t i = 0; i < current_op_input_idxes.size(); i++) {
auto idx = current_op_input_idxes[i];
auto dim_out = idxes_to_preprocessed_dims[idx];
if (idx_to_dim[dim_out] == -1) { // first appearance
idx_to_dim[dim_out] = dim;
if (size_of_dims[idx] == -1) { // keep track of sizes
size_of_dims[idx] = preprocessed_op.size(dim);
}
else {
AT_CHECK(size_of_dims[idx] == preprocessed_op.size(dim), "size of dimension does not match previous size, operand ", op, ", dim ", i);
}
dim++;
} else { // duplicate dimension in tensor --> take diagonal of idx_to_dim[dim_out] and dim and put the diagonal dimension to idx_to_dim[dim_out]
AT_CHECK(size_of_dims[idx] == preprocessed_op.size(dim), "size of dimension does not match previous size, operand ", op, ", dim ", i);
preprocessed_op = preprocessed_op.diagonal(0, idx_to_dim[dim_out], dim);
// diagonal moves the diagonal dimension to the back
// now we permute the last dim back to idx_to_dim[dim_out]
std::vector<int64_t> perm(preprocessed_op.dim(), 0);
for (int64_t d = 0; d < preprocessed_op.dim(); d++) {
if (d == idx_to_dim[dim_out]) {
perm[d] = preprocessed_op.dim() - 1;
} else {
perm[d] = d - (d > idx_to_dim[dim_out]);
}
}
preprocessed_op = preprocessed_op.permute(perm);
}
}
for (auto &c : position_labels) {
if (axes[c] > -1) {
permutation.push_back(axes[c]);
// now we permute the dimensions in the right order
std::vector<int64_t> permutation; // permutation for this tensor
for (auto &d : idx_to_dim) {
if (d > -1) {
permutation.push_back(d);
}
}
permuted_ops.push_back(tensors[op].permute(permutation));
for (int64_t dim = 0; dim < (int64_t) position_labels.size(); dim++) {
auto c = position_labels[dim];
if (axes[c] == -1) {
permuted_ops.back().unsqueeze_(dim);
preprocessed_op = preprocessed_op.permute(permutation);
// finally, we insert dimensions for idxes not in the operand
for (size_t dim = 0; dim < idx_to_dim.size(); dim++) {
if (idx_to_dim[dim] == -1) {
preprocessed_op.unsqueeze_(dim);
}
}
preprocessed_operands.push_back(preprocessed_op);
}

// now we reduce the indices from left to right
// numpy allows to optimize the path using various
// algorithms (see eigen_path in numpy docs)
// we start with the leftmost operator and reduce indices that
// appear only there
Tensor result = permuted_ops[0];
for (size_t idx = 0; idx < number_of_letters; idx++) {
if ((last_occurrence[idx] == 0)
&& (sorted_position[idx]>=num_output_dims)) {
result = result.sum(sorted_position[idx], true);
Tensor result = preprocessed_operands[0];
for (int64_t idx = 0; idx < num_total_idxes; idx++) {
if ((last_idx_occurrence[idx] == 0)
&& (idxes_to_preprocessed_dims[idx]>=num_output_dims)) {
result = result.sum(idxes_to_preprocessed_dims[idx], true);
}
}

// now we process each tensor using sumproduct_pair
for (int64_t i = 1; i < (int64_t) permuted_ops.size(); i++) {
for (int64_t i = 1; i < (int64_t) preprocessed_operands.size(); i++) {
std::vector<int64_t> sum_dims;
for (size_t idx = 0; idx < number_of_letters; idx++) {
if ((last_occurrence[idx] == i)
&& (sorted_position[idx]>=num_output_dims)) {
sum_dims.push_back(sorted_position[idx]);
for (int64_t idx = 0; idx < num_total_idxes; idx++) {
if ((last_idx_occurrence[idx] == i)
&& (idxes_to_preprocessed_dims[idx]>=num_output_dims)) {
sum_dims.push_back(idxes_to_preprocessed_dims[idx]);
}
}
result = at::native::sumproduct_pair(result, permuted_ops[i], sum_dims, true);
result = at::native::sumproduct_pair(result, preprocessed_operands[i], sum_dims, true);
}
// finally, we squeeze out all non-result dimensions
for (int64_t dim = position_labels.size()-1; dim >= num_output_dims; dim--)
for (int64_t dim = num_total_idxes-1; dim >= num_output_dims; dim--)
result.squeeze_(dim);
return result;
}
Expand Down
14 changes: 12 additions & 2 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1435,6 +1435,8 @@ def test_einsum(self):
E = torch.randn(7, 9)
F = torch.randn(2, 3, 5, 7)
G = torch.randn(7, 11, 13)
H = torch.randn(4, 4)
I = torch.randn(3, 4, 4)
l = torch.randn(5, 10)
r = torch.randn(5, 20)
w = torch.randn(30, 10, 20)
Expand All @@ -1461,14 +1463,22 @@ def test_einsum(self):
("ijk,jk->ij", C, A), # tensor matrix contraction with double indices
("ijk,ik->j", C, B), # non contiguous
("ijk,ik->jk", C, B), # non contiguous with double indices
# -- Diagonal
("ii", H), # trace
("ii->i", H), # diagonal
# -- Ellipsis
("i...->...", H),
("ki,...k->i...", A.t(), B),
("k...,jk", A.t(), B),
("...ii->...i", I), # batch diagonal
# -- Other
("bn,anm,bm->ba", l, w, r), # as torch.bilinear
]
for test in test_list:
actual = torch.einsum(test[0], test[1:])
expected = np.einsum(test[0], *[t.numpy() for t in test[1:]])
self.assertEqual(expected.shape, actual.shape)
self.assertTrue(np.allclose(expected, actual.numpy()))
self.assertEqual(expected.shape, actual.shape, test[0])
self.assertTrue(np.allclose(expected, actual.numpy()), test[0])

def do_einsum(*args):
return torch.einsum(test[0], args)
Expand Down
Loading

0 comments on commit cfc1d92

Please sign in to comment.