Skip to content

Commit

Permalink
Add precision to jax.numpy functions that use lax.dot_general (jax-ml…
Browse files Browse the repository at this point in the history
…#1728)

* Add precision to jax.numpy functions that use lax.dot_general

* Test precision argument

* check default precision

* test with jaxprs

* Document precision
  • Loading branch information
shoyer authored Nov 21, 2019
1 parent eff7b45 commit 27aa76e
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 27 deletions.
63 changes: 36 additions & 27 deletions jax/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2096,26 +2096,33 @@ def append(arr, values, axis=None):
### Tensor contraction operations


@_wraps(onp.dot)
def dot(a, b): # pylint: disable=missing-docstring
_PRECISION_DOC = """\
In addition to the original NumPy arguments listed below, also supports
``precision`` for extra control over matrix-multiplication precision
on supported devices. See :py:func:`jax.lax.dot` for details.
"""


@_wraps(onp.dot, lax_description=_PRECISION_DOC)
def dot(a, b, precision=None): # pylint: disable=missing-docstring
_check_arraylike("dot", a, b)
a, b = _promote_dtypes(a, b)
a_ndim, b_ndim = ndim(a), ndim(b)
if a_ndim == 0 or b_ndim == 0:
return lax.mul(a, b)
if _max(a_ndim, b_ndim) <= 2:
return lax.dot(a, b)
return lax.dot(a, b, precision=precision)

if b_ndim == 1:
contract_dims = ((a_ndim - 1,), (0,))
else:
contract_dims = ((a_ndim - 1,), (b_ndim - 2,))
batch_dims = ((), ())
return lax.dot_general(a, b, (contract_dims, batch_dims))
return lax.dot_general(a, b, (contract_dims, batch_dims), precision)


@_wraps(onp.matmul)
def matmul(a, b): # pylint: disable=missing-docstring
@_wraps(onp.matmul, lax_description=_PRECISION_DOC)
def matmul(a, b, precision=None): # pylint: disable=missing-docstring
_check_arraylike("matmul", a, b)
a_is_vec, b_is_vec = (ndim(a) == 1), (ndim(b) == 1)
a = lax.reshape(a, (1,) + shape(a)) if a_is_vec else a
Expand All @@ -2126,8 +2133,8 @@ def matmul(a, b): # pylint: disable=missing-docstring
a = broadcast_to(a, batch_shape + shape(a)[-2:])
b = broadcast_to(b, batch_shape + shape(b)[-2:])
batch_dims = tuple(range(len(batch_shape)))
result = lax.dot_general(a, b, (((ndim(a) - 1,), (ndim(b) - 2,)),
(batch_dims, batch_dims)))
dim_numbers = (((ndim(a) - 1,), (ndim(b) - 2,)), (batch_dims, batch_dims))
result = lax.dot_general(a, b, dim_numbers, precision)

if a_is_vec or b_is_vec:
m, n = shape(result)[-2:]
Expand All @@ -2138,15 +2145,15 @@ def matmul(a, b): # pylint: disable=missing-docstring
return result


@_wraps(onp.vdot)
def vdot(a, b):
@_wraps(onp.vdot, lax_description=_PRECISION_DOC)
def vdot(a, b, precision=None):
if issubdtype(_dtype(a), onp.complexfloating):
a = conj(a)
return dot(a.ravel(), b.ravel())
return dot(a.ravel(), b.ravel(), precision=precision)


@_wraps(onp.tensordot)
def tensordot(a, b, axes=2):
@_wraps(onp.tensordot, lax_description=_PRECISION_DOC)
def tensordot(a, b, axes=2, precision=None):
_check_arraylike("tensordot", a, b)
if not (ndim(a) >= 1 and ndim(b) >= 1):
msg = "tensordot requires a.ndim and b.dim to be at least 1, got {} and {}."
Expand All @@ -2161,48 +2168,49 @@ def tensordot(a, b, axes=2):
a, b = _promote_dtypes(a, b)
a_reshape = lax.reshape(a, (_prod(a.shape[:-axes]), _prod(a.shape[-axes:])))
b_reshape = lax.reshape(b, (_prod(b.shape[:axes]), _prod(b.shape[axes:])))
out_reshape = lax.dot(a_reshape, b_reshape)
out_reshape = lax.dot(a_reshape, b_reshape, precision=precision)
return lax.reshape(out_reshape, a.shape[:-axes] + b.shape[axes:])
elif type(axes) in (list, tuple) and len(axes) == 2:
ax1, ax2 = axes
if type(ax1) == type(ax2) == int:
a_transposed = moveaxis(a, ax1, -1) if ax1 != a.ndim - 1 else a
b_transposed = moveaxis(b, ax2, 0) if ax2 != 0 else b
return tensordot(a_transposed, b_transposed, 1)
return tensordot(a_transposed, b_transposed, 1, precision)
elif type(ax1) in (list, tuple) and type(ax2) in (list, tuple):
if len(ax1) != len(ax2):
msg = "tensordot requires axes lists to have equal length, got {} and {}."
raise TypeError(msg.format(ax1, ax2))
num_axes = len(ax1)
a_transposed = moveaxis(a, ax1, tuple(range(a.ndim - num_axes, a.ndim)))
b_transposed = moveaxis(b, ax2, tuple(range(num_axes)))
return tensordot(a_transposed, b_transposed, num_axes)
return tensordot(a_transposed, b_transposed, num_axes, precision)
msg = ("tensordot axes argument must be an int, a pair of ints, or a pair of "
"lists/tuples of ints.")
raise TypeError(msg)


@_wraps(onp.einsum)
@_wraps(onp.einsum, lax_description=_PRECISION_DOC)
def einsum(*operands, **kwargs):
optimize = kwargs.pop('optimize', 'auto')
optimize = 'greedy' if optimize is True else optimize
precision = kwargs.pop('precision', None)
if kwargs:
msg = 'invalid keyword arguments for einsum: {}'
raise TypeError(msg.format(', '.join(kwargs)))
# using einsum_call=True here is an internal api for opt_einsum
operands, contractions = opt_einsum.contract_path(
*operands, einsum_call=True, use_blas=True, optimize=optimize)
contractions = tuple(data[:3] for data in contractions)
return _einsum(operands, contractions)
return _einsum(operands, contractions, precision)

@_wraps(onp.einsum_path)
def einsum_path(subscripts, *operands, **kwargs):
optimize = kwargs.pop('optimize', 'greedy')
# using einsum_call=True here is an internal api for opt_einsum
return opt_einsum.contract_path(subscripts, *operands, optimize=optimize)

@partial(jit, static_argnums=(1,))
def _einsum(operands, contractions):
@partial(jit, static_argnums=(1, 2))
def _einsum(operands, contractions, precision):
operands = list(_promote_dtypes(*operands))
sum = lambda x, axes: lax.reduce(x, onp.array(0, x.dtype), lax.add, axes)

Expand Down Expand Up @@ -2292,7 +2300,8 @@ def sum_repeats(operand, names, counts, keep_names):
# contract using lax.dot_general
lhs_cont, rhs_cont = unzip2((lhs_names.index(n), rhs_names.index(n))
for n in contracted_names)
operand = _dot_general(lhs, rhs, lhs_cont, rhs_cont, len(batch_dims))
operand = _dot_general(lhs, rhs, lhs_cont, rhs_cont, len(batch_dims),
precision)
deleted_names = batch_names + ''.join(contracted_names)
names = (batch_names + removechars(lhs_names, deleted_names)
+ removechars(rhs_names, deleted_names))
Expand Down Expand Up @@ -2320,7 +2329,7 @@ def sum_repeats(operand, names, counts, keep_names):
return operands[0]


def _dot_general(lhs, rhs, lhs_cont, rhs_cont, nbatch):
def _dot_general(lhs, rhs, lhs_cont, rhs_cont, nbatch, precision):
"""Helper for einsum contractions."""
# lax.dot_general has some tight constraints on dimension_numbers that this
# wrapper loosens via transposes and reshapes
Expand All @@ -2332,7 +2341,7 @@ def _dot_general(lhs, rhs, lhs_cont, rhs_cont, nbatch):

if ncont == 1 and 0 <= lhs_ntensor <= 1 and 0 <= rhs_ntensor <= 1:
dimension_numbers = [(lhs_cont, rhs_cont), (batch_dims, batch_dims)]
return lax.dot_general(lhs, rhs, dimension_numbers)
return lax.dot_general(lhs, rhs, dimension_numbers, precision)
else:
# move contracting dimensions to the end. lax.dot_general only allows one
# contracting dimension, so if there's more than one we collapse them.
Expand Down Expand Up @@ -2360,7 +2369,7 @@ def _dot_general(lhs, rhs, lhs_cont, rhs_cont, nbatch):

lhs_cont, rhs_cont = [lhs.ndim - 1], [rhs.ndim - 1]
dimension_numbers = [(lhs_cont, rhs_cont), (batch_dims, batch_dims)]
result = lax.dot_general(lhs, rhs, dimension_numbers)
result = lax.dot_general(lhs, rhs, dimension_numbers, precision)
return lax.reshape(result, result_shape)


Expand All @@ -2372,11 +2381,11 @@ def _movechars(s, src, dst):
return ''.join(chars)


@_wraps(onp.inner)
def inner(a, b):
@_wraps(onp.inner, lax_description=_PRECISION_DOC)
def inner(a, b, precision=None):
if ndim(a) == 0 or ndim(b) == 0:
return a * b
return tensordot(a, b, (-1, -1))
return tensordot(a, b, (-1, -1), precision=precision)


@_wraps(onp.outer)
Expand Down
62 changes: 62 additions & 0 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2390,6 +2390,68 @@ def testBroadcastToIntIssue1548(self):
self.assertAllClose(lnp.broadcast_to(1, (3, 2)), onp.ones((3, 2)),
check_dtypes=False)

def testPrecision(self):

def iter_eqns(jaxpr):
for eqn in jaxpr.eqns:
yield eqn
for subjaxpr, _, _ in eqn.bound_subjaxprs:
for sub_eqn in iter_eqns(subjaxpr):
yield sub_eqn

def assert_precision(expected, fun, *args):
jaxpr = jax.make_jaxpr(fun)(*args)
precision, = [eqn.params['precision'] for eqn in iter_eqns(jaxpr)
if eqn.primitive == lax.dot_general_p]
self.assertEqual(precision, expected)

ones_1d = onp.ones((2,))
ones_2d = onp.ones((2, 2))
ones_3d = onp.ones((2, 2, 2))
HIGHEST = lax.Precision.HIGHEST

assert_precision(None, lnp.dot, ones_1d, ones_1d)
assert_precision(
HIGHEST,
partial(lnp.dot, precision=HIGHEST),
ones_1d, ones_1d)
assert_precision(
HIGHEST,
partial(lnp.dot, precision=HIGHEST),
ones_3d, ones_3d)
assert_precision(
HIGHEST,
partial(lnp.matmul, precision=HIGHEST),
ones_2d, ones_2d)
assert_precision(
HIGHEST,
partial(lnp.vdot, precision=HIGHEST),
ones_1d, ones_1d)
assert_precision(
HIGHEST,
partial(lnp.tensordot, axes=2, precision=HIGHEST),
ones_2d, ones_2d)
assert_precision(
HIGHEST,
partial(lnp.tensordot, axes=(0, 0), precision=HIGHEST),
ones_1d, ones_1d)
assert_precision(
HIGHEST,
partial(lnp.tensordot, axes=((0,), (0,)), precision=HIGHEST),
ones_1d, ones_1d)
assert_precision(
HIGHEST,
partial(lnp.einsum, 'i,i', precision=HIGHEST),
ones_1d, ones_1d)
assert_precision(
HIGHEST,
partial(lnp.einsum, 'ij,ij', precision=HIGHEST),
ones_2d, ones_2d)
assert_precision(
HIGHEST,
partial(lnp.inner, precision=HIGHEST),
ones_1d, ones_1d)

# Most grad tests are at the lax level (see lax_test.py), but we add some here
# as needed for e.g. particular compound ops of interest.

Expand Down

0 comments on commit 27aa76e

Please sign in to comment.