Skip to content

Commit

Permalink
Add meta registrations for kv_cache operators (pytorch#3442)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#3442

X-link: facebookresearch/FBGEMM#527

This diff adds full meta registration to fbgemm cache operators, which makes them compatible with torch compile.

Reviewed By: SungMinCho

Differential Revision: D66716728

fbshipit-source-id: 9413f2ef8b837c6a93cd5e1f177b924397a9dd73
  • Loading branch information
jwfromm authored and facebook-github-bot committed Dec 6, 2024
1 parent fa4bc6b commit 32f154e
Show file tree
Hide file tree
Showing 2 changed files with 225 additions and 20 deletions.
219 changes: 211 additions & 8 deletions fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,31 +174,234 @@ at::Tensor mqa_attn(
TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.def("rope_qkv_varseq_prefill(Tensor XQ, Tensor XK, Tensor XV, Tensor(a!) cache_K, Tensor(b!) cache_V, Tensor varseq_batch, Tensor varseq_seqpos, float theta, int? num_groups=1, Tensor? block_tables=None, int page_size=" STRING(
DEFAULT_PAGE_SIZE) ", Tensor? varseq_cache_seqpos=None, int cache_logical_dtype_int=0, bool rope_scaling=False, int old_context_len=8192, float scaling_factor=16, float lo_freq_factor=1, float hi_freq_factor=32, Tensor? qparam_k=None, Tensor? qparam_v=None) -> Tensor");
m.impl("rope_qkv_varseq_prefill", rope_qkv_varseq_prefill);
m.def("rope_qkv_decoding(Tensor XQ, Tensor XK, Tensor XV, Tensor(a!) cache_K, Tensor(b!) cache_V, Tensor seqpos, float theta, int? num_groups=1, Tensor? block_tables=None, int page_size=" STRING(
DEFAULT_PAGE_SIZE) ", Tensor? actual_batch_size=None, Tensor? batch=None, Tensor? cache_seqpos=None, int cache_logical_dtype_int=0, bool rope_scaling=False, int old_context_len=8192, float scaling_factor=16, float lo_freq_factor=1, float hi_freq_factor=32, Tensor? qparam_k=None, Tensor? qparam_v=None) -> Tensor");
m.impl("rope_qkv_decoding", rope_qkv_decoding);
m.def(
"nope_qkv_varseq_prefill(Tensor XQ, Tensor XK, Tensor XV, Tensor(a!) cache_K, Tensor(b!) cache_V, Tensor varseq_batch, Tensor varseq_seqpos, Tensor? block_tables=None, int page_size=" STRING(
DEFAULT_PAGE_SIZE) ", Tensor? varseq_cache_seqpos=None) -> Tensor");
m.impl("nope_qkv_varseq_prefill", nope_qkv_varseq_prefill);
m.def("nope_qkv_decoding(Tensor XQ, Tensor XK, Tensor XV, Tensor(a!) cache_K, Tensor(b!) cache_V, Tensor seqpos, Tensor? block_tables=None, int page_size=" STRING(
DEFAULT_PAGE_SIZE) ", Tensor? actual_batch_size=None, Tensor? batch=None, Tensor? cache_seqpos=None) -> Tensor");
m.impl("nope_qkv_decoding", nope_qkv_decoding);
m.def("xpos_qkv_varseq_prefill(Tensor XQ, Tensor XK, Tensor XV, Tensor(a!) cache_K, Tensor(b!) cache_V, Tensor varseq_batch, Tensor varseq_seqpos, float theta, float gamma, float scale_base, float exponent_offset, int? num_groups=1, Tensor? block_tables=None, int page_size=" STRING(
DEFAULT_PAGE_SIZE) ", Tensor? varseq_cache_seqpos=None, int cache_logical_dtype_int=0, bool rope_scaling=False, int old_context_len=8192, float scaling_factor=16, float lo_freq_factor=1, float hi_freq_factor=32, Tensor? qparam_k=None, Tensor? qparam_v=None) -> Tensor");
m.impl("xpos_qkv_varseq_prefill", xpos_qkv_varseq_prefill);
m.def("xpos_qkv_decoding(Tensor XQ, Tensor XK, Tensor XV, Tensor(a!) cache_K, Tensor(b!) cache_V, Tensor seqpos, float theta, float gamma, float scale_base, float exponent_offset, int? num_groups=1, Tensor? block_tables=None, int page_size=" STRING(
DEFAULT_PAGE_SIZE) ", Tensor? actual_batch_size=None, Tensor? batch=None, Tensor? cache_seqpos=None, int cache_logical_dtype_int=0, bool rope_scaling=False, int old_context_len=8192, float scaling_factor=16, float lo_freq_factor=1, float hi_freq_factor=32, Tensor? qparam_k=None, Tensor? qparam_v=None) -> Tensor");
m.impl("xpos_qkv_decoding", xpos_qkv_decoding);

m.def(
"dequantize_int4_cache(Tensor cache_K, Tensor cache_V, Tensor kv_seqlen, int? num_groups=1) -> (Tensor, Tensor)");
m.impl("dequantize_int4_cache", dequantize_int4_cache);
m.def(
"dequantize_fp8_cache(Tensor cache_K, Tensor cache_V, Tensor kv_seqlen, Tensor? qparam_k=None, Tensor? qparam_v=None, Tensor? block_tables=None, int page_size=" STRING(
DEFAULT_PAGE_SIZE) ") -> (Tensor, Tensor)");
}

TORCH_LIBRARY_IMPL(fbgemm, CPU, m) {
m.impl("rope_qkv_varseq_prefill", rope_qkv_varseq_prefill);
m.impl("rope_qkv_decoding", rope_qkv_decoding);
m.impl("nope_qkv_varseq_prefill", nope_qkv_varseq_prefill);
m.impl("nope_qkv_decoding", nope_qkv_decoding);
m.impl("xpos_qkv_varseq_prefill", xpos_qkv_varseq_prefill);
m.impl("xpos_qkv_decoding", xpos_qkv_decoding);
m.impl("dequantize_int4_cache", dequantize_int4_cache);
m.impl("dequantize_fp8_cache", dequantize_fp8_cache);
}

TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) {
m.impl("rope_qkv_varseq_prefill", rope_qkv_varseq_prefill);
m.impl("rope_qkv_decoding", rope_qkv_decoding);
m.impl("nope_qkv_varseq_prefill", nope_qkv_varseq_prefill);
m.impl("nope_qkv_decoding", nope_qkv_decoding);
m.impl("xpos_qkv_varseq_prefill", xpos_qkv_varseq_prefill);
m.impl("xpos_qkv_decoding", xpos_qkv_decoding);
m.impl("dequantize_int4_cache", dequantize_int4_cache);
m.impl("dequantize_fp8_cache", dequantize_fp8_cache);
}

