You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Remove einx dependency in compiled code: The code for a traced function now directly imports and uses the namespace of the backend (e.g. import torch). For example:
>>>print(einx.dot("b q (h c), b k (h c) -> b q k h", x, y, h=16, graph=True))
importtorchdefop0(i0, i1):
x0=torch.reshape(i0, (16, 768, 16, 64))
x1=torch.reshape(i1, (16, 768, 16, 64))
x2=torch.einsum("abcd,aecd->abec", x0, x1)
returnx2
In most cases, compiled functions now contain no reference to other einx code.
Improve handling of Python scalars: (see #7) einx now only converts int, float and bool to tensor objects (e.g. via torch.asarray) if the backend function that is called does not support Python scalars (previously all inputs were converted to tensor objects). When using PyTorch, the device argument will be used to place the constructed tensor on the correct device. For example, torch.add supports Python scalars