Skip to content

Commit

Permalink
Merge pull request jax-ml#2199 from sharadmv/patch-1
Browse files Browse the repository at this point in the history
Fix inconsistent indentation in `JaxprTrace.default_process_primitive`.
  • Loading branch information
gnecula authored Feb 11, 2020
2 parents 28e802c + 76d77bf commit f4b946e
Showing 1 changed file with 16 additions and 16 deletions.
32 changes: 16 additions & 16 deletions jax/interpreters/partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,22 +98,22 @@ def process_primitive(self, primitive, tracers, params):
return self.default_process_primitive(primitive, tracers, params)

def default_process_primitive(self, primitive, tracers, params):
pvs, consts = unzip2(t.pval for t in tracers)
if all(pv is None for pv in pvs):
return primitive.bind(*consts, **params)
tracers = map(self.instantiate_const, tracers)
avals = [t.aval for t in tracers]
out_aval = primitive.abstract_eval(*avals, **params)
if primitive.multiple_results:
out_tracers = [JaxprTracer(self, PartialVal((aval, unit)), None)
for aval in out_aval]
eqn = new_eqn_recipe(tracers, out_tracers, primitive, params)
for t in out_tracers: t.recipe = eqn
return out_tracers
else:
out_tracer = JaxprTracer(self, PartialVal((out_aval, unit)), None)
out_tracer.recipe = new_eqn_recipe(tracers, [out_tracer], primitive, params)
return out_tracer
pvs, consts = unzip2(t.pval for t in tracers)
if all(pv is None for pv in pvs):
return primitive.bind(*consts, **params)
tracers = map(self.instantiate_const, tracers)
avals = [t.aval for t in tracers]
out_aval = primitive.abstract_eval(*avals, **params)
if primitive.multiple_results:
out_tracers = [JaxprTracer(self, PartialVal((aval, unit)), None)
for aval in out_aval]
eqn = new_eqn_recipe(tracers, out_tracers, primitive, params)
for t in out_tracers: t.recipe = eqn
return out_tracers
else:
out_tracer = JaxprTracer(self, PartialVal((out_aval, unit)), None)
out_tracer.recipe = new_eqn_recipe(tracers, [out_tracer], primitive, params)
return out_tracer

def process_call(self, call_primitive, f, tracers, params):
name = params.get('name', f.__name__)
Expand Down

0 comments on commit f4b946e

Please sign in to comment.