Skip to content

Commit

Permalink
protocols for the lowering and executable underlying Lowered and `C…
Browse files Browse the repository at this point in the history
…ompiled`
  • Loading branch information
froystig committed Mar 23, 2022
1 parent 622107c commit e317863
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 38 deletions.
22 changes: 16 additions & 6 deletions jax/_src/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,23 +30,24 @@

from jax import core
from jax import linear_util as lu
from jax.errors import UnexpectedTracerError
import jax.interpreters.ad as ad
import jax.interpreters.batching as batching
import jax.interpreters.masking as masking
import jax.interpreters.mlir as mlir
import jax.interpreters.xla as xla
import jax.interpreters.partial_eval as pe
from jax.errors import UnexpectedTracerError
from jax._src.abstract_arrays import array_types
from jax._src.config import config, flags
from jax._src import device_array
from jax._src import dtypes
from jax._src import profiler
from jax._src import stages
from jax._src import traceback_util
from jax._src.abstract_arrays import array_types
from jax._src.config import config, flags
from jax._src.lib.mlir import ir
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client as xc
import jax._src.util as util
from jax._src import traceback_util

FLAGS = flags.FLAGS

Expand Down Expand Up @@ -482,7 +483,7 @@ def _execute_trivial(jaxpr, device: Optional[Device], consts, avals, handlers,
else h(*device_put(x, device)) for h, x in zip(handlers, outs)]


class XlaComputation:
class XlaComputation(stages.Computation):
name: str
_is_trivial: bool
_executable: Optional['XlaCompiledComputation']
Expand Down Expand Up @@ -583,7 +584,7 @@ def compile_or_get_cached(backend, computation, compile_options):
return backend_compile(backend, computation, compile_options)


class XlaCompiledComputation:
class XlaCompiledComputation(stages.Executable):
def __init__(self, xla_executable, in_avals, kept_var_idx, unsafe_call):
self._xla_executable = xla_executable
self.in_avals = in_avals
Expand Down Expand Up @@ -624,6 +625,7 @@ def is_trivial(self):

@property
def xla_executable(self):
# TODO(frostig): remove in favor of runtime_executable?
if self.is_trivial():
raise ValueError("A trivial compiled computation has no XLA executable")
return self._xla_executable
Expand All @@ -636,6 +638,14 @@ def from_trivial_jaxpr(jaxpr, consts, device, in_avals, out_avals,
out_avals, result_handlers, kept_var_idx)
return XlaCompiledComputation(None, in_avals, kept_var_idx, unsafe_call)

# -- stages.Executable protocol

def runtime_executable(self):
return self.xla_executable

def hlo_modules(self):
return self.xla_executable.hlo_modules()

def call(self, *args):
arg_specs = unsafe_map(arg_spec, args)
arg_avals = [spec[0] for i, spec in enumerate(arg_specs)
Expand Down
51 changes: 37 additions & 14 deletions jax/_src/stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,12 @@
# limitations under the License.

from dataclasses import dataclass
from typing import Any, Optional, Tuple, Union
from typing import Any, Optional, Sequence, Tuple
from typing_extensions import Protocol

from jax import core
from jax import tree_util
from jax.interpreters import pxla

from jax._src import dispatch
from jax._src import source_info_util
from jax._src import traceback_util
from jax._src import util
Expand All @@ -34,15 +32,6 @@
zip, unsafe_zip = util.safe_zip, zip


Computation = Union[dispatch.XlaComputation,
pxla.MeshComputation,
pxla.PmapComputation]

Executable = Union[dispatch.XlaCompiledComputation,
pxla.MeshExecutable,
pxla.PmapExecutable]


@dataclass
class ArgInfo:
aval: core.ShapedArray
Expand Down Expand Up @@ -78,6 +67,38 @@ def make_args_info(in_tree, in_avals, donate_argnums):
for i, aval in enumerate(flat_avals)])


class Executable(Protocol):
"""Protocol for compiled executable, which a ``Compiled`` encapsulates."""

def call(*args_flat) -> Sequence[Any]:
"""Invoke this on the flat list of arguments, returning flat outputs."""
raise NotImplementedError

def hlo_modules(self) -> Sequence[Any]:
"""Return a sequence of HLO modules representing this computation."""
raise NotImplementedError

def runtime_executable(self) -> Any:
"""Return an opaque reference to the executable known to the runtime."""
raise NotImplementedError


class Computation(Protocol):
"""Protocol for lowered computation, which a ``Lowered`` encapsulates."""

def compile(self) -> Executable:
"""Compile and return a corresponding ``Exectuable``."""
raise NotImplementedError

def hlo(self) -> Any:
"""Return an HLO representation of this computation."""
raise NotImplementedError

def mhlo(self) -> Any:
"""Return an MHLO representation of this computation."""
raise NotImplementedError


class Compiled(Stage):
"""Compiled representation of a function specialized to types/values.
Expand Down Expand Up @@ -107,10 +128,10 @@ def compiler_ir(self):
representation of the program after such passes, whenever
possible.
"""
return self._executable.xla_executable.hlo_modules()
return self._executable.hlo_modules()

