Skip to content

Commit

Permalink
Remove redundant name-stack setting in DynamicJaxprTrace
Browse files Browse the repository at this point in the history
  • Loading branch information
sharadmv committed Mar 30, 2022
1 parent 9062959 commit c233a97
Showing 1 changed file with 1 addition and 3 deletions.
4 changes: 1 addition & 3 deletions jax/interpreters/partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -1499,13 +1499,11 @@ def process_call(self, call_primitive, f, tracers, params):
dim_tracers = _get_tracers_only_in_shapes(tracers)
in_avals = _tracers_to_avals(dim_tracers + tracers)
keep_inputs = [False] * len(dim_tracers) + [True] * len(tracers)
name_stack = source_info_util.current_name_stack()
with core.new_sublevel():
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
f, self.main, in_avals, keep_inputs=keep_inputs)
if params.get('inline', False):
with source_info_util.set_name_stack(name_stack):
return core.eval_jaxpr(jaxpr, consts, *dim_tracers, *tracers)
return core.eval_jaxpr(jaxpr, consts, *dim_tracers, *tracers)
source_info = source_info_util.current()
env = {v: t for v, t in zip((*jaxpr.constvars, *jaxpr.invars),
(*consts, *dim_tracers, *tracers))
Expand Down

0 comments on commit c233a97

Please sign in to comment.