Skip to content

Commit

Permalink
More renaming of master to main in JAX internals (jax-ml#4179)
Browse files Browse the repository at this point in the history
  • Loading branch information
gnecula authored Aug 30, 2020
1 parent ffbfadd commit 634c625
Show file tree
Hide file tree
Showing 11 changed files with 77 additions and 77 deletions.
8 changes: 4 additions & 4 deletions jax/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,7 +645,7 @@ def copy(self):
return new

class Sublevel(int): pass
AxisEnvFrame = namedtuple('AxisEnvFrame', ['name', 'size', 'master_trace'])
AxisEnvFrame = namedtuple('AxisEnvFrame', ['name', 'size', 'main_trace'])


class TraceState:
Expand Down Expand Up @@ -1435,7 +1435,7 @@ def pp_kv_pairs(kv_pairs):
@no_type_check
def omnistaging_enabler() -> None:
global thread_local_state, call_bind, find_top_trace, initial_style_staging, \
new_master, reset_trace_state, extend_axis_env, axis_frame, \
new_main, reset_trace_state, extend_axis_env, axis_frame, \
new_base_main, eval_context, \
TraceStack, TraceState
del initial_style_staging
Expand Down Expand Up @@ -1573,8 +1573,8 @@ def bind(self, *args, **params):
Primitive.bind = bind

@contextmanager
def extend_axis_env(axis_name, size: int, master_trace: Optional[MasterTrace]):
frame = AxisEnvFrame(axis_name, size, master_trace)
def extend_axis_env(axis_name, size: int, main_trace: Optional[MainTrace]):
frame = AxisEnvFrame(axis_name, size, main_trace)
thread_local_state.trace_state.axis_env.append(frame)
try:
yield
Expand Down
4 changes: 2 additions & 2 deletions jax/experimental/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ def callback_fun(fun : lu.WrappedFun, in_vals, callback, strip_calls):
return fun.call_wrapped(*in_vals)

@lu.transformation
def callback_subtrace(master, *in_vals, **params):
trace = CallbackTrace(master, core.cur_sublevel())
def callback_subtrace(main, *in_vals, **params):
trace = CallbackTrace(main, core.cur_sublevel())
in_tracers = [CallbackTracer(trace, val) for val in in_vals]
outs = yield in_tracers, params
out_tracers = map(trace.full_raise, outs)
Expand Down
8 changes: 4 additions & 4 deletions jax/experimental/doubledouble.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ def process_call(self, call_primitive, f, tracers, params):


@lu.transformation
def doubling_subtrace(master, heads, tails):
trace = DoublingTrace(master, core.cur_sublevel())
def doubling_subtrace(main, heads, tails):
trace = DoublingTrace(main, core.cur_sublevel())
in_tracers = [DoublingTracer(trace, h, t) if t is not None else h
for h, t in zip(heads, tails)]
ans = yield in_tracers, {}
Expand All @@ -109,8 +109,8 @@ def screen_nones(num_heads, in_tree_def, *heads_and_tails):

@lu.transformation
def doubling_transform(*args):
with core.new_main(DoublingTrace) as master:
trace = DoublingTrace(master, core.cur_sublevel())
with core.new_main(DoublingTrace) as main:
trace = DoublingTrace(main, core.cur_sublevel())
in_tracers = [DoublingTracer(trace, head, tail) for head, tail in args]
outputs = yield in_tracers, {}
if isinstance(outputs, Sequence):
Expand Down
16 changes: 8 additions & 8 deletions jax/experimental/jax2tf/jax2tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,16 +212,16 @@ def converted_fun_flat_with_custom_gradient(*args_flat: TfVal) -> TfVal:

def _interpret_fun(fun: lu.WrappedFun,
in_vals: Sequence[TfValOrUnit]) -> Sequence[TfValOrUnit]:
with core.new_main(TensorFlowTrace) as master:
fun = _interpret_subtrace(fun, master)
with core.new_main(TensorFlowTrace) as main:
fun = _interpret_subtrace(fun, main)
out_vals: Sequence[TfValOrUnit] = fun.call_wrapped(*in_vals)
del master
del main
return out_vals


@lu.transformation
def _interpret_subtrace(master: core.MainTrace, *in_vals: TfValOrUnit):
trace = TensorFlowTrace(master, core.cur_sublevel())
def _interpret_subtrace(main: core.MainTrace, *in_vals: TfValOrUnit):
trace = TensorFlowTrace(main, core.cur_sublevel())
in_tracers = tuple(TensorFlowTracer(trace, val) for val in in_vals)
outs = yield in_tracers, {} # type: Sequence[TfValOrUnit]
out_tracers: Iterable[TensorFlowTracer] = map(trace.full_raise, outs) # type: ignore
Expand Down Expand Up @@ -295,7 +295,7 @@ def pure(self, val: TfValOrUnit):
return TensorFlowTracer(self, val)

def lift(self, val: core.Tracer):
"""Lifts a core.Tracer from a lower-level master into the TensorFlowTrace."""
"""Lifts a core.Tracer from a lower-level main into the TensorFlowTrace."""
# TODO(necula): this should never be needed
return TensorFlowTracer(self, val)

Expand Down Expand Up @@ -342,9 +342,9 @@ def post_process_call(self, call_primitive: core.Primitive,
# (out_tracers) include TensorFlowTracer that were not passed through
# its arguments (captured from the environment).
vals = tuple(t.val for t in out_tracers)
master = self.main
main = self.main
def todo(vals: Sequence[TfValOrUnit]):
trace = TensorFlowTrace(master, core.cur_sublevel())
trace = TensorFlowTrace(main, core.cur_sublevel())
return map(functools.partial(TensorFlowTracer, trace), vals)
return vals, todo

Expand Down
8 changes: 4 additions & 4 deletions jax/experimental/jet.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ def jet_fun(order, primals, series):
yield out_primals, out_terms

@lu.transformation
def jet_subtrace(master, primals, series):
trace = JetTrace(master, core.cur_sublevel())
def jet_subtrace(main, primals, series):
trace = JetTrace(main, core.cur_sublevel())
in_tracers = map(partial(JetTracer, trace), primals, series)
ans = yield in_tracers, {}
out_tracers = map(trace.full_raise, ans)
Expand Down Expand Up @@ -145,10 +145,10 @@ def post_process_call(self, call_primitive, out_tracers, params):
primals, series = unzip2((t.primal, t.terms) for t in out_tracers)
out, treedef = tree_flatten((primals, series))
del primals, series
master = self.main
main = self.main
def todo(x):
primals, series = tree_unflatten(treedef, x)
trace = JetTrace(master, core.cur_sublevel())
trace = JetTrace(main, core.cur_sublevel())
return map(partial(JetTracer, trace), primals, series)
return out, todo

Expand Down
12 changes: 6 additions & 6 deletions jax/experimental/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,16 +280,16 @@ def start_subtrace(self):
# TODO: This follows the __enter__ part of core.new_main.
if config.omnistaging_enabled:
level = core.thread_local_state.trace_state.trace_stack.next_level()
master = core.MainTrace(level, pe.JaxprTrace)
core.thread_local_state.trace_state.trace_stack.push(master)
main = core.MainTrace(level, pe.JaxprTrace)
core.thread_local_state.trace_state.trace_stack.push(main)
self._count_subtraces += 1
return pe.JaxprTrace(master, core.cur_sublevel())
return pe.JaxprTrace(main, core.cur_sublevel())
else:
level = core.thread_local_state.trace_state.trace_stack.next_level(False)
master = core.MainTrace(level, pe.JaxprTrace)
core.thread_local_state.trace_state.trace_stack.push(master, False)
main = core.MainTrace(level, pe.JaxprTrace)
core.thread_local_state.trace_state.trace_stack.push(main, False)
self._count_subtraces += 1
return pe.JaxprTrace(master, core.cur_sublevel())
return pe.JaxprTrace(main, core.cur_sublevel())

def end_subtrace(self):
# TODO: This follows the __exit__ part of core.new_main
Expand Down
18 changes: 9 additions & 9 deletions jax/interpreters/ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,18 +46,18 @@ def jvp(fun: lu.WrappedFun, has_aux=False, instantiate=True) -> Any:

@lu.transformation
def jvpfun(instantiate, primals, tangents):
with core.new_main(JVPTrace) as master:
out_primals, out_tangents = yield (master, primals, tangents), {}
del master
with core.new_main(JVPTrace) as main:
out_primals, out_tangents = yield (main, primals, tangents), {}
del main
if type(instantiate) is bool:
instantiate = [instantiate] * len(out_tangents)
out_tangents = [instantiate_zeros(t) if inst else t for t, inst
in zip(out_tangents, instantiate)]
yield out_primals, out_tangents

@lu.transformation
def jvp_subtrace(master, primals, tangents):
trace = JVPTrace(master, core.cur_sublevel())
def jvp_subtrace(main, primals, tangents):
trace = JVPTrace(main, core.cur_sublevel())
for x in list(primals) + list(tangents):
if isinstance(x, Tracer):
assert x._trace.level < trace.level
Expand All @@ -69,8 +69,8 @@ def jvp_subtrace(master, primals, tangents):
for out_tracer in out_tracers])

