Skip to content

Commit

Permalink
lax.mul: accept boolean inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Jun 14, 2024
1 parent 895b490 commit 4f7cd03
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
4 changes: 3 additions & 1 deletion jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -2264,6 +2264,8 @@ def _add_inverse(r, x, y):
yr = r - x
return xr, yr

# Note: although XLA allows add(bool, bool) -> bool, we prohibit it in lax.add
# because it has ambiguous semantics (e.g. XLA uses XOR, numpy uses OR).
# TODO(slebedev): Why does mypy fail to infer the type here?
add_p: Primitive = standard_naryop([_num, _num], 'add')
ad.primitive_jvps[add_p] = _add_jvp
Expand Down Expand Up @@ -2318,7 +2320,7 @@ def _mul_inverse(r, x, y):
yr = r / x
return xr, yr

mul_p = standard_naryop([_num, _num], 'mul')
mul_p = standard_naryop([_any, _any], 'mul')
ad.defjvp(mul_p,
lambda xdot, x, y: mul(xdot, y),
lambda ydot, x, y: mul(x, ydot))
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/numpy/ufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def _arccosh(x: ArrayLike, /) -> Array:
bitwise_xor = _one_to_one_binop(np.bitwise_xor, lax.bitwise_xor)
left_shift = _one_to_one_binop(np.left_shift, lax.shift_left, promote_to_numeric=True)
equal = _one_to_one_binop(np.equal, lax.eq)
multiply = _maybe_bool_binop(np.multiply, lax.mul, lax.bitwise_and)
multiply = _one_to_one_binop(np.multiply, lax.mul)
not_equal = _one_to_one_binop(np.not_equal, lax.ne)
subtract = _one_to_one_binop(np.subtract, lax.sub)
arctan2 = _one_to_one_binop(np.arctan2, lax.atan2, True)
Expand Down

0 comments on commit 4f7cd03

Please sign in to comment.