Skip to content

Commit

Permalink
Merge pull request numpy#4977 from sotte/mdot
Browse files Browse the repository at this point in the history
ENH: add `mdot`: fast dot with multiple arguments.
  • Loading branch information
juliantaylor committed Nov 13, 2014
2 parents 25aff4d + 1b12c39 commit f14d5e1
Show file tree
Hide file tree
Showing 3 changed files with 276 additions and 4 deletions.
3 changes: 3 additions & 0 deletions doc/release/1.10.0-notes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ Highlights
==========
* numpy.distutils now supports parallel compilation via the --jobs/-j argument
passed to setup.py build
* Addition of `np.linalg.multi_dot`: compute the dot product of two or more
arrays in a single function call, while automatically selecting the fastest
evaluation order.


Dropped Support
Expand Down
187 changes: 185 additions & 2 deletions numpy/linalg/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
__all__ = ['matrix_power', 'solve', 'tensorsolve', 'tensorinv', 'inv',
'cholesky', 'eigvals', 'eigvalsh', 'pinv', 'slogdet', 'det',
'svd', 'eig', 'eigh', 'lstsq', 'norm', 'qr', 'cond', 'matrix_rank',
'LinAlgError']
'LinAlgError', 'multi_dot']

import warnings

Expand All @@ -23,7 +23,7 @@
csingle, cdouble, inexact, complexfloating, newaxis, ravel, all, Inf, dot,
add, multiply, sqrt, maximum, fastCopyAndTranspose, sum, isfinite, size,
finfo, errstate, geterrobj, longdouble, rollaxis, amin, amax, product, abs,
broadcast
broadcast, atleast_2d, intp, asanyarray
)
from numpy.lib import triu, asfarray
from numpy.linalg import lapack_lite, _umath_linalg
Expand Down Expand Up @@ -2153,3 +2153,186 @@ def norm(x, ord=None, axis=None, keepdims=False):
return ret
else:
raise ValueError("Improper number of dimensions to norm.")


# multi_dot

def multi_dot(arrays):
"""
Compute the dot product of two or more arrays in a single function call,
while automatically selecting the fastest evaluation order.
`multi_dot` chains `numpy.dot` and uses an optimal parenthesizations
of the matrices [1]_ [2]_. Depending on the shape of the matrices
this can speed up the multiplication a lot.
The first and last argument can be 1-D and are treated respectively as
row and column vector. The other arguments must be 2-D.
Think of `multi_dot` as::
def multi_dot(arrays): return functools.reduce(np.dot, arrays)
Parameters
----------
arrays : sequence of array_like
First and last argument can be 1-D and are treated respectively as
row and column vector, the other arguments must be 2-D.
Returns
-------
output : ndarray
Returns the dot product of the supplied arrays.
See Also
--------
dot : dot multiplication with two arguments.
References
----------
.. [1] Cormen, "Introduction to Algorithms", Chapter 15.2, p. 370-378
.. [2] http://en.wikipedia.org/wiki/Matrix_chain_multiplication
Examples
--------
`multi_dot` allows you to write::
>>> import numpy as np
>>> # Prepare some data
>>> A = np.random.random(10000, 100)
>>> B = np.random.random(100, 1000)
>>> C = np.random.random(1000, 5)
>>> D = np.random.random(5, 333)
>>> # the actual dot multiplication
>>> multi_dot([A, B, C, D])
instead of::
>>> np.dot(np.dot(np.dot(A, B), C), D)
>>> # or
>>> A.dot(B).dot(C).dot(D)
Example: multiplication costs of different parenthesizations
------------------------------------------------------------
The cost for a matrix multiplication can be calculated with the
following function::
def cost(A, B): return A.shape[0] * A.shape[1] * B.shape[1]
Let's assume we have three matrices
:math:`A_{10x100}, B_{100x5}, C_{5x50}$`.
The costs for the two different parenthesizations are as follows::
cost((AB)C) = 10*100*5 + 10*5*50 = 5000 + 2500 = 7500
cost(A(BC)) = 10*100*50 + 100*5*50 = 50000 + 25000 = 75000
"""
n = len(arrays)
# optimization only makes sense for len(arrays) > 2
if n < 2:
raise ValueError("Expecting at least two arrays.")
elif n == 2:
return dot(arrays[0], arrays[1])

