Skip to content

Commit

Permalink
Merge pull request numpy#9773 from dgasmith/einsum_optimal_path_fix
Browse files Browse the repository at this point in the history
BUG: Fixes optimal einsum path for multi-term intermediates
  • Loading branch information
charris authored Sep 26, 2017
2 parents a1ec286 + dc53552 commit 088c565
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 4 deletions.
10 changes: 8 additions & 2 deletions numpy/core/einsumfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,14 @@ def _optimal_path(input_sets, output_set, idx_dict, memory_limit):
new_pos = positions + [con]
iter_results.append((new_cost, new_pos, new_input_sets))

# Update list to iterate over
full_results = iter_results
# Update combinatorial list, if we did not find anything return best
# path + remaining contractions
if iter_results:
full_results = iter_results
else:
path = min(full_results, key=lambda x: x[0])[1]
path += [tuple(range(len(input_sets) - iteration))]
return path

# If we have not found anything return single einsum contraction
if len(full_results) == 0:
Expand Down
14 changes: 12 additions & 2 deletions numpy/core/tests/test_einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -767,13 +767,13 @@ def test_random_cases(self):


class TestEinSumPath(object):
def build_operands(self, string):
def build_operands(self, string, size_dict=global_size_dict):

# Builds views based off initial operands
operands = [string]
terms = string.split('->')[0].split(',')
for term in terms:
dims = [global_size_dict[x] for x in term]
dims = [size_dict[x] for x in term]
operands.append(np.random.rand(*dims))

return operands
Expand Down Expand Up @@ -863,6 +863,16 @@ def test_edge_paths(self):
path, path_str = np.einsum_path(*edge_test4, optimize='optimal')
self.assert_path_equal(path, ['einsum_path', (1, 2), (0, 2), (0, 1)])

# Edge test5
edge_test4 = self.build_operands('a,ac,ab,ad,cd,bd,bc->',
size_dict={"a": 20, "b": 20, "c": 20, "d": 20})
path, path_str = np.einsum_path(*edge_test4, optimize='greedy')
self.assert_path_equal(path, ['einsum_path', (0, 1), (0, 1, 2, 3, 4, 5)])

path, path_str = np.einsum_path(*edge_test4, optimize='optimal')
self.assert_path_equal(path, ['einsum_path', (0, 1), (0, 1, 2, 3, 4, 5)])


def test_path_type_input(self):
# Test explicit path handeling
path_test = self.build_operands('dcc,fce,ea,dbf->ab')
Expand Down

0 comments on commit 088c565

Please sign in to comment.