forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Inductor Intel GPU backend Upstream] Step 1/3: Generalize device-bia…
…s code in code generation. (pytorch#116020) As the [RFC](pytorch#114856) mentions, this is the step 1 to add Intel GPU backend as an alternative inductor backend. ### Design Typically, in order to integrate Intel GPU backend into Inductor, we need to inherit from `WrapperCodegen` and `TritonScheduling` and implement the corresponding subclasses respectively. However, since `WrapperCodegen` and `TritonScheduling` have some device-bias code generation **scattered** in their methods, overriding them in subclasses would introduce a lot of duplicated parent class code. For example: https://github.com/pytorch/pytorch/blob/2a440348958b3f0a2b09458bd76fe5959b371c0c/torch/_inductor/codegen/wrapper.py#L487 https://github.com/pytorch/pytorch/blob/2a440348958b3f0a2b09458bd76fe5959b371c0c/torch/_inductor/codegen/triton.py#L1996 So we abstract the device-bias code scattered in WrapperCodegen and TritonScheduling and provide a unified interface "DeviceOpOverrides". This way, when integrating a new backend, we can maximize the reuse of `WrapperCodegen` and `TritonScheduling` code by inherit and implement this interface for device flexibility. Currently the `DeviceOpOverrides` only cover Python wrapper code generation. We can futher extend it to cover Cpp wrapper code generation on demand. Pull Request resolved: pytorch#116020 Approved by: https://github.com/jgong5, https://github.com/EikanWang, https://github.com/jansel
- Loading branch information
1 parent
7d0ad6e
commit 7a6cb9f
Showing
5 changed files
with
69 additions
and
25 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
from ..common import DeviceOpOverrides | ||
|
||
|
||
class CUDADeviceOpOverrides(DeviceOpOverrides): | ||
def import_get_raw_stream_as(self, name): | ||
return f"from torch._C import _cuda_getCurrentRawStream as {name}" | ||
|
||
def set_device(self, device_idx): | ||
return f"torch.cuda.set_device({device_idx})" | ||
|
||
def synchronize(self): | ||
return "torch.cuda.synchronize()" | ||
|
||
def device_guard(self, device_idx): | ||
return f"torch.cuda._DeviceGuard({device_idx})" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters