Skip to content

Commit

Permalink
[inductor] Make UserDefinedTritonKernel a multi-output operation (pyt…
Browse files Browse the repository at this point in the history
…orch#129325)

Previously each mutation was represented by a `MutationOutput` operation which
was a new scheduler node that must be scheduled immediately afterwards.

Now we have a single scheduler node, which produces mutiple `MutationOutput`
buffers as its output.

Pull Request resolved: pytorch#129325
Approved by: https://github.com/lezcano
ghstack dependencies: pytorch#128893
  • Loading branch information
peterbell10 authored and pytorchmergebot committed Jul 2, 2024
1 parent fb078c2 commit 7955cd3
Showing 1 changed file with 59 additions and 47 deletions.
106 changes: 59 additions & 47 deletions torch/_inductor/ir.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# mypy: allow-untyped-defs
from __future__ import annotations

import collections
import contextlib
import dataclasses
import functools
Expand Down Expand Up @@ -3941,20 +3940,22 @@ def get_read_writes_input(self, x):
return dependencies.StarDep(x.get_name())

def get_read_writes(self):
star_dep = []
reads: Set[dependencies.Dep] = set()
StarDep = dependencies.StarDep
for input in self.inputs:
if isinstance(input, list):
star_dep.extend([self.get_read_writes_input(x) for x in input])
reads.update({StarDep(x.get_name()) for x in input})
else:
star_dep.append(self.get_read_writes_input(input))
reads.add(StarDep(input.get_name()))

writes: Set[dependencies.Dep] = {
StarDep(buf.get_name()) for buf in self.get_outputs()
}

return dependencies.ReadWrites(
set(star_dep),
{dependencies.StarDep(self.get_name())},
set(),
[],
None,
op_counts=collections.Counter(),
reads=reads,
writes=writes,
index_exprs=set(),
)

@classmethod
Expand Down Expand Up @@ -4860,6 +4861,29 @@ def apply_constraint(self):
raise NotImplementedError


class MutationOutput(Buffer):
"""
An output buffer that represents the mutation of a pre-existing buffer
"""

def __init__(self, layout, mutated_node, mutating_node: Operation):
super().__init__(name=None, layout=layout)
mutated_node_name = mutated_node.get_name()
V.graph.mark_buffer_mutated(mutated_node_name)
self.mutation_names = [mutated_node_name]
self.mutating_node: Operation = mutating_node
self.name = V.graph.register_buffer(self)

def get_defining_op(self) -> Operation:
return self.mutating_node

def get_mutation_names(self):
return self.mutation_names

def should_allocate(self):
return False


class UserDefinedTritonKernel(ExternKernel):
def get_kernel_and_configs(self):
from triton.runtime.autotuner import Autotuner
Expand Down Expand Up @@ -4901,14 +4925,6 @@ def codegen(self, wrapper):
new_name, self.grid, configs, args, triton_meta, raw_args
)

def should_allocate(self):
return False

def has_side_effects(self):
# UserDefinedTritonKernel does not return anything, but rather
# modifies input in place, do not let it get DCEd
return True

def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]:
# add unbacked symbols used in the grid to the ones used
# in the kwargs (the latter is generated by ExternKernel)
Expand All @@ -4917,12 +4933,6 @@ def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]:
def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]:
return set()

def get_mutation_names(self):
# NB: Inductor only allows a node to mutate 0 or 1 buffers.
# To get around that, we create MutationOutputs which marks their
# assigned input as mutable, thus, adhering to Inductor's constraint.
return []

def __init__(self, *, kernel_idx, grid, kernel_args):
inputs = []
kwargs = dict()
Expand All @@ -4937,17 +4947,15 @@ def __init__(self, *, kernel_idx, grid, kernel_args):
kwargs[k] = v

assert len(inputs) != 0
device = inputs[0].get_device()
self.device = inputs[0].get_device()

super().__init__(
None,
NoneLayout(device), # type: ignore[arg-type]
NoneLayout(self.device), # type: ignore[arg-type]
inputs,
tuple(constant_args),
kwargs,
)
self.name = V.graph.register_buffer(self)
V.graph.register_operation(self)
self.kernel_idx = kernel_idx
self.grid = grid

Expand All @@ -4966,10 +4974,18 @@ def __init__(self, *, kernel_idx, grid, kernel_args):
kernel, {**kernel_args, **autotuned_kwargs}
)
]
mark_node_as_mutating(self, *self.mutable_args)

def get_inputs_that_alias_output(self):
return [i.get_name() for i in self.mutable_args]
self.outputs: List[Buffer] = [
MutationOutput(NoneLayout(self.device), buf, self)
for buf in self.mutable_args
]
V.graph.register_operation(self)

def get_device(self) -> torch.device:
return self.device

def get_outputs(self) -> List[Buffer]:
return self.outputs


def mark_node_as_mutating(cur_buffer, *mutated_nodes: IRNode):
Expand All @@ -4983,33 +4999,29 @@ def mark_node_as_mutating(cur_buffer, *mutated_nodes: IRNode):
assert isinstance(
node, IRNode
), f"{node} node is type {type(node)} and is not an IRNode"
V.graph.mark_buffer_mutated(node.get_name())
MutationOutput(node.get_layout(), node, cur_buffer)

MutationOperation(node.get_layout(), node, cur_buffer)

class MutationOutput(ExternKernel):
def get_mutation_names(self):
return [self.inputs[0].get_name()]

class MutationOperation(InputsKernel):
# TODO: Remove this, and use MutationOutput directly
def __init__(self, layout, mutated_node, node_doing_mutating):
# NB: Do not directly construct this - use `mark_node_as_mutating`
super().__init__(None, layout, [mutated_node, node_doing_mutating], ())
self.node_doing_mutating = node_doing_mutating
self.name = V.graph.register_buffer(self)
super().__init__(None, layout, inputs=[node_doing_mutating])
self.device = node_doing_mutating.get_device()
self.outputs: List[Buffer] = [MutationOutput(layout, mutated_node, self)]
V.graph.register_operation(self)

def get_device(self):
return self.device

def get_outputs(self) -> List[Buffer]:
return self.outputs

def should_allocate(self):
return False

def is_no_op(self):
return True

def has_side_effects(self):
return True

def get_inputs_that_alias_output(self):
return [self.inputs[0].get_name()]


class InplaceBernoulliFallback(ExternKernel):
"""
Expand Down

0 comments on commit 7955cd3

Please sign in to comment.