Skip to content

Commit

Permalink
torch.compile-functorch interaction: update docs (pytorch#108130)
Browse files Browse the repository at this point in the history
  • Loading branch information
kshitij12345 authored and pytorchmergebot committed Sep 5, 2023
1 parent 42f94d7 commit a74f50d
Showing 1 changed file with 120 additions and 30 deletions.
150 changes: 120 additions & 30 deletions docs/source/torch.compiler_faq.rst
Original file line number Diff line number Diff line change
Expand Up @@ -317,8 +317,8 @@ them by default: ``env TORCHDYNAMO_DYNAMIC_SHAPES=0 python model.py`` 2.
CUDA graphs with Triton are enabled by default in inductor but removing
them may alleviate some OOM issues: ``torch._inductor.config.triton.cudagraphs = False``.

``torch.func`` does not work with ``torch.compile``
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
``torch.func`` works with ``torch.compile`` (for `grad` and `vmap` transforms)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Applying a ``torch.func`` transform to a function that uses ``torch.compile``
does not work:
Expand All @@ -337,12 +337,20 @@ does not work:
x = torch.randn(2, 3)
g(x)
This code will not work. There is an `issue <https://github.com/pytorch/pytorch/issues/100320>`__
that you can track for this.

As a workaround, use ``torch.compile`` outside of the ``torch.func`` function:

.. note::
This is an experimental feature and can be used by setting `torch._dynamo.config.capture_func_transforms=True`

.. code-block:: python
import torch
torch._dynamo.config.capture_func_transforms=True
def f(x):
return torch.sin(x)
Expand All @@ -353,58 +361,140 @@ As a workaround, use ``torch.compile`` outside of the ``torch.func`` function:
x = torch.randn(2, 3)
g(x)
Applying a ``torch.func`` transform to a function handled with ``torch.compile``
--------------------------------------------------------------------------------
Calling ``torch.func`` transform inside of a function handled with ``torch.compile``
------------------------------------------------------------------------------------


For example, you have the following code:
Compiling ``torch.func.grad`` with ``torch.compile``
----------------------------------------------------

.. code-block:: python
import torch
@torch.compile
def f(x):
return torch.sin(x)
torch._dynamo.config.capture_func_transforms=True
def g(x):
return torch.grad(f)(x)
def wrapper_fn(x):
return torch.func.grad(lambda x: x.sin().sum())(x)
x = torch.randn(2, 3)
g(x)
x = torch.randn(3, 3, 3)
grad_x = torch.compile(wrapper_fn)(x)
This code will not work. There is an `issue <https://github.com/pytorch/pytorch/issues/100320>`__
that you can track for this.
As a workaround, please put the ``torch.compile`` outside of ``torch.func`` transform:
Compiling ``torch.vmap`` with ``torch.compile``
-----------------------------------------------

.. code-block:: python
import torch
def f(x):
return torch.sin(x)
torch._dynamo.config.capture_func_transforms=True
@torch.compile
def g(x):
return torch.vmap(f)(x)
def my_fn(x):
return torch.vmap(lambda x: x.sum(1))(x)
x = torch.randn(2, 3)
g(x)
x = torch.randn(3, 3, 3)
output = torch.compile(my_fn)(x)
Calling ``torch.func`` transform inside of a function handled with ``torch.compile``
------------------------------------------------------------------------------------
Limitations
-----------

There are currently a few cases which are not supported and lead to graph breaks
(that is, torch.compile falls back to eager-mode PyTorch on these). We are working
on improving the situation for the next release (PyTorch 2.2)

1. The inputs and outputs of the function being transformed over must be tensors.
We do not yet support things like tuple of Tensors.

.. code-block:: python
import torch
@torch.compile
def f(x):
return torch.vmap(torch.sum)(x)
torch._dynamo.config.capture_func_transforms=True
x = torch.randn(2, 3)
f(x)
def fn(x):
x1, x2 = x
return x1 + x2
def my_fn(x):
return torch.func.vmap(fn)(x)
x1 = torch.randn(3, 3, 3)
x2 = torch.randn(3, 3, 3)
# Unsupported, falls back to eager-mode PyTorch
output = torch.compile(my_fn)((x1, x2))
2. Keyword arguments are not supported.

.. code-block:: python
import torch
torch._dynamo.config.capture_func_transforms=True
def fn(x, y):
return (x + y).sum()
def my_fn(x, y):
return torch.func.grad(fn)(x, y=y)
x = torch.randn(3, 3)
y = torch.randn(3, 3)
# Unsupported, falls back to eager-mode PyTorch
output = torch.compile(my_fn)(x, y)
3. Functions with observable side effects. For example, it is OK to mutate a list created in the function,
but not OK to mutate a list created outside of the function.

.. code-block:: python
import torch
torch._dynamo.config.capture_func_transforms=True
some_list = []
def f(x, y):
some_list.append(1)
return x + y
def my_fn(x, y):
return torch.func.vmap(f)(x, y)
x = torch.ones(2, 3)
y = torch.randn(2, 3)
# Unsupported, falls back to eager-mode PyTorch
output = torch.compile(my_fn)(x, y)
4. ``torch.vmap`` over a function that calls one or more operators in the following list.

.. note::
'stride', 'requires_grad', 'storage_offset', 'layout', 'data', 'is_coalesced', 'is_complex',
'is_conj', 'is_contiguous', 'is_cpu', 'is_cuda', 'is_distributed', 'is_floating_point',
'is_inference', 'is_ipu', 'is_leaf', 'is_meta', 'is_mkldnn', 'is_mps', 'is_neg', 'is_nested',
'is_nonzero', 'is_ort', 'is_pinned', 'is_quantized', 'is_same_size', 'is_set_to', 'is_shared',
'is_signed', 'is_sparse', 'is_sparse_csr', 'is_vulkan', 'is_xla', 'is_xpu'

.. code-block:: python
import torch
torch._dynamo.config.capture_func_transforms=True
def bad_fn(x):
x.stride()
return x
def my_fn(x):
return torch.func.vmap(bad_fn)(x)
x = torch.randn(3, 3, 3)
# Unsupported, falls back to eager-mode PyTorch
output = torch.compile(my_fn)(x)
Compiling functions besides the ones which are supported (escape hatch)
-----------------------------------------------------------------------

This doesn't work yet. As a workaround, use ``torch._dynamo.allow_in_graph``
For other transforms, as a workaround, use ``torch._dynamo.allow_in_graph``

``allow_in_graph`` is an escape hatch. If your code does not work with
``torch.compile``, which introspects Python bytecode, but you believe it
Expand Down

0 comments on commit a74f50d

Please sign in to comment.