Skip to content

Commit

Permalink
Fix bf16i4bf16 unit test failure (pytorch#2864)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2864

Revert the changes on bf16i4bf16 GEMM routine: causing some errors in unit test.

https://www.internalfb.com/intern/test/562950059123389/

```
_h100#link-tree/gen_ai/llm_inference/fb/llm/llama_layers.py", line 352, in matmul_nt
    return torch.ops.fbgemm.bf16i4bf16_rowwise(x, w.weight, w.scale, w.zero_point)
  File "/re_cwd/buck-out/v2/gen/fbcode/c2e398f2bd191d93/gen_ai/llm_inference/fb/llm/__llama_tests_h100__/llama_tests_h100#link-tree/torch/_ops.py", line 1124, in __call__
    return self_._op(*args, **(kwargs or {}))
RuntimeError: cutlass cannot implement
```

Reviewed By: jiawenliu64

Differential Revision: D59924899

fbshipit-source-id: 97663d8f5274c688b3d7dcbdefc76083dcf7c49f
  • Loading branch information
jianyuh authored and facebook-github-bot committed Jul 19, 2024
1 parent d6790d0 commit 521c3ad
Showing 1 changed file with 6 additions and 20 deletions.
26 changes: 6 additions & 20 deletions fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1719,16 +1719,9 @@ at::Tensor bf16i4bf16_rowwise_impl(
at::Tensor WQ, // INT4
at::Tensor w_scale,
at::Tensor w_zp) {
// XQ: M x K
// WQ: N x K
// output: M x N
int M = size_to_dim_(X.dim() - 1, X.sizes());
int M = X.size(0);
int N = WQ.size(0);
int K = WQ.size(1);
// 1. If the input tensor is {M, K}, the output tensor is {M, N}.
// 2. If the input tensor is {b, M, K}, the output tensor is {b, M, N}.
auto out_sizes = X.sizes().vec();
out_sizes.back() = N;
int K = X.size(1);

int num_groups = w_scale.size(0);

Expand All @@ -1740,7 +1733,7 @@ at::Tensor bf16i4bf16_rowwise_impl(

int group_size = K / num_groups;

auto Y = at::empty(out_sizes, X.options().dtype(at::kBFloat16));
auto Y = at::empty({M, N}, X.options().dtype(at::kBFloat16));

using ElementInputA = cutlass::bfloat16_t;
using LayoutInputA = cutlass::layout::ColumnMajor;
Expand Down Expand Up @@ -1980,16 +1973,9 @@ at::Tensor f8i4bf16_rowwise_impl(
at::Tensor x_scale,
at::Tensor w_scale,
at::Tensor w_zp) {
// XQ: M x K
// WQ: N x K
// output: M x N
int M = size_to_dim_(XQ.dim() - 1, XQ.sizes());
int M = XQ.size(0);
int N = WQ.size(0);
int K = WQ.size(1);
// 1. If the input tensor is {M, K}, the output tensor is {M, N}.
// 2. If the input tensor is {b, M, K}, the output tensor is {b, M, N}.
auto out_sizes = XQ.sizes().vec();
out_sizes.back() = N;
int K = XQ.size(1);

int num_groups = w_scale.size(0);

Expand All @@ -2002,7 +1988,7 @@ at::Tensor f8i4bf16_rowwise_impl(

int group_size = K / num_groups;

auto Y = at::empty(out_sizes, XQ.options().dtype(at::kBFloat16));
auto Y = at::empty({M, N}, XQ.options().dtype(at::kBFloat16));

using ElementInputA = INPUT_DTYPE;
using LayoutInputA = cutlass::layout::ColumnMajor;
Expand Down

0 comments on commit 521c3ad

Please sign in to comment.