at::Tensor rope_qkv_varseq_prefill_meta(
at::Tensor XQ,
at::Tensor /* XK */,
at::Tensor /* XV */,
at::Tensor /* cache_K */,
at::Tensor /* cache_V */,
at::Tensor /* varseq_batch */,
at::Tensor /* varseq_seqpos */,
double /* theta */,
std::optional<int64_t> /* num_groups */,
std::optional<at::Tensor> /* block_tables */,
int64_t /* page_size */,
std::optional<at::Tensor> /* varseq_cache_seqpos */,
int64_t /* cache_logical_dtype_int */,
bool /* rope_scaling */,
int64_t /* old_context_len */,
double /* scaling_factor */,
double /* lo_freq_factor */,
double /* hi_freq_factor */,
std::optional<at::Tensor> /* qparam_k */,
std::optional<at::Tensor> /* qparam_v */
) {
return at::empty_like(XQ);
}

at::Tensor rope_qkv_decoding_meta(
at::Tensor XQ,
at::Tensor /* XK */,
at::Tensor /* XV */,
at::Tensor /* cache_K */,
at::Tensor /* cache_V */,
at::Tensor /* seqpos */,
double /* theta */,
std::optional<int64_t> /* num_groups */,
std::optional<at::Tensor> /* block_tables */,
int64_t /* page_size */,
std::optional<at::Tensor> /* actual_batch_size */,
std::optional<at::Tensor> /* batch */,
std::optional<at::Tensor> /* cache_seqpos */,
int64_t /* cache_logical_dtype_int */,
bool /* rope_scaling */,
int64_t /* old_context_len */,
double /* scaling_factor */,
double /* lo_freq_factor */,
double /* hi_freq_factor */,
std::optional<at::Tensor> /* qparam_k */,
std::optional<at::Tensor> /* qparam_v */
) {
return at::empty_like(XQ);
}

