Skip to content

Commit

Permalink
Add einx.experimental.shard
Browse files Browse the repository at this point in the history
  • Loading branch information
fferflo committed Apr 26, 2024
1 parent 85c2f51 commit 4d1a293
Show file tree
Hide file tree
Showing 13 changed files with 330 additions and 36 deletions.
7 changes: 6 additions & 1 deletion docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -122,4 +122,9 @@ Keras
.. autoclass:: einx.nn.keras.Norm
.. autoclass:: einx.nn.keras.Dropout

.. autofunction:: einx.nn.keras.param
.. autofunction:: einx.nn.keras.param

Experimental
============

.. autofunction:: einx.experimental.shard
1 change: 1 addition & 0 deletions einx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@
from . import expr
from .op import *
from . import nn
from . import experimental
1 change: 1 addition & 0 deletions einx/experimental/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .op import *
1 change: 1 addition & 0 deletions einx/experimental/op/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .shard import *
214 changes: 214 additions & 0 deletions einx/experimental/op/shard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
import einx
import einx.op.util as util
import numpy as np
from functools import partial
from typing import Callable, Union, Any
import numpy.typing as npt

tP = einx.tracer.import_("PartitionSpec", "P", from_="jax.sharding")
tNamedSharding = einx.tracer.import_("NamedSharding", from_="jax.sharding")
tMesh = einx.tracer.import_("Mesh", from_="jax.sharding")
tjax = einx.tracer.import_("jax")
tnp = einx.tracer.import_("numpy", as_="np")


@einx.jit(
trace=lambda t, c: lambda expr_in, tensor_in, expr_out, backend=None: c(
expr_in,
t(tensor_in),
expr_out,
)
)
def shard_stage3(expr_in, tensor_in, expr_out, mesh=None, backend=None):
import jax

for root in [expr_in, expr_out]:
for expr in root.all():
if isinstance(expr, einx.expr.stage3.Concatenation):
raise ValueError("Concatenation not allowed")

# Call tensor factories
tensor_in = einx.tracer.call_factory(tensor_in, expr_in.shape, backend=backend)
(tensor_in,) = backend.all_to_tensor([tensor_in])

# Flatten expressions
(expr_in,), (tensor_in,) = util.flatten([expr_in], [tensor_in], backend=backend)
marked_axes = tuple(
axis
for axis in expr_in
if isinstance(axis, einx.expr.stage3.Axis) and einx.expr.stage3.is_marked(axis)
)

if mesh is None:
# Construct new mesh
devices = tnp.array(tjax.devices()).reshape(tuple(a.value for a in marked_axes))
mesh = tMesh(devices, axis_names=tuple(a.name for a in marked_axes))
elif isinstance(mesh, jax.sharding.Mesh):
# Got mesh -> check that marked axes match mesh
marked_names = set(a.name for a in marked_axes)
mesh_names = set(str(a) for a in mesh.axis_names)
if not marked_names.issubset(mesh_names):
raise ValueError(
f"Marked axes must be subset of mesh axes. Got marked axes {marked_names} and mesh axes {mesh_names}"
)
else:
# Got list of devices -> construct new mesh
devices = tnp.array(mesh).reshape(tuple(a.value for a in marked_axes))
mesh = tMesh(devices, axis_names=tuple(a.name for a in marked_axes))

# Construct partition spec
axes = tuple(axis for axis in expr_in if isinstance(axis, einx.expr.stage3.Axis))
partition_spec = [axis.name if einx.expr.stage3.is_marked(axis) else None for axis in axes]

# Shard tensor
sharding = tNamedSharding(mesh, tP(*partition_spec))
tensor_in = tjax.device_put(tensor_in, sharding)

# Unflatten output expressions
(tensor_in,) = util.unflatten([expr_in], [tensor_in], [expr_out], backend=backend)

return tensor_in, expr_in


@einx.lru_cache
def parse(description, tensor_shape, cse=True, mesh=None, jax_devices=None, **parameters):
import jax

description, parameters = einx.op.util._clean_description_and_parameters(
description, parameters
)

op = einx.expr.stage1.parse_op(description)

if len(op) != 1:
raise ValueError(f"Expected exactly one expression, got {len(op)}")

def solve(eqs):
return einx.expr.solve(
[einx.expr.Equation(op[0][0], tensor_shape)]
+ eqs
+ [
einx.expr.Equation(k, np.asarray(v)[..., np.newaxis], depth1=None, depth2=None)
for k, v in parameters.items()
],
cse=cse,
)[0]

