Skip to content

Commit

Permalink
Merge pull request jax-ml#15154 from mattjj:pjit-typecheck
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 518717095
  • Loading branch information
jax authors committed Mar 23, 2023
2 parents 484eb26 + 268456e commit e39578c
Show file tree
Hide file tree
Showing 11 changed files with 36 additions and 18 deletions.
5 changes: 3 additions & 2 deletions jax/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2627,7 +2627,7 @@ class JaxprTypeError(TypeError): pass

custom_typechecks: Dict[Primitive, Callable] = {}

def _check_closed_call(*in_atoms, call_jaxpr):
def _check_closed_call(_, *in_atoms, call_jaxpr):
in_avals = [x.aval for x in in_atoms]
if list(in_avals) != list(call_jaxpr.in_avals):
raise JaxprTypeError("Closed call in_avals mismatch")
Expand Down Expand Up @@ -2726,7 +2726,8 @@ def write(v: Var, a: AbstractValue) -> None:

# Compute the type of the primitive application.
if prim in custom_typechecks:
out_type, eqn_effects = custom_typechecks[prim](*in_atoms, **eqn.params)
out_type, eqn_effects = custom_typechecks[prim](
ctx_factory, *in_atoms, **eqn.params)
elif prim.call_primitive:
out_type, eqn_effects = _check_call(ctx_factory, prim, in_atoms,
eqn.params)
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/custom_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,8 +403,8 @@ def process_env_traces(primitive, level: int, jvp_was_run: bool, *args):

custom_jvp_call_p = CustomJVPCallPrimitive('custom_jvp_call')

def _custom_jvp_call_typecheck(*in_avals, call_jaxpr, jvp_jaxpr_thunk, num_consts,
symbolic_zeros):
def _custom_jvp_call_typecheck(_, *in_avals, call_jaxpr, jvp_jaxpr_thunk,
num_consts, symbolic_zeros):
# TODO(mattjj): could do more checking here...
del in_avals, jvp_jaxpr_thunk, num_consts
disallowed_effects = allowed_effects.filter_not_in(call_jaxpr.effects)
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/custom_transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def get_bind_params(self, params):


# TODO(frostig,mattjj): reinstate checks
def custom_transpose_typecheck(*in_atoms, out_types, **params):
def custom_transpose_typecheck(_, *in_atoms, out_types, **params):
del in_atoms, params
return out_types, core.no_effects

Expand Down
8 changes: 5 additions & 3 deletions jax/_src/lax/control_flow/conditionals.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,7 +733,9 @@ def _cond_axis_substitution(params, subst, traverse):
branches = tuple(core.subst_axis_names_jaxpr(jaxpr, subst) for jaxpr in params['branches'])
return dict(params, branches=branches)

