Skip to content

Commit

Permalink
Avoid decomposing _unsafe_index in Inductor (pytorch#107882)
Browse files Browse the repository at this point in the history
`_unsafe_index` was previously added to the core ATen decomp table in pytorch#106814, but this has performance ramifications for Inductor. Therefore, this diff removes it from the decomposition table used by Inductor.

Differential Revision: [D48649210](https://our.internmc.facebook.com/intern/diff/D48649210/)

Pull Request resolved: pytorch#107882
Approved by: https://github.com/SherlockNoMad
  • Loading branch information
SS-JIA authored and pytorchmergebot committed Aug 25, 2023
1 parent e00bd83 commit 86f9fec
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 1 deletion.
19 changes: 19 additions & 0 deletions torch/_decomp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,25 @@ def get_decompositions(
return decompositions


def remove_decompositions(
decompositions: Dict[OpOverload, Callable],
aten_ops: Sequence[Union[OpOverload, OpOverloadPacket]],
) -> None:
"""
Given a dictionary of decompositions obtained from get_decompositions(), removes
operators associated with a list of operator overloads and overload packets passed
as input. If the decomposition dictionary does not contain a decomposition that is
specified to be removed, it is silently ignored.
"""
for op in aten_ops:
if isinstance(op, OpOverloadPacket):
for overload_name in op.overloads():
opo = getattr(op, overload_name)
decompositions.pop(opo, None)
elif isinstance(op, OpOverload):
decompositions.pop(op, None)


# populate the table
import torch._decomp.decompositions
import torch._refs
Expand Down
14 changes: 13 additions & 1 deletion torch/_inductor/decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@
import torch
import torch._decomp as decomp
import torch.ao.quantization.fx._decomposed
from torch._decomp import core_aten_decompositions, get_decompositions
from torch._decomp import (
core_aten_decompositions,
get_decompositions,
remove_decompositions,
)
from torch._decomp.decompositions import pw_cast_for_opmath
from torch._decomp.decompositions_for_rng import extra_random_decomps

Expand Down Expand Up @@ -56,6 +60,14 @@
)
decompositions = {**core_aten_decompositions(), **inductor_decompositions}

# Remove unwanted decompositions included via the core ATen decompositions from
# the Inductor decomp table.
decomps_to_exclude = [
aten._unsafe_index,
]

remove_decompositions(decompositions, decomps_to_exclude)


def register_decomposition(ops):
for op in [ops] if callable(ops) else ops:
Expand Down

0 comments on commit 86f9fec

Please sign in to comment.