arrays = [asanyarray(a) for a in arrays]

# save original ndim to reshape the result array into the proper form later
ndim_first, ndim_last = arrays[0].ndim, arrays[-1].ndim
# Explicitly convert vectors to 2D arrays to keep the logic of the internal
# _multi_dot_* functions as simple as possible.
if arrays[0].ndim == 1:
arrays[0] = atleast_2d(arrays[0])
if arrays[-1].ndim == 1:
arrays[-1] = atleast_2d(arrays[-1]).T
_assertRank2(*arrays)

# _multi_dot_three is much faster than _multi_dot_matrix_chain_order
if n == 3:
result = _multi_dot_three(arrays[0], arrays[1], arrays[2])
else:
order = _multi_dot_matrix_chain_order(arrays)
result = _multi_dot(arrays, order, 0, n - 1)

# return proper shape
if ndim_first == 1 and ndim_last == 1:
return result[0, 0] # scalar
elif ndim_first == 1 or ndim_last == 1:
return result.ravel() # 1-D
else:
return result


def _multi_dot_three(A, B, C):
"""
Find best ordering for three arrays and do the multiplication.
Doing in manually instead of using dynamic programing is
approximately 15 times faster.
"""
# cost1 = cost((AB)C)
cost1 = (A.shape[0] * A.shape[1] * B.shape[1] + # (AB)
A.shape[0] * B.shape[1] * C.shape[1]) # (--)C
# cost2 = cost((AB)C)
cost2 = (B.shape[0] * B.shape[1] * C.shape[1] + # (BC)
A.shape[0] * A.shape[1] * C.shape[1]) # A(--)

if cost1 < cost2:
return dot(dot(A, B), C)
else:
return dot(A, dot(B, C))


def _multi_dot_matrix_chain_order(arrays, return_costs=False):
"""
Return a np.array which encodes the opimal order of mutiplications.
The optimal order array is then used by `_multi_dot()` to do the
multiplication.
Also return the cost matrix if `return_costs` is `True`
The implementation CLOSELY follows Cormen, "Introduction to Algorithms",
Chapter 15.2, p. 370-378. Note that Cormen uses 1-based indices.
cost[i, j] = min([
cost[prefix] + cost[suffix] + cost_mult(prefix, suffix)
for k in range(i, j)])
"""
n = len(arrays)
# p stores the dimensions of the matrices
# Example for p: A_{10x100}, B_{100x5}, C_{5x50} --> p = [10, 100, 5, 50]
p = [a.shape[0] for a in arrays] + [arrays[-1].shape[1]]
# m is a matrix of costs of the subproblems
# m[i,j]: min number of scalar multiplications needed to compute A_{i..j}
m = zeros((n, n), dtype=double)
# s is the actual ordering
# s[i, j] is the value of k at which we split the product A_i..A_j
s = empty((n, n), dtype=intp)

for l in range(1, n):
for i in range(n - l):
j = i + l
m[i, j] = Inf
for k in range(i, j):
q = m[i, k] + m[k+1, j] + p[i]*p[k+1]*p[j+1]
if q < m[i, j]:
m[i, j] = q
s[i, j] = k # Note that Cormen uses 1-based index

return (s, m) if return_costs else s