if mesh is None:
# If no mesh is given, create new mesh of all devices
try:
expr_in = solve([])
except einx.expr.SolveException as e:
# Try with additional constraint of total number of devices
expr_mesh = einx.expr.stage1.Composition(einx.expr.stage1.get_marked(op[0][0]))
mesh_eq = einx.expr.Equation(expr_mesh, [len(jax.devices())])
try:
expr_in = solve([mesh_eq])
except einx.expr.SolveException:
# If it still fails, reraise original exception
raise e
elif isinstance(mesh, jax.sharding.Mesh):
# Add constraints for existing mesh axes
expr_mesh = einx.expr.stage1.Marker(
einx.expr.stage1.List.maybe([
einx.expr.stage1.NamedAxis(name) for name in mesh.axis_names
])
)
mesh_eq = einx.expr.Equation(expr_mesh, mesh.devices.shape)

expr_in = solve([mesh_eq])
elif isinstance(mesh, (list, tuple)):
# Add constraint for number of devices
expr_mesh = einx.expr.stage1.Composition(einx.expr.stage1.get_marked(op[0][0]))
mesh_eq = einx.expr.Equation(expr_mesh, [len(mesh)])
expr_in = solve([mesh_eq])

expr_out = expr_in.__deepcopy__()

return expr_in, expr_out


@einx.traceback_util.filter
@einx.jit(
trace=lambda t, c: lambda description, tensor, mesh=None, backend=None, **kwargs: c(
description, t(tensor), mesh=mesh, **kwargs
)
)
def shard(
description: str,
tensor: einx.Tensor,
mesh: Any = None,
backend: Union[einx.Backend, str, None] = "jax",
cse: bool = True,
**parameters: npt.ArrayLike,
) -> einx.Tensor:
"""Shards a tensor over a mesh of devices.
*This function is currently only supported for Jax: A sharding is created
based on the given expression, and applied to the tensor using* ``jax.device_put``.
The tensor is sharded across the marked axes in the input expression. The marked axes
match the axis names and shape of the mesh:
>>> x = jnp.ones((2, 4, 128))
>>> x = einx.experimental.shard("[d1 d2] c")
>>> x.sharding
NamedSharding(mesh=Mesh('d1': 2, 'd2': 4), spec=PartitionSpec('d1', 'd2', None))
Axis compositions can be used to apply the
`sharding rules of Jax <https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html>`_,
where tensor axes are evenly divided by the number of shards:
>>> x = jnp.ones((128, 640, 480, 3))
>>> x = einx.experimental.shard("([batch] _) ...", x)
>>> x.sharding
NamedSharding(mesh=Mesh('batch': 8), spec=PartitionSpec('batch',))
If possible, the sharding is created over all devices. ``_`` is a regular axis name,
and its value is determined by :doc:`einx's expression solver </faq/solver>`.
Optionally, an existing mesh can be passed:
>>> from jax.sharding import Mesh
>>> devices = np.asarray(jax.devices()).reshape(4, 2)
>>> mesh = Mesh(devices, axis_names=("d1", "d2"))
>>> x = jnp.ones((4, 1024, 1024))
>>> x = einx.experimental.shard("a ([d2] b) ([d1] c)", x, mesh=mesh)
>>> x.sharding
NamedSharding(mesh=Mesh('d1': 4, 'd2': 2), spec=PartitionSpec(None, 'd2', 'd1'))
The array is replicated over all mesh axes that are not part of the expression:
>>> x = jnp.ones((1024, 1024))
>>> x = einx.experimental.shard("a ([d1] b)", x, mesh=mesh)
>>> x.sharding
NamedSharding(mesh=Mesh('d1': 4, 'd2': 2), spec=PartitionSpec(None, 'd1',))
**This function is currently experimental and will likely change in future versions.**
Args:
description: Description string in Einstein notation (see above).
tensor: Input tensor or tensor factory matching the description string.
mesh: Mesh or list of devices to shard the tensor over. If not given, a new mesh over all
available devices will be created matching the axes in the given expression.
Defaults to ``None``.
cse: Whether to apply common subexpression elimination to the expressions. Defaults
to True.
graph: Whether to return the graph representation of the operation instead of
computing the result. Defaults to False.
**parameters: Additional parameters that specify values for single axes, e.g. ``a=4``.
Returns:
The sharded tensor if ``graph=False``, otherwise the graph
representation of the operation.
"""
if backend.name != "jax":
raise NotImplementedError("einx.shard is currently only supported for Jax")
expr_in, expr_out = parse(
description, einx.tracer.get_shape(tensor), mesh=mesh, cse=cse, **parameters
)
tensor, expr_out = shard_stage3(expr_in, tensor, expr_out, mesh=mesh, backend=backend)
return tensor


shard.parse = parse
1 change: 1 addition & 0 deletions einx/expr/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from . import stage1, stage2, stage3
from .util import *
from .solver import SolveException
2 changes: 1 addition & 1 deletion einx/expr/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def to_term(x):

class SolveException(Exception):
def __init__(self, message):
self.message = message
super().__init__(message)


