Skip to content

Commit

Permalink
more CK FP8 rowwise GEMM instances and tuning (pytorch#3455)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#3455

X-link: facebookresearch/FBGEMM#539

Added some MFMA 16x16 instances that seem to help with power efficiency and use them in emu1.7

Reviewed By: jwfromm

Differential Revision: D66776945

fbshipit-source-id: e8c6bf6b626b7528c49c1c0ec0578d4681eb2941
  • Loading branch information
mxz297 authored and facebook-github-bot committed Dec 6, 2024
1 parent 0f8f5e6 commit 82bc8f4
Show file tree
Hide file tree
Showing 4 changed files with 168 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -175,19 +175,19 @@ static const std::unordered_map<std::tuple<int, int, int>, RowwiseKernel, IntTup
fp8_rowwise_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
// EMU 1.7 shapes.
{{1536, 4096, 4096},
fp8_rowwise_256x128x192x128_32x32_2x3_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
fp8_rowwise_256x192x128x128_16x16_6x4_8x32x1_8x32x1_1x32x1x8_8x8x1_2x2_intrawave_v3},
{{3600, 4096, 4096},
fp8_rowwise_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
fp8_rowwise_256x192x256x128_16x16_6x8_8x32x1_8x32x1_1x32x1x8_8x8x1_2x2_intrawave_v3},
{{3600, 11008, 4096},
fp8_rowwise_256x256x192x128_32x32_4x3_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
fp8_rowwise_256x192x256x128_16x16_6x8_8x32x1_8x32x1_1x32x1x8_8x8x1_2x2_intrawave_v3},
{{3600, 4096, 11008},
fp8_rowwise_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
fp8_rowwise_256x192x256x128_16x16_6x8_8x32x1_8x32x1_1x32x1x8_8x8x1_2x2_intrawave_v3},
{{4096, 4096, 4096},
fp8_rowwise_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
fp8_rowwise_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
{{4096, 11008, 4096},
fp8_rowwise_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
fp8_rowwise_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
{{4096, 4096, 11008},
fp8_rowwise_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
fp8_rowwise_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
// Pro Shapes.
{{32768, 128, 8192},
fp8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include "fp8_rowwise_common.h"

at::Tensor
fp8_rowwise_256x192x128x128_16x16_6x4_8x32x1_8x32x1_1x32x1x8_8x8x1_2x2_intrawave_v3(
at::Tensor XQ,
at::Tensor WQ,
at::Tensor x_scale,
at::Tensor w_scale,
at::Tensor Y) {
// A kernel that seems to work well on mid sized tensors.

// Check if this input needs to be padded.
int M = size_to_dim_(XQ.dim() - 1, XQ.sizes());
int N = WQ.size(0);
int K = WQ.size(1);
bool pad = (K % 128 != 0);

// Dispatch based on whether padding is needed or not.
if (pad) {
using DeviceGemmInstance = DeviceGemmHelper<
256,
192,
128,
128,
16,
16,
6,
4,
S<8, 32, 1>,
S<8, 32, 1>,
S<1, 32, 1, 8>,
S<8, 8, 1>,
2,
2,
ck::BlockGemmPipelineScheduler::Intrawave,
ck::BlockGemmPipelineVersion::v3,
ck::tensor_operation::device::GemmSpecialization::KPadding>;
// Run kernel instance.
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(
XQ, WQ, x_scale, w_scale, Y);
} else {
using DeviceGemmInstance = DeviceGemmHelper<
256,
192,
128,
128,
16,
16,
6,
4,
S<8, 32, 1>,
S<8, 32, 1>,
S<1, 32, 1, 8>,
S<8, 8, 1>,
2,
2,
ck::BlockGemmPipelineScheduler::Intrawave,
ck::BlockGemmPipelineVersion::v3,
ck::tensor_operation::device::GemmSpecialization::Default>;
// Run kernel instance.
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(
XQ, WQ, x_scale, w_scale, Y);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include "fp8_rowwise_common.h"

at::Tensor
fp8_rowwise_256x192x256x128_16x16_6x8_8x32x1_8x32x1_1x32x1x8_8x8x1_2x2_intrawave_v3(
at::Tensor XQ,
at::Tensor WQ,
at::Tensor x_scale,
at::Tensor w_scale,
at::Tensor Y) {
// A kernel that seems to work well on mid sized tensors.

// Check if this input needs to be padded.
int M = size_to_dim_(XQ.dim() - 1, XQ.sizes());
int N = WQ.size(0);
int K = WQ.size(1);
bool pad = (K % 128 != 0);

// Dispatch based on whether padding is needed or not.
if (pad) {
using DeviceGemmInstance = DeviceGemmHelper<
256,
192,
256,
128,
16,
16,
6,
8,
S<8, 32, 1>,
S<8, 32, 1>,
S<1, 32, 1, 8>,
S<8, 8, 1>,
2,
2,
ck::BlockGemmPipelineScheduler::Intrawave,
ck::BlockGemmPipelineVersion::v3,
ck::tensor_operation::device::GemmSpecialization::KPadding>;
// Run kernel instance.
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(
XQ, WQ, x_scale, w_scale, Y);
} else {
using DeviceGemmInstance = DeviceGemmHelper<
256,
192,
256,
128,
16,
16,
6,
8,
S<8, 32, 1>,
S<8, 32, 1>,
S<1, 32, 1, 8>,
S<8, 8, 1>,
2,
2,
ck::BlockGemmPipelineScheduler::Intrawave,
ck::BlockGemmPipelineVersion::v3,
ck::tensor_operation::device::GemmSpecialization::Default>;
// Run kernel instance.
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(
XQ, WQ, x_scale, w_scale, Y);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ fp8_rowwise_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v
at::Tensor w_scale,
at::Tensor Y);

// Another varient of larger batch size support.
// Another variant of larger batch size support.
at::Tensor
fp8_rowwise_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2(
at::Tensor XQ,
Expand Down Expand Up @@ -273,3 +273,19 @@ fp8_rowwise_256x128x192x128_32x32_2x3_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave
at::Tensor x_scale,
at::Tensor w_scale,
at::Tensor Y);

at::Tensor
fp8_rowwise_256x192x128x128_16x16_6x4_8x32x1_8x32x1_1x32x1x8_8x8x1_2x2_intrawave_v3(
at::Tensor XQ,
at::Tensor WQ,
at::Tensor x_scale,
at::Tensor w_scale,
at::Tensor Y);

at::Tensor
fp8_rowwise_256x192x256x128_16x16_6x8_8x32x1_8x32x1_1x32x1x8_8x8x1_2x2_intrawave_v3(
at::Tensor XQ,
at::Tensor WQ,
at::Tensor x_scale,
at::Tensor w_scale,
at::Tensor Y);

0 comments on commit 82bc8f4

Please sign in to comment.