Skip to content

Commit

Permalink
Add SDPA backend test with KV cache
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca committed Oct 24, 2023
1 parent 50b1974 commit 4dd2e17
Showing 1 changed file with 48 additions and 0 deletions.
48 changes: 48 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,3 +519,51 @@ def assert_sdpa_uses_flash(original_fn, q, k, v, mask):
expected = SDPBackend.EFFICIENT_ATTENTION if config.head_size % 8 == 0 else SDPBackend.MATH
with torch.backends.cuda.sdp_kernel(enable_flash=False):
model(x)


@pytest.mark.skipif(not SUPPORTS_FUSED_ATTENTION, reason="Unsupported")
@pytest.mark.parametrize("config", config_module.configs, ids=[c["name"] for c in config_module.configs])
@torch.inference_mode()
def test_sdpa_choice_kv_cache(config):
from torch.backends.cuda import SDPBackend

from lit_gpt import GPT

torch.set_default_dtype(torch.float16)

def assert_sdpa_uses_flash(original_fn, q, k, v, mask):
choice = torch._fused_sdp_choice(q, k, v, mask, is_causal=True)
assert choice == expected
return original_fn(q, k, v, mask)

config["n_layer"] = 1
config = config_module.Config(**config)

try:
with torch.device("cuda"):
model = GPT(config)
model.max_seq_length = 1
model.set_kv_cache(2)
x = torch.randint(0, 10, (2, 1), dtype=torch.int32)
input_pos = torch.tensor([0], dtype=torch.long)
except torch.cuda.OutOfMemoryError:
# best effort, if the GPU can load it
pytest.xfail()

for h in model.transformer.h:
h.attn.scaled_dot_product_attention = partial(assert_sdpa_uses_flash, h.attn.scaled_dot_product_attention)

if SUPPORTS_FLASH_ATTENTION:
# flash attention does not support an attention mask
expected = SDPBackend.MATH
with torch.backends.cuda.sdp_kernel(enable_mem_efficient=False):
model(x, input_pos)

if SUPPORTS_MEM_EFF_ATTENTION:
expected = (
SDPBackend.EFFICIENT_ATTENTION
if config.head_size % 8 == 0 and config.n_query_groups != 1
else SDPBackend.MATH
)
with torch.backends.cuda.sdp_kernel(enable_flash=False):
model(x, input_pos)

0 comments on commit 4dd2e17

Please sign in to comment.