at::Tensor nope_qkv_varseq_prefill_meta(
at::Tensor XQ,
at::Tensor /* XK */,
at::Tensor /* XV */,
at::Tensor /* cache_K */,
at::Tensor /* cache_V */,
at::Tensor /* varseq_batch */,
at::Tensor /* varseq_seqpos */,
std::optional<at::Tensor> /* block_tables */,
int64_t /* page_size */,
std::optional<at::Tensor> /* varseq_cache_seqpos */
) {
return at::empty_like(XQ);
}

at::Tensor nope_qkv_decoding_meta(
at::Tensor XQ,
at::Tensor /* XK */,
at::Tensor /* XV */,
at::Tensor /* cache_K */,
at::Tensor /* cache_V */,
at::Tensor /* seqpos */,
std::optional<at::Tensor> /* block_tables */,
int64_t /* page_size */,
std::optional<at::Tensor> /* actual_batch_size */,
std::optional<at::Tensor> /* batch */,
std::optional<at::Tensor> /* cache_seqpos */
) {
return at::empty_like(XQ);
}

at::Tensor xpos_qkv_varseq_prefill_meta(
at::Tensor XQ,
at::Tensor /* XK */,
at::Tensor /* XV */,
at::Tensor /* cache_K */,
at::Tensor /* cache_V */,
at::Tensor /* varseq_batch */,
at::Tensor /* varseq_seqpos */,
double /* theta */,
double /* gamma */,
double /* scale_base */,
double /* exponent_offset */,
std::optional<int64_t> /* num_groups */,
std::optional<at::Tensor> /* block_tables */,
int64_t /* page_size */,
std::optional<at::Tensor> /* varseq_cache_seqpos */,
int64_t /* cache_logical_dtype_int */,
bool /* rope_scaling */,
int64_t /* old_context_len */,
double /* scaling_factor */,
double /* lo_freq_factor */,
double /* hi_freq_factor */,
std::optional<at::Tensor> /* qparam_k */,
std::optional<at::Tensor> /* qparam_v */
) {
return at::empty_like(XQ);
}

at::Tensor xpos_qkv_decoding_meta(
at::Tensor XQ,
at::Tensor /* XK */,
at::Tensor /* XV */,
at::Tensor /* cache_K */,
at::Tensor /* cache_V */,
at::Tensor /* seqpos */,
double /* theta */,
double /* gamma */,
double /* scale_base */,
double /* exponent_offset */,
std::optional<int64_t> /* num_groups */,
std::optional<at::Tensor> /* block_tables */,
int64_t /* page_size */,
std::optional<at::Tensor> /* actual_batch_size */,
std::optional<at::Tensor> /* batch */,
std::optional<at::Tensor> /* cache_seqpos */,
int64_t /* cache_logical_dtype_int */,
bool /* rope_scaling */,
int64_t /* old_context_len */,
double /* scaling_factor */,
double /* lo_freq_factor */,
double /* hi_freq_factor */,
std::optional<at::Tensor> /* qparam_k */,
std::optional<at::Tensor> /* qparam_v */
) {
return at::empty_like(XQ);
}