def solve(equations):
Expand Down
38 changes: 20 additions & 18 deletions einx/expr/stage2.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def all(self):
yield from self.inner.all()


class SolveDepthException(Exception):
class SolveDepthException(solver.SolveException):
def __init__(self, exprs1, exprs2, expansions1, expansions2, depths1, depths2, message):
assert (
len({
Expand All @@ -208,22 +208,23 @@ def __init__(self, exprs1, exprs2, expansions1, expansions2, depths1, depths2, m
self.expansions2 = expansions2
self.depths1 = depths1
self.depths2 = depths2
self.message = (
message_in = message
message = (
"Failed to solve for the depth of axes, i.e. the number of outer ellipses.\n"
"Equations:\n"
)
for expr1, expr2 in zip(exprs1, exprs2):
if expr1 is not None and expr2 is not None:
self.message += " "
self.message += f"{einx.expr.util._to_str(expr1)}"
self.message += " = "
self.message += f"{einx.expr.util._to_str(expr2)}"
self.message += "\n"
self.message += f"Reason: {message}"
super().__init__(self.message)
message += " "
message += f"{einx.expr.util._to_str(expr1)}"
message += " = "
message += f"{einx.expr.util._to_str(expr2)}"
message += "\n"
message += f"Reason: {message_in}"
super().__init__(message)


class SolveExpansionException(Exception):
class SolveExpansionException(solver.SolveException):
def __init__(self, exprs1, exprs2, expansions1, expansions2, depths1, depths2, message):
assert (
len({
Expand All @@ -242,16 +243,17 @@ def __init__(self, exprs1, exprs2, expansions1, expansions2, depths1, depths2, m
self.expansions2 = expansions2
self.depths1 = depths1
self.depths2 = depths2
self.message = "Failed to solve for the number of axes in the expressions.\nEquations:\n"
message_in = message
message = "Failed to solve for the number of axes in the expressions.\nEquations:\n"
for expr1, expr2 in zip(exprs1, exprs2):
if expr1 is not None and expr2 is not None:
self.message += " "
self.message += f"{einx.expr.util._to_str(expr1)}"
self.message += " = "
self.message += f"{einx.expr.util._to_str(expr2)}"
self.message += "\n"
self.message += f"Reason: {message}"
super().__init__(self.message)
message += " "
message += f"{einx.expr.util._to_str(expr1)}"
message += " = "
message += f"{einx.expr.util._to_str(expr2)}"
message += "\n"
message += f"Reason: {message_in}"
super().__init__(message)


def solve(exprs1, exprs2, expansions1, expansions2, depths1, depths2):
Expand Down
10 changes: 4 additions & 6 deletions einx/expr/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,16 +223,14 @@ def all(self):
yield from self.inner.all()


class SolveValueException(Exception):
class SolveValueException(solver.SolveException):
def __init__(self, exprs1, exprs2, message):
self.exprs1 = exprs1
self.exprs2 = exprs2
self.message = f"Failed to solve values of expressions. {message}\nInput:\n"
message = f"Failed to solve values of expressions. {message}\nInput:\n"
for expr1, expr2 in zip(exprs1, exprs2):
self.message += (
f" '{einx.expr.util._to_str(expr1)} = {einx.expr.util._to_str(expr2)}'\n"
)
super().__init__(self.message)
message += f" '{einx.expr.util._to_str(expr1)} = {einx.expr.util._to_str(expr2)}'\n"
super().__init__(message)


def solve(exprs1, exprs2):
Expand Down
12 changes: 7 additions & 5 deletions einx/tracer/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,11 +213,13 @@ def execute_application(self, application):

use_dynamic_output_check = False
if isinstance(application.op, Import):
import_str = f"import {application.op.module}"
name = application.op.module
if not application.op.shorthand is None:
import_str += f" as {application.op.shorthand}"
name = application.op.shorthand
import_str = f"import {application.op.import_}"
name = application.op.import_
if not application.op.as_ is None:
import_str = f"{import_str} as {application.op.as_}"
name = application.op.as_
if not application.op.from_ is None:
import_str = f"from {application.op.from_} {import_str}"

# Import only once
if not any(
Expand Down
11 changes: 6 additions & 5 deletions einx/tracer/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,17 +131,18 @@ def __copy__(self):


class Import(Tracer):
def __init__(self, module, shorthand=None):
def __init__(self, import_, as_, from_):
Tracer.__init__(self, origin="constant")
self.module = module
self.shorthand = shorthand
self.import_ = import_
self.as_ = as_
self.from_ = from_

def __call__(self): # Overwrite allowed arguments
return apply(self)


def import_(module, shorthand=None):
return Import(module, shorthand)()
def import_(import_, as_=None, from_=None):
return Import(import_, as_, from_)()


class MemberAccess(Tracer):
Expand Down
Loading

0 comments on commit 4d1a293

Please sign in to comment.