Skip to content

Commit

Permalink
Address Segfault (ROCm#537)
Browse files Browse the repository at this point in the history
This is a combination of 6 commits.

Compress

This is a combination of 3 commits.

Compress work

This is a combination of 3 commits.

ignore stuff

find segfault

save

save diff

save

add bias test

clean up

address feedback

fix bug

add asesert
  • Loading branch information
micmelesse authored Mar 21, 2024
1 parent 6c88111 commit 11e0175
Showing 1 changed file with 66 additions and 6 deletions.
72 changes: 66 additions & 6 deletions python/perf-kernels/flash-attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,8 +375,9 @@ def attn_fwd(
order=(1, 0)
)
if BIAS_TYPE != 0:
b_offset = off_h_q * stride_bh # Note: this might get large enough to overflow on some configs
bias_ptr = tl.make_block_ptr(
base=bias + off_h_q * stride_bh,
base=bias + b_offset,
shape=(seqlen_q, seqlen_k),
strides=(stride_bm, stride_bn),
offsets=(start_m * BLOCK_M, 0),
Expand Down Expand Up @@ -783,6 +784,10 @@ def _bwd_kernel_dq(
class _attention(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, o, metadata):
# NOTE: a large bias tensor leads to overflow during pointer arithmetic
if (metadata.bias is not None):
assert(metadata.bias.numel() < 2 ** 31)

if o is None:
o = torch.empty_like(q, dtype=v.dtype)
metadata.check_args(q, k, v, o)
Expand Down Expand Up @@ -964,11 +969,66 @@ def backward(ctx, do, _):
(4, 4, 113, 123, 1),
])
@pytest.mark.parametrize('causal', [False, True])
@pytest.mark.parametrize('use_bias', [False])
def test_op_fwd(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_bias, dtype=torch.float16):
# TODO: using bias causes coredump for certain configs and must be fixed.
if use_bias:
pytest.skip()
def test_op_fwd(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, dtype=torch.float16):
torch.manual_seed(20)
sm_scale = D_HEAD ** -0.5
input_metadata = MetaData(sm_scale=sm_scale)
input_metadata.max_seqlens_q = N_CTX_Q
input_metadata.max_seqlens_k = N_CTX_K
if causal:
input_metadata.need_causal()

q = torch.randn((Z, H, N_CTX_Q, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
k = torch.randn((Z, H, N_CTX_K, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
v = torch.randn((Z, H, N_CTX_K, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
if TORCH_HAS_FP8E5:
q = q.to(torch_dtype)
k = k.to(torch_dtype)
o = torch.empty_like(q)

# triton implementation
tri_out, _ = attention(q, k, v, o, input_metadata)
# reference implementation:171

scores = torch.einsum('bhqd,bhkd->bhqk', q, k).float() * sm_scale
if causal:
mask = torch.tril(torch.ones(N_CTX_Q, N_CTX_K, device="cuda"),
diagonal=N_CTX_K-N_CTX_Q)
scores[:, :, mask==0] = float("-inf")

p = torch.softmax(scores, dim=-1)
if causal:
# If N_CTX_Q > N_CTX_K, there is at least one row of all -infs going into
# the softmax. This produces a row of NaNs as -inf - -inf == NaN. So we fix
# this by converting the NaNs to 0s, which is what they should be out of the softmax.
nan_mask = torch.isnan(p)
p[nan_mask==1] = 0
ref_out = torch.einsum('bhqk,bhkd->bhqd', p.half(), v)
# compare
torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2)


@pytest.mark.parametrize('Z, H, N_CTX_Q, N_CTX_K, D_HEAD',
[(4, 48, 1024, 1024, 64),
(4, 24, 8192, 8192, 64),
(2, 4, 16384, 16384, 128),
(2, 16, 1020, 987, 128),
(2, 16, 15498, 2, 128),
(2, 16, 7, 16219, 64),
(4, 48, 1, 1, 64),
(4, 48, 1, 1, 128),
(4, 48, 3, 3, 128),
(4, 48, 1001, 990, 64),
(1, 8, 8081, 7099, 64),
(1, 8, 16330, 15989, 128),
(4, 4, 1024, 1024, 33),
(4, 4, 65, 1019, 65),
(4, 4, 128, 128, 65),
(4, 4, 113, 123, 1),
])
@pytest.mark.parametrize('causal', [False, True])
@pytest.mark.parametrize('use_bias', [True])
def test_op_fwd_bias(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_bias, dtype=torch.float16):
torch.manual_seed(20)
sm_scale = D_HEAD ** -0.5
input_metadata = MetaData(sm_scale=sm_scale)
Expand Down

0 comments on commit 11e0175

Please sign in to comment.