Skip to content

Commit

Permalink
[TUTORIALS] Minor fix for tutorial 06 (triton-lang#3986)
Browse files Browse the repository at this point in the history
Based on 
`@pytest.mark.parametrize("Z, H, N_CTX, HEAD_DIM", [(1, 2, 1024, 64)])`
`BATCH, N_HEADS, HEAD_DIM = 4, 32, 64`
the HEAD_DIM is `64` in both pytest and benchmark, which triggers
assertion failure from `tl.static_assert(BLOCK_N <= HEAD_DIM)` since
BLOCK_N can be `128` in the tunning.
Therefore, change the tunning size for `BN` as `[32, 64]`.

Both pytest and benchmark runs fine on GPU after the fix.
  • Loading branch information
lancerts authored May 29, 2024
1 parent 284f292 commit a07826d
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions python/tutorials/06-fused-attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,13 @@ def _attn_fwd_inner(acc, l_i, m_i, q, #
return acc, l_i, m_i


# We don't run auto-tuning every time to keep the tutorial fast. Uncommenting
# We don't run auto-tuning every time to keep the tutorial fast. Keeping
# the code below and commenting out the equivalent parameters is convenient for
# re-tuning.
configs = [
triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN}, num_stages=s, num_warps=w) \
for BM in [64, 128]\
for BN in [64, 128]\
for BN in [32, 64]\
for s in ([1] if is_hip() else [3, 4, 7])\
for w in [4, 8]\
]
Expand Down Expand Up @@ -551,7 +551,7 @@ def test_op(Z, H, N_CTX, HEAD_DIM, causal, dtype=torch.float16):
assert torch.allclose(ref_out, tri_out, atol=1e-2, rtol=0)
rtol = 0.0
# Relative tolerance workaround for known hardware limitation of MI200 GPU.
# For detailss see https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices
# For details see https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices
if torch.version.hip is not None and triton.runtime.driver.active.get_current_target().arch == "gfx90a":
rtol = 1e-2
assert torch.allclose(ref_dv, tri_dv, atol=1e-2, rtol=rtol)
Expand Down Expand Up @@ -603,9 +603,9 @@ def bench_flash_attention(BATCH, H, N_CTX, HEAD_DIM, causal, mode, provider, dev
rep = 100
dtype = torch.float16
if "triton" in provider:
q = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device="cuda", requires_grad=True)
k = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device="cuda", requires_grad=True)
v = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device="cuda", requires_grad=True)
q = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True)
k = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True)
v = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True)
if mode == "fwd" and "fp8" in provider:
q = q.to(torch.float8_e5m2)
k = k.to(torch.float8_e5m2)
Expand Down

0 comments on commit a07826d

Please sign in to comment.