Skip to content

v0.2.1

Compare
Choose a tag to compare
@fferflo fferflo released this 23 Apr 19:29
· 69 commits to master since this release

einx v0.2.1

Changed

  • Remove einx dependency in compiled code: The code for a traced function now directly imports and uses the namespace of the backend (e.g. import torch). For example:
    >>> print(einx.dot("b q (h c), b k (h c) -> b q k h", x, y, h=16, graph=True))
    import torch
    def op0(i0, i1):
        x0 = torch.reshape(i0, (16, 768, 16, 64))
        x1 = torch.reshape(i1, (16, 768, 16, 64))
        x2 = torch.einsum("abcd,aecd->abec", x0, x1)
        return x2
    In most cases, compiled functions now contain no reference to other einx code.
  • Improve handling of Python scalars: (see #7) einx now only converts int, float and bool to tensor objects (e.g. via torch.asarray) if the backend function that is called does not support Python scalars (previously all inputs were converted to tensor objects). When using PyTorch, the device argument will be used to place the constructed tensor on the correct device.
    For example, torch.add supports Python scalars
    >>> print(einx.add("a,", x, 1, graph=True))
    import torch
    def op0(i0, i1):
        x0 = torch.add(i0, i1)
        return x0
    while torch.maximum does not:
    >>> print(einx.maximum("a,", x, 1, graph=True))
    import torch
    def op0(i0, i1):
        x0 = torch.asarray(i1, device=i0.device)
        x1 = torch.maximum(i0, x0)
        return x1
  • Run unit tests for PyTorch and Jax also on the GPU (if it is available).
  • Run unit tests also with jax.jit and torch.compile.

Fixed