def runtime_executable(self):
return self._executable.xla_executable
return self._executable.runtime_executable()

def _xla_executable(self):
# TODO(frostig): finalize API. For now, return the underlying
Expand Down Expand Up @@ -221,6 +242,7 @@ def _xla_computation(self):
class Wrapped(Protocol):
def __call__(self, *args, **kwargs):
"""Executes the wrapped function, lowering and compiling as needed."""
raise NotImplementedError

def lower(self, *args, **kwargs) -> Lowered:
"""Lower this function for the given arguments.
Expand All @@ -232,3 +254,4 @@ def lower(self, *args, **kwargs) -> Lowered:
Returns:
A ``Lowered`` instance representing the lowering.
"""
raise NotImplementedError
57 changes: 39 additions & 18 deletions jax/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,32 +45,35 @@
import numpy as np

import jax
from jax._src.config import config
from jax import core
from jax import linear_util as lu
from jax._src import abstract_arrays
from jax._src.abstract_arrays import array_types
from jax.core import ConcreteArray, ShapedArray
from jax.errors import JAXTypeError
from jax.interpreters import ad
from jax.interpreters import batching
from jax.interpreters import mlir
from jax.interpreters import partial_eval as pe
from jax.interpreters import xla
from jax.tree_util import tree_flatten, tree_map

from jax._src import abstract_arrays
from jax._src import device_array
from jax._src import source_info_util
from jax._src import util
from jax._src.util import (unzip3, prod, safe_map, safe_zip,
extend_name_stack, new_name_stack, wrap_name, assert_unreachable,
tuple_insert, tuple_delete, distributed_debug_log)
from jax.errors import JAXTypeError
from jax._src import dispatch
from jax._src import profiler
from jax._src import stages
from jax._src.abstract_arrays import array_types
from jax._src.config import config
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client as xc
from jax._src.lib import pmap_lib
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import mhlo
from jax.tree_util import tree_flatten, tree_map
from jax.interpreters import batching
from jax.interpreters import mlir
from jax.interpreters import partial_eval as pe
from jax.interpreters import xla
from jax.interpreters import ad
from jax._src.util import (unzip3, prod, safe_map, safe_zip,
extend_name_stack, new_name_stack, wrap_name, assert_unreachable,
tuple_insert, tuple_delete, distributed_debug_log)


# Built in Python lists don't support weak refs but subclasses of lists do.
class WeakRefList(list):
Expand Down Expand Up @@ -1063,8 +1066,10 @@ def lower_parallel_callable(
shards=shards, tuple_args=tuple_args)


class PmapComputation:
class PmapComputation(stages.Computation):
_hlo: Union[ir.Module, xc.XlaComputation]
_executable: Optional['PmapExecutable']

def __init__(self, hlo: Union[ir.Module, xc.XlaComputation], **compile_args):
self._executable = None
self._hlo = hlo
Expand All @@ -1087,13 +1092,13 @@ def mhlo(self) -> ir.Module:
return self._hlo

@profiler.annotate_function
def compile(self):
def compile(self) -> 'PmapExecutable':
if self._executable is None:
self._executable = PmapExecutable.from_hlo(self._hlo, **self.compile_args)
return self._executable


class PmapExecutable:
class PmapExecutable(stages.Executable):
__slots__ = ['xla_executable', 'unsafe_call', 'fingerprint', 'in_avals']

def __init__(self, xla_executable, unsafe_call, fingerprint, in_avals):
Expand Down Expand Up @@ -1223,6 +1228,14 @@ def from_hlo(xla_computation,

return PmapExecutable(compiled, execute_fun, fingerprint, pci.avals)

# -- stages.Executable protocol

def runtime_executable(self):
return self.xla_executable

def hlo_modules(self):
return self.xla_executable.hlo_modules()

@profiler.annotate_function
def call(self, *args):
# TODO(frostig): do we need to check sharding and sharded avals?
Expand Down Expand Up @@ -2246,7 +2259,7 @@ def lower_mesh_computation(
spmd_lowering=spmd_lowering, tuple_args=tuple_args, in_is_global=in_is_global)


class MeshComputation:
class MeshComputation(stages.Computation):
_hlo: Union[ir.Module, xc.XlaComputation]
_executable: Optional['MeshExecutable']

Expand Down Expand Up @@ -2313,7 +2326,7 @@ def _get_input_metadata(global_in_avals, global_mesh, in_axes, in_is_global):
return input_specs, input_indices, input_avals


class MeshExecutable:
class MeshExecutable(stages.Executable):
__slots__ = ['xla_executable', 'unsafe_call', '_input_avals']

def __init__(self, xla_executable, unsafe_call, input_avals):
Expand Down Expand Up @@ -2374,6 +2387,14 @@ def from_hlo(name: str,

return MeshExecutable(xla_executable, unsafe_call, input_avals)

# -- stages.Executable protocol

def runtime_executable(self):
return self.xla_executable

def hlo_modules(self):
return self.xla_executable.hlo_modules()

def call(self, *args):
# TODO(yashkatariya): Add a AOT lowering test where GDA is an input.
arg_avals = map(xla.abstractify, args)
Expand Down

0 comments on commit e317863

Please sign in to comment.