def _cond_typecheck(*in_atoms, branches, linear):
def _cond_typecheck(bind_time, *in_atoms, branches, linear):
if not bind_time:
_, *in_atoms = in_atoms
avals = [x.aval for x in in_atoms]
tc = partial(_typecheck_param, 'cond')
tc(branches, 'branches', 'tuple of ClosedJaxpr',
Expand Down Expand Up @@ -794,7 +796,7 @@ def cond_bind(*args, branches, linear):
if config.jax_enable_checks:
avals = map(core.get_aval, args)
in_atoms = [core.Var(0, '', a) for a in avals] # dummies
_cond_typecheck(*in_atoms, branches=branches, linear=linear)
_cond_typecheck(True, *in_atoms, branches=branches, linear=linear)
for jaxpr in branches:
core.check_jaxpr(jaxpr.jaxpr)
return core.AxisPrimitive.bind(cond_p, *args, branches=branches, linear=linear)
Expand All @@ -810,7 +812,7 @@ def cond_bind(*args, branches, linear):
batching.spmd_axis_primitive_batchers[cond_p] = _cond_batching_rule
batching.axis_primitive_batchers[cond_p] = partial(_cond_batching_rule, None)
xla.register_initial_style_primitive(cond_p)
core.custom_typechecks[cond_p] = _cond_typecheck
core.custom_typechecks[cond_p] = partial(_cond_typecheck, False)
core.axis_substitution_rules[cond_p] = _cond_axis_substitution
pe.partial_eval_jaxpr_custom_rules[cond_p] = _cond_partial_eval_custom
pe.dce_rules[cond_p] = _cond_dce_rule
Expand Down
8 changes: 5 additions & 3 deletions jax/_src/lax/control_flow/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -908,8 +908,10 @@ def known(*ins_known):
new_vars = [*new_inst, *intensive_res, *extensive_res]
return eqn_known, eqn_staged, unks_out, inst_out, new_vars

def _scan_typecheck(bind_time, *in_atoms, reverse, length, num_consts, num_carry,
jaxpr, linear, unroll):
def _scan_typecheck(bind_time, *in_atoms, reverse, length, num_consts,
num_carry, jaxpr, linear, unroll):
if not bind_time:
_, *in_atoms = in_atoms
avals = [x.aval for x in in_atoms]
tc = partial(_typecheck_param, 'scan')
tc(reverse, 'reverse', 'bool', type(reverse) is bool)
Expand Down Expand Up @@ -1546,7 +1548,7 @@ def fun(*args):
ctx.set_tokens_out(mlir.TokenSet(zip(body_effects, tokens)))
return z

def _while_typecheck(*in_atoms, cond_jaxpr, body_jaxpr, cond_nconsts,
def _while_typecheck(_, *in_atoms, cond_jaxpr, body_jaxpr, cond_nconsts,
body_nconsts):
# TODO(frostig,mattjj): check cond_jaxpr, body_jaxpr types
joined_effects = _join_while_effects(body_jaxpr, cond_jaxpr, body_nconsts,
Expand Down
6 changes: 3 additions & 3 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -2812,7 +2812,7 @@ def _broadcast_in_dim_shape_rule(operand, *, shape, broadcast_dimensions):
return shape

def _broadcast_in_dim_typecheck_rule(
operand, *dyn_shape, shape, broadcast_dimensions):
_, operand, *dyn_shape, shape, broadcast_dimensions):
if not dyn_shape:
out_aval, effects = broadcast_in_dim_p.abstract_eval(
operand.aval, shape=shape, broadcast_dimensions=broadcast_dimensions)
Expand Down Expand Up @@ -3271,7 +3271,7 @@ def _reshape_shape_rule(operand, *, new_sizes, dimensions):
raise TypeError(msg.format(dimensions, np.shape(operand)))
return tuple(new_sizes)

def _reshape_typecheck_rule(operand, *dyn_shape, new_sizes, dimensions):
def _reshape_typecheck_rule(_, operand, *dyn_shape, new_sizes, dimensions):
if not dyn_shape:
out_aval, effects = reshape_p.abstract_eval(
operand.aval, new_sizes=new_sizes, dimensions=dimensions)
Expand Down Expand Up @@ -4506,7 +4506,7 @@ def _iota_staging_rule(trace, *dyn_shape, dtype, shape, dimension):
return _dyn_shape_staging_rule(trace, iota_p, aval, *dyn_shape, **params)
pe.custom_staging_rules[iota_p] = _iota_staging_rule

def _iota_typecheck_rule(*dyn_shape, dtype, shape, dimension):
def _iota_typecheck_rule(_, *dyn_shape, dtype, shape, dimension):
if not dyn_shape:
out_aval, effects = iota_p.abstract_eval(
dtype=dtype, shape=shape, dimension=dimension)
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/lax/slicing.py
Original file line number Diff line number Diff line change
Expand Up @@ -916,7 +916,7 @@ def _dynamic_slice_staging_rule(trace, x, *starts_and_dyn_sizes, slice_sizes):
*starts_and_dyn_sizes,
slice_sizes=slice_sizes)

def _dynamic_slice_typecheck_rule(x, *starts_and_dyn_sizes, slice_sizes):
def _dynamic_slice_typecheck_rule(_, x, *starts_and_dyn_sizes, slice_sizes):
start_indices, dyn = util.split_list(starts_and_dyn_sizes, [x.aval.ndim])
if not dyn:
out_aval, effects = dynamic_slice_p.abstract_eval(
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -889,7 +889,7 @@ def unmap_zero(zero, axes):


def _typecheck_xmap(
*in_atoms, call_jaxpr, name, in_axes, out_axes, donated_invars,
_, *in_atoms, call_jaxpr, name, in_axes, out_axes, donated_invars,
global_axis_sizes, axis_resources, resource_env, backend,
spmd_in_axes, spmd_out_axes):
in_avals = [x.aval for x in in_atoms]
Expand Down
6 changes: 5 additions & 1 deletion jax/_src/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -1343,9 +1343,13 @@ def pjit_staging_rule(trace, *args, **params):
return core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args)
else:
return trace.default_process_primitive(pjit_p, args, params)

pe.custom_staging_rules[pjit_p] = pjit_staging_rule

def _pjit_typecheck(ctx_factory, *in_atoms, jaxpr, **params):
return core._check_call(ctx_factory, pjit_p, in_atoms,
dict(params, call_jaxpr=jaxpr.jaxpr))
core.custom_typechecks[pjit_p] = _pjit_typecheck


def _pjit_abstract_eval(*args, jaxpr, out_shardings, resource_env, **_):
return jaxpr.out_avals, jaxpr.effects
Expand Down
2 changes: 1 addition & 1 deletion jax/experimental/shard_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ def _unshard_aval(mesh: Mesh, names: AxisNames, aval: core.AbstractValue

# Type-checking

def _shard_map_typecheck(*in_atoms, jaxpr, mesh, in_names, out_names,
def _shard_map_typecheck(_, *in_atoms, jaxpr, mesh, in_names, out_names,
check_rep):
for v, x, in_name in zip(jaxpr.invars, in_atoms, in_names):
if not core.typecompat(v.aval, _shard_aval(mesh, in_name, x.aval)):
Expand Down
9 changes: 9 additions & 0 deletions tests/core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,15 @@ def test_check_jaxpr_cond_correct(self):
jaxpr = make_jaxpr(lambda x: lax.switch(0, [jnp.sin, jnp.cos], x))(1.).jaxpr
core.check_jaxpr(jaxpr)

def test_check_jaxpr_jit_invalid(self):
jaxpr = make_jaxpr(jax.jit(lambda x, y: x + 1))(1., 2.).jaxpr
pjit_eqn, = jaxpr.eqns
jaxpr._eqns[0] = pjit_eqn._replace(invars=())
self.assertRaisesRegex(
core.JaxprTypeError,
'0 operands cannot call jaxpr with 2 inputs',
lambda: core.check_jaxpr(jaxpr))

def test_check_jaxpr_cond_invalid(self):
jaxpr = make_jaxpr(lambda x: lax.switch(0, [jnp.sin, jnp.cos], x))(1.).jaxpr
cond = next(eqn for eqn in jaxpr.eqns if eqn.primitive.name == 'cond')
Expand Down

0 comments on commit e39578c

Please sign in to comment.