def _multi_dot(arrays, order, i, j):
"""Actually do the multiplication with the given order."""
if i == j:
return arrays[i]
else:
return dot(_multi_dot(arrays, order, i, order[i, j]),
_multi_dot(arrays, order, order[i, j] + 1, j))
90 changes: 88 additions & 2 deletions numpy/linalg/tests/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from numpy import array, single, double, csingle, cdouble, dot, identity
from numpy import multiply, atleast_2d, inf, asarray, matrix
from numpy import linalg
from numpy.linalg import matrix_power, norm, matrix_rank
from numpy.linalg import matrix_power, norm, matrix_rank, multi_dot
from numpy.linalg.linalg import _multi_dot_matrix_chain_order
from numpy.testing import (
assert_, assert_equal, assert_raises, assert_array_equal,
assert_almost_equal, assert_allclose, run_module_suite,
Expand Down Expand Up @@ -207,7 +208,7 @@ def __repr__(self):
for case in src:
if not isinstance(case.a, np.ndarray):
continue

a = np.array([case.a, 2*case.a, 3*case.a])
if case.b is None:
b = None
Expand Down Expand Up @@ -1192,5 +1193,90 @@ def test_xerbla_override():
raise SkipTest('Numpy xerbla not linked in.')


class TestMultiDot(object):
def test_basic_function_with_three_arguments(self):
# multi_dot with three arguments uses a fast hand coded algorithm to
# determine the optimal order. Therefore test it separately.
A = np.random.random((6, 2))
B = np.random.random((2, 6))
C = np.random.random((6, 2))

assert_almost_equal(multi_dot([A, B, C]), A.dot(B).dot(C))
assert_almost_equal(multi_dot([A, B, C]), np.dot(A, np.dot(B, C)))

def test_basic_function_with_dynamic_programing_optimization(self):
# multi_dot with four or more arguments uses the dynamic programing
# optimization and therefore deserve a separate
A = np.random.random((6, 2))
B = np.random.random((2, 6))
C = np.random.random((6, 2))
D = np.random.random((2, 1))
assert_almost_equal(multi_dot([A, B, C, D]), A.dot(B).dot(C).dot(D))

def test_vector_as_first_argument(self):
# The first argument can be 1-D
A1d = np.random.random(2) # 1-D
B = np.random.random((2, 6))
C = np.random.random((6, 2))
D = np.random.random((2, 2))

# the result should be 1-D
assert_equal(multi_dot([A1d, B, C, D]).shape, (2,))

def test_vector_as_last_argument(self):
# The last argument can be 1-D
A = np.random.random((6, 2))
B = np.random.random((2, 6))
C = np.random.random((6, 2))
D1d = np.random.random(2) # 1-D

# the result should be 1-D
assert_equal(multi_dot([A, B, C, D1d]).shape, (6,))

def test_vector_as_first_and_last_argument(self):
# The first and last arguments can be 1-D
A1d = np.random.random(2) # 1-D
B = np.random.random((2, 6))
C = np.random.random((6, 2))
D1d = np.random.random(2) # 1-D

# the result should be a scalar
assert_equal(multi_dot([A1d, B, C, D1d]).shape, ())

def test_dynamic_programming_logic(self):
# Test for the dynamic programming part
# This test is directly taken from Cormen page 376.
arrays = [np.random.random((30, 35)),
np.random.random((35, 15)),
np.random.random((15, 5)),
np.random.random((5, 10)),
np.random.random((10, 20)),
np.random.random((20, 25))]
m_expected = np.array([[0., 15750., 7875., 9375., 11875., 15125.],
[0., 0., 2625., 4375., 7125., 10500.],
[0., 0., 0., 750., 2500., 5375.],
[0., 0., 0., 0., 1000., 3500.],
[0., 0., 0., 0., 0., 5000.],
[0., 0., 0., 0., 0., 0.]])
s_expected = np.array([[0, 1, 1, 3, 3, 3],
[0, 0, 2, 3, 3, 3],
[0, 0, 0, 3, 3, 3],
[0, 0, 0, 0, 4, 5],
[0, 0, 0, 0, 0, 5],
[0, 0, 0, 0, 0, 0]], dtype=np.int)
s_expected -= 1 # Cormen uses 1-based index, python does not.

s, m = _multi_dot_matrix_chain_order(arrays, return_costs=True)

# Only the upper triangular part (without the diagonal) is interesting.
assert_almost_equal(np.triu(s[:-1, 1:]),
np.triu(s_expected[:-1, 1:]))
assert_almost_equal(np.triu(m), np.triu(m_expected))

def test_too_few_input_arrays(self):
assert_raises(ValueError, multi_dot, [])
assert_raises(ValueError, multi_dot, [np.random.random((3, 3))])


if __name__ == "__main__":
run_module_suite()

0 comments on commit f14d5e1

Please sign in to comment.