Skip to content

Commit

Permalink
Fix incorrect type for eigenvectors in abstract evaluation rule for e…
Browse files Browse the repository at this point in the history
…igh.
  • Loading branch information
hawkinsp committed May 4, 2019
1 parent 9198656 commit 2d6fcf3
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 2 deletions.
2 changes: 1 addition & 1 deletion jax/lax_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def eigh_abstract_eval(operand, lower):
batch_dims = operand.shape[:-2]
n = operand.shape[-1]
v = ShapedArray(batch_dims + (n, n), operand.dtype)
w = ShapedArray(batch_dims + (n,), operand.dtype)
w = ShapedArray(batch_dims + (n,), lax.lax._complex_basetype(operand.dtype))
else:
v, w = operand, operand
return core.AbstractTuple((v, w))
Expand Down
12 changes: 11 additions & 1 deletion tests/linalg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from absl.testing import absltest
from absl.testing import parameterized

from jax import jvp, vmap
from jax import jit, grad, jvp, vmap
from jax import numpy as np
from jax import scipy as jsp
from jax import test_util as jtu
Expand Down Expand Up @@ -395,6 +395,16 @@ def args_maker():
check_dtypes=True, tol=1e-3)
self._CompileAndCheck(np.linalg.inv, args_maker, check_dtypes=True)

# Regression test for incorrect type for eigenvalues of a complex matrix.
def testIssue669(self):
def test(x):
val, vec = np.linalg.eigh(x)
return np.real(np.sum(val))

grad_test_jc = jit(grad(jit(test)))
xc = onp.eye(3, dtype=onp.complex)
self.assertAllClose(xc, grad_test_jc(xc), check_dtypes=True)


class ScipyLinalgTest(jtu.JaxTestCase):

Expand Down

0 comments on commit 2d6fcf3

Please sign in to comment.