Skip to content

Commit

Permalink
GEMM custom op enablement (pytorch#3046)
Browse files Browse the repository at this point in the history
Summary:
X-link: facebookresearch/FBGEMM#144

Pull Request resolved: pytorch#3046

Inspired by the following medium article, wanted to implement block-wise benchmarking to see if are getting any perf gains.

This diff, introduces block-wise as a custom op

Reviewed By: jwfromm

Differential Revision: D61800794

fbshipit-source-id: da61af1a61e7adf4072911a6c2020edfe7048d36
  • Loading branch information
Saman Keon authored and facebook-github-bot committed Aug 28, 2024
1 parent dba7263 commit a9a3713
Showing 1 changed file with 20 additions and 0 deletions.
20 changes: 20 additions & 0 deletions fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1467,6 +1467,7 @@ def _kernel_matmul_fp8_block_slowacc(
tl.atomic_add(c, acc, mask=mask)


@torch.library.custom_op("triton::matmul_fp8_block", mutates_args=())
def matmul_fp8_block(
a: torch.Tensor,
b: torch.Tensor,
Expand Down Expand Up @@ -1587,6 +1588,25 @@ def grid(META):
return c.view(output_shape)


@matmul_fp8_block.register_fake
def matmul_fp8_block_meta(
a: torch.Tensor,
b: torch.Tensor,
a_scale: torch.Tensor,
b_scale: torch.Tensor,
scale_block_m: int = 256,
scale_block_n: int = 256,
scale_block_k: int = 256,
dot_out_dtype: Optional[torch.dtype] = None,
allow_tf32: bool = True,
fp8_fast_accum: bool = True,
) -> torch.Tensor:
"""Shape function for torch compile."""
M, K = a.shape
N, K = b.shape
return torch.empty((M, N), device=a.device, dtype=torch.bfloat16)


def get_matmul_tune(M: int, N: int, K: int) -> Tuple[int, int, int]:
"""
Generate a simplified matmul tune key for A @ B.T
Expand Down

0 comments on commit a9a3713

Please sign in to comment.