Skip to content

Commit

Permalink
Explicitly broadcast values in nn.one_hot and nn.initializers.orthogo…
Browse files Browse the repository at this point in the history
…nal. (jax-ml#2901)

At head the following fails:

```python
>>> import jax
>>> import jax.numpy as jnp
>>> jax.config.update('jax_numpy_rank_promotion', 'raise')
>>> jax.nn.one_hot(jnp.ones([8]), 512)
...
ValueError: Operands could not be broadcast together for equal on shapes (8, 1) (512,) and with the config option jax_numpy_rank_promotion='raise'. For more information, see https://jax.readthedocs.io/en/latest/rank_promotion_warning.html.
```
  • Loading branch information
tomhennigan authored May 1, 2020
1 parent 279a077 commit 0736679
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 3 deletions.
4 changes: 4 additions & 0 deletions jax/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,6 +665,10 @@ def broadcast_in_dim(operand: Array, shape: Shape,
operand, shape=tuple(shape),
broadcast_dimensions=tuple(broadcast_dimensions))

def broadcast_to_rank(x: Array, rank: int) -> Array:
"""Adds leading dimensions of ``1`` to give ``x`` rank ``rank``."""
return broadcast(x, (1,) * (rank - x.ndim))

def reshape(operand: Array, new_sizes: Shape,
dimensions: Optional[Sequence[int]] = None) -> Array:
"""Wraps XLA's `Reshape
Expand Down
5 changes: 3 additions & 2 deletions jax/nn/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,8 +264,9 @@ def one_hot(x, num_classes, *, dtype=np.float64):
"""
dtype = dtypes.canonicalize_dtype(dtype)
x = np.asarray(x)
return np.array(x[..., np.newaxis] == np.arange(num_classes, dtype=x.dtype),
dtype=dtype)
lhs = x[..., np.newaxis]
rhs = lax.broadcast_to_rank(np.arange(num_classes, dtype=x.dtype), lhs.ndim)
return np.array(lhs == rhs, dtype=dtype)

def relu6(x):
r"""Rectified Linear Unit 6 activation function.
Expand Down
3 changes: 2 additions & 1 deletion jax/nn/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ def init(key, shape, dtype=dtype):
matrix_shape = (n_cols, n_rows) if n_rows < n_cols else (n_rows, n_cols)
A = random.normal(key, matrix_shape, dtype)
Q, R = np.linalg.qr(A)
Q *= np.sign(np.diag(R)) # needed for a uniform distribution
diag_sign = lax.broadcast_to_rank(np.sign(np.diag(R)), rank=Q.ndim)
Q *= diag_sign # needed for a uniform distribution
if n_rows < n_cols: Q = Q.T
Q = np.reshape(Q, tuple(onp.delete(shape, column_axis)) + (shape[column_axis],))
Q = np.moveaxis(Q, -1, column_axis)
Expand Down
17 changes: 17 additions & 0 deletions tests/nn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,17 @@
from jax.config import config
config.parse_flags_with_absl()


class NNFunctionsTest(jtu.JaxTestCase):

def setUp(self):
super().setUp()
config.update("jax_numpy_rank_promotion", "raise")

def tearDown(self):
super().tearDown()
config.update("jax_numpy_rank_promotion", "warn")

@jtu.skip_on_flag("jax_skip_slow_tests", True)
def testSoftplusGrad(self):
check_grads(nn.softplus, (1e-8,), order=4,
Expand Down Expand Up @@ -161,6 +170,14 @@ def initializer_record(name, initializer, min_dims=2, max_dims=4):

class NNInitializersTest(jtu.JaxTestCase):

def setUp(self):
super().setUp()
config.update("jax_numpy_rank_promotion", "raise")

def tearDown(self):
super().tearDown()
config.update("jax_numpy_rank_promotion", "warn")

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
"_{}_{}".format(
Expand Down

0 comments on commit 0736679

Please sign in to comment.