std::tuple<at::Tensor, at::Tensor> dequantize_int4_cache_meta(
at::Tensor cache_K,
at::Tensor /* cache_V */,
at::Tensor /* kv_seqlen */,
std::optional<int64_t> num_groups) {
const at::SymInt B = cache_K.sym_size(0);
const at::SymInt MAX_T = cache_K.sym_size(1);
const at::SymInt N_KVH = cache_K.sym_size(2);
const at::SymInt D_HQ = cache_K.sym_size(3);
auto num_groups_ = num_groups ? num_groups.value() : 1;
auto int4_qparam_offset = 4 * num_groups_;
const at::SymInt D_H = (D_HQ - int4_qparam_offset) * 2;
auto cache_K_dq = at::empty_symint(
{B, MAX_T, N_KVH, D_H}, cache_K.options().dtype(at::kBFloat16));
auto cache_V_dq = at::empty_symint(
{B, MAX_T, N_KVH, D_H}, cache_K.options().dtype(at::kBFloat16));
return {cache_K_dq, cache_V_dq};
}

std::tuple<at::Tensor, at::Tensor> dequantize_fp8_cache_meta(
at::Tensor cache_K,
at::Tensor /* cache_V */,
at::Tensor /* kv_seqlen */,
std::optional<at::Tensor> qparam_k,
std::optional<at::Tensor> /* qparam_v */,
std::optional<at::Tensor> /* block_tables */,
int64_t /* page_size */) {
const at::SymInt B_KV = cache_K.sym_size(0);
const at::SymInt MAX_T = cache_K.sym_size(1);
const at::SymInt N_KVH = cache_K.sym_size(2);
const at::SymInt D_HQ = cache_K.sym_size(3);
auto fp8_qparam_offset = qparam_k ? 0 : 4;
const at::SymInt D_H = (D_HQ - fp8_qparam_offset);
auto cache_K_dq = at::empty_symint(
{B_KV, MAX_T, N_KVH, D_H}, cache_K.options().dtype(at::kBFloat16));
auto cache_V_dq = at::empty_symint(
{B_KV, MAX_T, N_KVH, D_H}, cache_K.options().dtype(at::kBFloat16));
return {cache_K_dq, cache_V_dq};
}

TORCH_LIBRARY_IMPL(fbgemm, Meta, m) {
m.impl("rope_qkv_varseq_prefill", rope_qkv_varseq_prefill_meta);
m.impl("rope_qkv_decoding", rope_qkv_decoding_meta);
m.impl("nope_qkv_varseq_prefill", nope_qkv_varseq_prefill_meta);
m.impl("nope_qkv_decoding", nope_qkv_decoding_meta);
m.impl("xpos_qkv_varseq_prefill", xpos_qkv_varseq_prefill_meta);
m.impl("xpos_qkv_decoding", xpos_qkv_decoding_meta);
m.impl("dequantize_int4_cache", dequantize_int4_cache_meta);
m.impl("dequantize_fp8_cache", dequantize_fp8_cache_meta);
}

} // namespace fbgemm_gpu
26 changes: 14 additions & 12 deletions fbgemm_gpu/experimental/gen_ai/test/kv_cache/kv_cache_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def test_int4_kv_cache(self, num_groups: int, MAX_T: int, N_KVH_L: int) -> None:
size=(B, MAX_T, N_KVH_L, D_H), dtype=torch.bfloat16, device="cuda"
)