@lu.transformation_with_aux
def jvp_subtrace_aux(master, primals, tangents):
trace = JVPTrace(master, core.cur_sublevel())
def jvp_subtrace_aux(main, primals, tangents):
trace = JVPTrace(main, core.cur_sublevel())
for x in list(primals) + list(tangents):
if isinstance(x, Tracer):
assert x._trace.level < trace.level
Expand Down Expand Up @@ -280,10 +280,10 @@ def post_process_call(self, call_primitive, out_tracers, params):
primals, tangents = unzip2((t.primal, t.tangent) for t in out_tracers)
out, treedef = tree_flatten((primals, tangents))
del primals, tangents
master = self.main
main = self.main
def todo(x):
primals, tangents = tree_unflatten(treedef, x)
trace = JVPTrace(master, core.cur_sublevel())
trace = JVPTrace(main, core.cur_sublevel())
return map(partial(JVPTracer, trace), primals, tangents)
return out, todo

Expand Down
36 changes: 18 additions & 18 deletions jax/interpreters/batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ def batch(fun: lu.WrappedFun, in_vals, in_dims, out_dim_dests, axis_name):
return batched_fun.call_wrapped(*in_vals)

@lu.transformation_with_aux
def batch_subtrace(master, in_dims, *in_vals, **params):
trace = BatchTrace(master, core.cur_sublevel())
def batch_subtrace(main, in_dims, *in_vals, **params):
trace = BatchTrace(main, core.cur_sublevel())
in_tracers = [BatchTracer(trace, val, dim) if dim is not None else val
for val, dim in zip(in_vals, in_dims)]
outs = yield in_tracers, params
Expand All @@ -60,10 +60,10 @@ def _batch_fun(axis_name, sum_match, in_dims, out_dims_thunk, out_dim_dests,
canonicalize_axis(dim, np.ndim(val)) if isinstance(dim, int) else dim
for val, dim in zip(in_vals, in_dims)]
size, = {x.shape[d] for x, d in zip(in_vals, in_dims) if d is not not_mapped}
with core.new_main(BatchTrace) as master:
with core.extend_axis_env(axis_name, size, master):
out_vals = yield (master, in_dims,) + in_vals, params
del master
with core.new_main(BatchTrace) as main:
with core.extend_axis_env(axis_name, size, main):
out_vals = yield (main, in_dims,) + in_vals, params
del main
out_dim_dests = out_dim_dests() if callable(out_dim_dests) else out_dim_dests
out_dims = out_dims_thunk()
for od, od_dest in zip(out_dims, out_dim_dests):
Expand All @@ -80,9 +80,9 @@ def batch_fun2(fun : lu.WrappedFun, in_dims):

