Skip to content

Commit

Permalink
fixes tests for complex numbers
Browse files Browse the repository at this point in the history
  • Loading branch information
sharadmv committed May 4, 2019
1 parent e742a26 commit f259e91
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions jax/lax_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ def svd_jvp_rule(primals, tangents, full_matrices, compute_uv):
s, U, Vt = svd_p.bind(A, full_matrices=False, compute_uv=True)

k = s.shape[-1]
Ut, V = U.T, Vt.T
Ut, V = np.conj(U).T, np.conj(Vt).T
s_dim = s[..., None, :]
dS = Ut.dot(dA).dot(V)
ds = np.diag(dS)
Expand All @@ -490,7 +490,7 @@ def svd_jvp_rule(primals, tangents, full_matrices, compute_uv):
if m > n:
dU = dU + (np.eye(m) - U.dot(Ut)).dot(dA).dot(V) / s_dim
if n > m:
dV = dV + (np.eye(n) - V.dot(Vt)).dot(dA.T).dot(U) / s_dim
dV = dV + (np.eye(n) - V.dot(Vt)).dot(np.conj(dA).T).dot(U) / s_dim
return core.pack((s, U, Vt)), core.pack((ds, dU, dV.T))

def svd_cpu_translation_rule(c, operand, full_matrices, compute_uv):
Expand Down

0 comments on commit f259e91

Please sign in to comment.