diff --git a/jax/lax_linalg.py b/jax/lax_linalg.py index b7328825005e..853a5a04c5dd 100644 --- a/jax/lax_linalg.py +++ b/jax/lax_linalg.py @@ -212,7 +212,7 @@ def eigh_impl(operand, lower): def eigh_translation_rule(c, operand, lower): raise NotImplementedError( - "Symmetric eigendecomposition is only implemented on the CPU backend") + "Symmetric eigendecomposition is only implemented on the CPU and GPU backends") def eigh_abstract_eval(operand, lower): if isinstance(operand, ShapedArray):