Skip to content

Commit

Permalink
Ensure LU decomposition cache hits in op-by-op
Browse files Browse the repository at this point in the history
  • Loading branch information
j-towns committed Sep 26, 2019
1 parent d2d0576 commit b24d6ca
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 19 deletions.
25 changes: 16 additions & 9 deletions jax/lax/lax_control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,20 @@ class FixedPointError(Exception): pass

### fori_loop and while_loop

@cache()
def _make_fori_cond(upper):
def while_cond_fun(loop_carry):
i, _ = loop_carry
return lax.lt(i, upper)
return while_cond_fun

@cache()
def _make_fori_body(body_fun):
def while_body_fun(loop_carry):
i, x = loop_carry
return lax.add(i, lax._const(i, 1)), body_fun(i, x)
return while_body_fun

def fori_loop(lower, upper, body_fun, init_val):
"""Loop from ``lower`` to ``upper`` by reduction to ``while_loop``.
Expand Down Expand Up @@ -108,15 +122,8 @@ def fori_loop(lower, upper, body_fun, init_val):
Returns:
Loop value from the final iteration, of type ``a``.
"""
def while_cond_fun(loop_carry):
i, _ = loop_carry
return lax.lt(i, upper)

def while_body_fun(loop_carry):
i, x = loop_carry
return lax.add(i, lax._const(i, 1)), body_fun(i, x)

_, result = while_loop(while_cond_fun, while_body_fun, (lower, init_val))
_, result = while_loop(_make_fori_cond(int(upper)),
_make_fori_body(body_fun), (lower, init_val))
return result


Expand Down
25 changes: 15 additions & 10 deletions jax/lax_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,6 +593,17 @@ def _lu_cpu_gpu_translation_rule(getrf_impl, c, operand):
_lu_cpu_gpu_translation_rule, cusolver.getrf)


# Define this outside lu_pivots_to_permutation to ensure fori_loop cache hits
def _lu_pivots_body_fn(i, permutation_and_swaps):
permutation, swaps = permutation_and_swaps
batch_dims = swaps.shape[:-1]
j = swaps[..., i]
iotas = np.ix_(*(lax.iota(np.int32, b) for b in batch_dims))
x = permutation[..., i]
y = permutation[iotas + (j,)]
permutation = ops.index_update(permutation, ops.index[..., i], y)
return ops.index_update(permutation, ops.index[iotas + (j,)], x), swaps

def lu_pivots_to_permutation(swaps, m):
"""Converts the pivots (row swaps) returned by LU to a permutation.
Expand All @@ -609,18 +620,12 @@ def lu_pivots_to_permutation(swaps, m):
batch_dims = swaps.shape[:-1]
k = swaps.shape[-1]

def body_fn(i, permutation):
j = swaps[..., i]
iotas = np.ix_(*(lax.iota(np.int32, b) for b in batch_dims))
x = permutation[..., i]
y = permutation[iotas + (j,)]
permutation = ops.index_update(permutation, ops.index[..., i], y)
return ops.index_update(permutation, ops.index[iotas + (j,)], x)

permutation = lax.broadcasted_iota(np.int32, batch_dims + (m,),
len(batch_dims))
return lax.fori_loop(
onp.array(0, onp.int32), onp.array(k, onp.int32), body_fn, permutation)
result, _ = lax.fori_loop(
onp.array(0, onp.int32), onp.array(k, onp.int32), _lu_pivots_body_fn,
(permutation, swaps))
return result


# QR decomposition
Expand Down

0 comments on commit b24d6ca

Please sign in to comment.