@lu.transformation
def _batch_fun2(in_dims, *in_vals, **params):
with core.new_main(BatchTrace) as master:
out_vals = yield (master, in_dims,) + in_vals, params
del master
with core.new_main(BatchTrace) as main:
out_vals = yield (main, in_dims,) + in_vals, params
del main
yield out_vals


Expand Down Expand Up @@ -174,9 +174,9 @@ def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params):

def post_process_call(self, call_primitive, out_tracers, params):
vals, dims = unzip2((t.val, t.batch_dim) for t in out_tracers)
master = self.main
main = self.main
def todo(vals):
trace = BatchTrace(master, core.cur_sublevel())
trace = BatchTrace(main, core.cur_sublevel())
return map(partial(BatchTracer, trace), vals, dims)
return vals, todo

Expand All @@ -198,9 +198,9 @@ def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params):

def post_process_map(self, call_primitive, out_tracers, params):
vals, dims = unzip2((t.val, t.batch_dim) for t in out_tracers)
master = self.main
main = self.main
def todo(vals):
trace = BatchTrace(master, core.cur_sublevel())
trace = BatchTrace(main, core.cur_sublevel())
return [BatchTracer(trace, v, d + 1 if d is not not_mapped else d)
for v, d in zip(vals, dims)]
return vals, todo
Expand Down Expand Up @@ -392,12 +392,12 @@ def batch_jaxpr(jaxpr, size, batched, instantiate):
@lu.transformation_with_aux
def batched_traceable(size, batched, instantiate, *vals):
in_dims = [0 if b else None for b in batched]
with core.new_main(BatchTrace) as master:
trace = BatchTrace(master, core.cur_sublevel())
with core.new_main(BatchTrace) as main:
trace = BatchTrace(main, core.cur_sublevel())
ans = yield map(partial(BatchTracer, trace), vals, in_dims), {}
out_tracers = map(trace.full_raise, ans)
out_vals, out_dims = unzip2((t.val, t.batch_dim) for t in out_tracers)
del master, out_tracers
del main, out_tracers
if type(instantiate) is bool:
instantiate = [instantiate] * len(out_vals)
out_vals = [moveaxis(x, d, 0) if d is not not_mapped and d != 0
Expand All @@ -409,9 +409,9 @@ def batched_traceable(size, batched, instantiate, *vals):


@lu.transformation_with_aux
def batch_custom_jvp_subtrace(master, in_dims, *in_vals):
def batch_custom_jvp_subtrace(main, in_dims, *in_vals):
size, = {x.shape[d] for x, d in zip(in_vals, in_dims) if d is not not_mapped}
trace = BatchTrace(master, core.cur_sublevel())
trace = BatchTrace(main, core.cur_sublevel())
in_tracers = [BatchTracer(trace, val, dim) if dim is not None else val
for val, dim in zip(in_vals, in_dims * 2)]
outs = yield in_tracers, {}
Expand Down
14 changes: 7 additions & 7 deletions jax/interpreters/masking.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,19 +80,19 @@ def mask_fun(fun, logical_env, padded_env, in_vals, polymorphic_shapes):
logical_env_vals = [logical_env[k] for k in env_keys]
# Make padded_env hashable
padded_env = (env_keys, padded_env_vals)
with core.new_main(MaskTrace) as master:
fun, out_shapes = mask_subtrace(fun, master, polymorphic_shapes, padded_env)
with core.new_main(MaskTrace) as main:
fun, out_shapes = mask_subtrace(fun, main, polymorphic_shapes, padded_env)
out_vals = fun.call_wrapped(*(logical_env_vals + in_vals))
del master
del main
return out_vals, out_shapes()

@lu.transformation_with_aux
def mask_subtrace(master, shapes, padded_env, *in_vals):
def mask_subtrace(main, shapes, padded_env, *in_vals):
env_keys, _ = padded_env
logical_env_vals, in_vals = in_vals[:len(env_keys)], in_vals[len(env_keys):]
logical_env = dict(zip(env_keys, logical_env_vals))
padded_env = dict(zip(*padded_env))
trace = MaskTrace(master, core.cur_sublevel())
trace = MaskTrace(main, core.cur_sublevel())
in_tracers = [MaskTracer(trace, x, s).full_lower()
for x, s in zip(in_vals, shapes)]
with extend_shape_envs(logical_env, padded_env):
Expand Down Expand Up @@ -430,9 +430,9 @@ def process_call(self, call_primitive, f, tracers, params):

def post_process_call(self, call_primitive, out_tracers, params):
vals, shapes = unzip2((t.val, t.polymorphic_shape) for t in out_tracers)
master = self.main
main = self.main
def todo(vals):
trace = MaskTrace(master, core.cur_sublevel())
trace = MaskTrace(main, core.cur_sublevel())
return map(partial(MaskTracer, trace), vals, shapes)
return vals, todo

Expand Down
24 changes: 12 additions & 12 deletions jax/interpreters/partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def post_process_call(self, primitive, out_tracers, params):
out_pvs, out_pv_consts = unzip2(t.pval for t in out_tracers)
out = out_pv_consts + consts
del consts, out_pv_consts
master = self.main
main = self.main

if primitive.map_primitive:
sz = params['axis_size']
Expand All @@ -249,7 +249,7 @@ def post_process_call(self, primitive, out_tracers, params):
def todo(x):
n = len(jaxpr.outvars)
out_pv_consts, consts = x[:n], x[n:]
trace = JaxprTrace(master, core.cur_sublevel())
trace = JaxprTrace(main, core.cur_sublevel())
const_tracers = map(trace.new_instantiated_const, consts)
out_tracers = [JaxprTracer(trace, PartialVal((out_pv, out_pv_const)), None)
for out_pv, out_pv_const in zip(out_pvs, out_pv_consts)]
Expand Down Expand Up @@ -417,20 +417,20 @@ def fun(ki, ui): # ki will be a known input in this example
consts = [3, 6] # values for `ka` and `kb` constvars
"""
trace_type = trace_type or (StagingJaxprTrace if stage_out else JaxprTrace)
with core.new_main(trace_type, bottom=bottom) as master:
fun = trace_to_subjaxpr(fun, master, instantiate)
with core.new_main(trace_type, bottom=bottom) as main:
fun = trace_to_subjaxpr(fun, main, instantiate)
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
assert not env
del master
del main

return jaxpr, out_pvals, consts


@lu.transformation
def trace_to_subjaxpr(master: core.MainTrace, instantiate: Union[bool, Sequence[bool]],
def trace_to_subjaxpr(main: core.MainTrace, instantiate: Union[bool, Sequence[bool]],
pvals: Sequence[PartialVal]):
assert all([isinstance(pv, PartialVal) for pv in pvals]), pvals
trace = JaxprTrace(master, core.cur_sublevel())
trace = JaxprTrace(main, core.cur_sublevel())
in_tracers = map(trace.new_arg, pvals)
ans = yield in_tracers, {}
instantiate = [instantiate] * len(ans) if isinstance(instantiate, bool) else instantiate
Expand Down Expand Up @@ -869,8 +869,8 @@ def __init__(self):
self.tracer_to_var = {}
self.constid_to_var = {}
self.constvar_to_val = {}
self.tracers = [] # circ refs, frame->tracer->trace->master->frame,
self.eqns = [] # cleared when we pop frame from master
self.tracers = [] # circ refs, frame->tracer->trace->main->frame,
self.eqns = [] # cleared when we pop frame from main

def to_jaxpr(self, in_tracers, out_tracers):
invars = [self.tracer_to_var[id(t)] for t in in_tracers]
Expand Down Expand Up @@ -1090,11 +1090,11 @@ def omnistaging_enabler() -> None:
def trace_to_jaxpr(fun: lu.WrappedFun, pvals: Sequence[PartialVal],
instantiate: Union[bool, Sequence[bool]] = False,
) -> Tuple[Jaxpr, Tuple[PartialVal, ...], Tuple[core.Value, ...]]:
with core.new_main(JaxprTrace) as master:
fun = trace_to_subjaxpr(fun, master, instantiate)
with core.new_main(JaxprTrace) as main:
fun = trace_to_subjaxpr(fun, main, instantiate)
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
assert not env
del master
del main

return jaxpr, out_pvals, consts

Expand Down
Loading

0 comments on commit 634c625

Please sign in to comment.