Skip to content

Commit

Permalink
[Inductor Intel GPU backend Upstream] Step 1/3: Generalize device-bia…
Browse files Browse the repository at this point in the history
…s code in code generation. (pytorch#116020)

As the [RFC](pytorch#114856) mentions, this is the step 1 to add Intel GPU backend as an alternative inductor backend.

### Design
Typically, in order to integrate Intel GPU backend into Inductor, we need to inherit from `WrapperCodegen` and `TritonScheduling` and implement the corresponding subclasses respectively. However, since `WrapperCodegen` and `TritonScheduling` have some device-bias code generation **scattered** in their methods, overriding them in subclasses would introduce a lot of duplicated parent class code.
For example:
https://github.com/pytorch/pytorch/blob/2a440348958b3f0a2b09458bd76fe5959b371c0c/torch/_inductor/codegen/wrapper.py#L487

https://github.com/pytorch/pytorch/blob/2a440348958b3f0a2b09458bd76fe5959b371c0c/torch/_inductor/codegen/triton.py#L1996

 So we abstract the device-bias code scattered in WrapperCodegen and TritonScheduling and provide a unified interface "DeviceOpOverrides". This way, when integrating a new backend, we can  maximize the reuse of `WrapperCodegen` and `TritonScheduling` code by inherit and implement this interface for device flexibility.

Currently the `DeviceOpOverrides` only cover Python wrapper code generation. We can futher extend it to cover Cpp wrapper code generation on demand.

Pull Request resolved: pytorch#116020
Approved by: https://github.com/jgong5, https://github.com/EikanWang, https://github.com/jansel
  • Loading branch information
etaf authored and pytorchmergebot committed Dec 22, 2023
1 parent 7d0ad6e commit 7a6cb9f
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 25 deletions.
24 changes: 24 additions & 0 deletions torch/_inductor/codegen/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,16 @@ def index_prevent_reordering(index: List[sympy.Expr], index_vars, sizes):
return [*index, sympy_dot(index_vars, FlexibleLayout.contiguous_strides(sizes))]


def get_device_op_overrides(device: str):
assert isinstance(device, str)
if device == "cuda":
from .cuda.device_op_overrides import CUDADeviceOpOverrides

return CUDADeviceOpOverrides()

return DeviceOpOverrides()


@functools.lru_cache(None)
def boolean_ops():
return (
Expand Down Expand Up @@ -461,6 +471,20 @@ def load_seed(name, offset):
return ops.load(name, sympy.Integer(offset))


class DeviceOpOverrides:
def import_get_raw_stream_as(self, name):
raise NotImplementedError()

def set_device(self, device_idx):
raise NotImplementedError()

def synchronize(self):
raise NotImplementedError()

def device_guard(self, device_idx):
raise NotImplementedError()


class DeferredLine(DeferredLineBase):
"""A line that can be 'unwritten' by adding name to V.graph.removed_buffers"""

Expand Down
15 changes: 15 additions & 0 deletions torch/_inductor/codegen/cuda/device_op_overrides.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from ..common import DeviceOpOverrides


class CUDADeviceOpOverrides(DeviceOpOverrides):
def import_get_raw_stream_as(self, name):
return f"from torch._C import _cuda_getCurrentRawStream as {name}"

def set_device(self, device_idx):
return f"torch.cuda.set_device({device_idx})"

def synchronize(self):
return "torch.cuda.synchronize()"

def device_guard(self, device_idx):
return f"torch.cuda._DeviceGuard({device_idx})"
18 changes: 10 additions & 8 deletions torch/_inductor/codegen/triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -2032,10 +2032,10 @@ def codegen_kernel_benchmark(self):
extra_args_str = None
index = V.graph.scheduler.current_device.index
with result.indent():
result.writeline(f"with torch.cuda._DeviceGuard({index}):")
result.writeline(f"with {V.graph.device_ops.device_guard(index)}:")
with result.indent():
result.writeline(
f"torch.cuda.set_device({index})"
V.graph.device_ops.set_device(index)
) # no-op to ensure context
for tree in self.range_trees:
expr = pexpr(V.graph.sizevars.size_hint(tree.numel))
Expand All @@ -2045,7 +2045,7 @@ def codegen_kernel_benchmark(self):
grid.append(expr)

stream_name = f"stream{index}"
result.writeline(f"{stream_name} = get_cuda_stream({index})")
result.writeline(f"{stream_name} = get_raw_stream({index})")

if self.need_numel_args():
extra_args_str = ", ".join(map(str, extra_args)) + ", "
Expand All @@ -2059,10 +2059,10 @@ def codegen_kernel_benchmark(self):
# benchmark all configs
result.writelines(["\n", "\n", "def benchmark_all_configs(args):"])
with result.indent():
result.writeline(f"with torch.cuda._DeviceGuard({index}):")
result.writeline(f"with {V.graph.device_ops.device_guard(index)}:")
with result.indent():
result.writeline(
f"torch.cuda.set_device({index})"
V.graph.device_ops.set_device(index)
) # no-op to ensure context
result.writeline(
f"return {str(Placeholder.KERNEL_NAME)}.benchmark_all_configs(*args, {extra_args_str}grid=grid({', '.join(grid)}))" # noqa: B950 line too long
Expand Down Expand Up @@ -2093,10 +2093,12 @@ def imports_for_benchmark_kernel(self):
return textwrap.dedent(
"""
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream
{}
import torch
from torch._inductor.triton_heuristics import grid
"""
""".format(
V.graph.device_ops.import_get_raw_stream_as("get_raw_stream")
)
)

@staticmethod
Expand Down Expand Up @@ -2971,7 +2973,7 @@ def codegen_template(self, template_node, epilogue_nodes):
self.scheduler.free_buffers()

def codegen_sync(self):
V.graph.wrapper_code.writeline("torch.cuda.synchronize()")
V.graph.wrapper_code.writeline(V.graph.device_ops.synchronize())

def codegen_foreach(self, foreach_node):
from .triton_foreach import ForeachKernel
Expand Down
32 changes: 15 additions & 17 deletions torch/_inductor/codegen/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def push(self, key, item: "FreeIfNotReusedLine"):


@dataclasses.dataclass
class EnterCudaDeviceContextManagerLine:
class EnterDeviceContextManagerLine:
device_idx: int
last_seen_device_guard_index: Optional[int]

Expand Down Expand Up @@ -237,14 +237,12 @@ def codegen(self, code: IndentedBuffer, device_cm_stack: contextlib.ExitStack):
else:
# Note _DeviceGuard has less overhead than device, but only accepts
# integers
code.writeline(f"with torch.cuda._DeviceGuard({self.device_idx}):")
code.writeline(f"with {V.graph.device_ops.device_guard(self.device_idx)}:")
device_cm_stack.enter_context(code.indent())
code.writeline(
f"torch.cuda.set_device({self.device_idx}) # no-op to ensure context"
)
code.writeline(V.graph.device_ops.set_device(self.device_idx))


class ExitCudaDeviceContextManagerLine:
class ExitDeviceContextManagerLine:
def codegen(self, code: IndentedBuffer, device_cm_stack: contextlib.ExitStack):
if not V.graph.cpp_wrapper:
device_cm_stack.close()
Expand Down Expand Up @@ -449,8 +447,10 @@ def write_triton_header_once(self):
import triton
import triton.language as tl
from torch._inductor.triton_heuristics import grid, start_graph, end_graph
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream
"""
{}
""".format(
V.graph.device_ops.import_get_raw_stream_as("get_raw_stream")
)
)

def add_meta_once(self, meta):
Expand Down Expand Up @@ -492,7 +492,7 @@ def call(args):
)
with self.prefix.indent():
if config.triton.debug_sync_graph:
self.prefix.writeline("torch.cuda.synchronize()")
self.prefix.writeline(V.graph.device_ops.synchronize())
inp_len = len(V.graph.graph_inputs.keys())
if inp_len != 0:
lhs = f"{', '.join(V.graph.graph_inputs.keys())}{'' if inp_len != 1 else ','}"
Expand All @@ -506,22 +506,20 @@ def call(args):
def write_get_raw_stream(self, index):
self.write_triton_header_once()
name = f"stream{index}"
self.writeline(f"{name} = get_cuda_stream({index})")
self.writeline(f"{name} = get_raw_stream({index})")
return name

def next_kernel_suffix(self):
return f"{next(self._names_iter)}"

def codegen_device_guard_enter(self, device_idx):
self.writeline(
EnterCudaDeviceContextManagerLine(
device_idx, self.last_seen_device_guard_index
)
EnterDeviceContextManagerLine(device_idx, self.last_seen_device_guard_index)
)
self.last_seen_device_guard_index = device_idx

def codegen_device_guard_exit(self):
self.writeline(ExitCudaDeviceContextManagerLine())
self.writeline(ExitDeviceContextManagerLine())

def generate_return(self, output_refs):
if output_refs:
Expand Down Expand Up @@ -637,8 +635,8 @@ def generate(self, is_inference):
elif isinstance(
line,
(
EnterCudaDeviceContextManagerLine,
ExitCudaDeviceContextManagerLine,
EnterDeviceContextManagerLine,
ExitDeviceContextManagerLine,
),
):
line.codegen(self.wrapper_call, device_cm_stack)
Expand All @@ -648,7 +646,7 @@ def generate(self, is_inference):
output_refs = self.get_output_refs()
self.mark_output_type()
if config.triton.debug_sync_graph:
self.wrapper_call.writeline("torch.cuda.synchronize()")
self.wrapper_call.writeline(V.graph.device_ops.synchronize())

if config.profile_bandwidth:
self.generate_end_graph()
Expand Down
5 changes: 5 additions & 0 deletions torch/_inductor/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@

from . import config, ir
from .codegen.common import (
DeviceOpOverrides,
get_device_op_overrides,
get_scheduling_for_device,
get_wrapper_codegen_for_device,
register_backend_for_device,
Expand Down Expand Up @@ -220,6 +222,7 @@ def __init__(
self.mutated_buffers: Set[str] = set()
self.never_reuse_buffers: Set[str] = set()
self.inplaced_to_remove: Set[str] = set()
self.device_ops: DeviceOpOverrides = None # type: ignore[assignment]
self.wrapper_code: WrapperCodeGen = None # type: ignore[assignment]
# See `ProxyExecutor Design Note` in ir.py for more details
self.extern_kernel_nodes: List[ir.ExternKernelNode] = []
Expand Down Expand Up @@ -1006,6 +1009,8 @@ def init_wrapper_code(self):
)
only_cpu = len(device_types) == 0
device_type = "cpu" if only_cpu else device_types.pop()

self.device_ops = get_device_op_overrides(device_type)
wrapper_code_gen_cls = get_wrapper_codegen_for_device(device_type)
assert wrapper_code_gen_cls is not None, f"Device {device_type} not supported"
self.wrapper_code = wrapper_code_gen_cls()
Expand Down

0 comments on commit 7a6cb9f

Please sign in to comment.