xq_out_bf16 = torch.ops.fbgemm.rope_qkv_varseq_prefill(
xq_out_bf16 = torch.compile(torch.ops.fbgemm.rope_qkv_varseq_prefill)(
xq,
xk,
xv,
Expand All @@ -152,7 +152,7 @@ def test_int4_kv_cache(self, num_groups: int, MAX_T: int, N_KVH_L: int) -> None:
dtype=torch.uint8,
device="cuda",
)
xq_out = torch.ops.fbgemm.rope_qkv_varseq_prefill(
xq_out = torch.compile(torch.ops.fbgemm.rope_qkv_varseq_prefill)(
xq,
xk,
xv,
Expand All @@ -166,12 +166,13 @@ def test_int4_kv_cache(self, num_groups: int, MAX_T: int, N_KVH_L: int) -> None:
)
torch.testing.assert_close(xq_out_bf16, xq_out)

cache_k, cache_v = torch.ops.fbgemm.dequantize_int4_cache(
dequantized_cache = torch.compile(torch.ops.fbgemm.dequantize_int4_cache)(
cache_k_int4,
cache_v_int4,
attn_bias.k_seqinfo.seqlen,
num_groups=num_groups,
)
cache_k, cache_v = dequantized_cache

torch.testing.assert_close(
cache_k[:, :T], cache_k_bf16[:, :T], atol=1.0e-2, rtol=1.0e-2
Expand Down Expand Up @@ -260,7 +261,7 @@ def test_fp8_kv_cache(self, MAX_T: int, N_KVH_L: int) -> None:
size=(B, MAX_T, N_KVH_L, D_H), dtype=torch.bfloat16, device="cuda"
)

xq_out_bf16 = torch.ops.fbgemm.rope_qkv_varseq_prefill(
xq_out_bf16 = torch.compile(torch.ops.fbgemm.rope_qkv_varseq_prefill)(
xq,
xk,
xv,
Expand All @@ -282,7 +283,7 @@ def test_fp8_kv_cache(self, MAX_T: int, N_KVH_L: int) -> None:
dtype=torch.uint8,
device="cuda",
)
xq_out = torch.ops.fbgemm.rope_qkv_varseq_prefill(
xq_out = torch.compile(torch.ops.fbgemm.rope_qkv_varseq_prefill)(
xq,
xk,
xv,
Expand All @@ -295,11 +296,12 @@ def test_fp8_kv_cache(self, MAX_T: int, N_KVH_L: int) -> None:
)
torch.testing.assert_close(xq_out_bf16, xq_out)

cache_k, cache_v = torch.ops.fbgemm.dequantize_fp8_cache(
dequantized_cache = torch.compile(torch.ops.fbgemm.dequantize_fp8_cache)(
cache_k_fp8,
cache_v_fp8,
attn_bias.k_seqinfo.seqlen,
)
cache_k, cache_v = dequantized_cache

torch.testing.assert_close(
cache_k[:, :T], cache_k_bf16[:, :T], atol=1.0e-2, rtol=5.0e-2
Expand Down Expand Up @@ -390,9 +392,9 @@ def test_positional_encoding_with_paged_attention(

if rope_theta is not None:
func = (
torch.ops.fbgemm.rope_qkv_varseq_prefill
torch.compile(torch.ops.fbgemm.rope_qkv_varseq_prefill)
if prefill
else torch.ops.fbgemm.rope_qkv_decoding
else torch.compile(torch.ops.fbgemm.rope_qkv_decoding)
)
xq_out_ref = func(
xq,
Expand All @@ -418,9 +420,9 @@ def test_positional_encoding_with_paged_attention(
)
else:
func = (
torch.ops.fbgemm.xpos_qkv_varseq_prefill
torch.compile(torch.ops.fbgemm.xpos_qkv_varseq_prefill)
if prefill
else torch.ops.fbgemm.xpos_qkv_decoding
else torch.compile(torch.ops.fbgemm.xpos_qkv_decoding)
)
xq_out_ref = func(
xq,
Expand Down Expand Up @@ -537,9 +539,9 @@ def test_rope_positional_encoding_only(
seqpos_args = (seq_positions,)

func = (
torch.ops.fbgemm.rope_qkv_varseq_prefill
torch.compile(torch.ops.fbgemm.rope_qkv_varseq_prefill)
if prefill
else torch.ops.fbgemm.rope_qkv_decoding
else torch.compile(torch.ops.fbgemm.rope_qkv_decoding)
)
xq_out = func(
xq,
Expand Down

0 comments on commit 32f154e

Please sign in to comment.