Skip to content

Commit

Permalink
fMHA: Add backward pass (#844)
Browse files Browse the repository at this point in the history
* fMHA: Add backward pass

* Better checks for strides/alignments

* Remove fb-internal URL

* torch.Tensor.untyped_storage requires pytorch 2.0+

* minor changes

* make test

---------

Co-authored-by: danthe3rd <danthe3rd>
Co-authored-by: Haicheng Wu <[email protected]>
  • Loading branch information
danthe3rd and hwu36 authored Apr 7, 2023
1 parent e2d439e commit 9b8166e
Show file tree
Hide file tree
Showing 9 changed files with 2,931 additions and 17 deletions.
12 changes: 12 additions & 0 deletions examples/41_fused_multi_head_attention/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,20 @@ cutlass_example_add_executable(
fused_multihead_attention_variable_seqlen.cu
)

cutlass_example_add_executable(
41_fused_multi_head_attention_backward
fused_multi_head_attention_backward.cu
DISABLE_TESTS ON
)


add_custom_target(41_fused_multi_head_attention
DEPENDS 41_fused_multi_head_attention_fixed_seqlen
41_fused_multi_head_attention_variable_seqlen
41_fused_multi_head_attention_backward
)

add_test(
NAME ctest_examples_41_fmha_backward_python
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/fmha_backward_test.py $<TARGET_FILE:41_fused_multi_head_attention_backward>
)
199 changes: 199 additions & 0 deletions examples/41_fused_multi_head_attention/fmha_backward_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
import argparse
import torch
import sys
import os
from piped_subprocess import PipedSubprocess, TORCH_DTYPE_NAME
import math


parser = argparse.ArgumentParser()
parser.add_argument("example_exe", type=str, help="Path to the 41_fused_multi_head_attention_backward executable")
args = parser.parse_args()

torch.manual_seed(0)
dtype = torch.float16
B, Mq, Mkv, H, K, Kv = 2, 1024, 1024, 5, 128, 128
causal = True
repeat_count = 100

ATOL = {
torch.float: 5e-4,
torch.half: 9.5e-2,
torch.bfloat16: 7e-1,
}[dtype]

RTOL = {
torch.float: 1e-4,
torch.half: 2e-2,
torch.bfloat16: 1e-1,
}[dtype]


assert not (causal and Mq < Mkv), "causal only supports seqlenK <= seqlenQ"

fmha_bw_binary = args.example_exe
if not os.path.isfile(fmha_bw_binary):
print(f"""No such file: `{fmha_bw_binary}`\nDid you forget to run "make 41_fused_multi_head_attention"?""")
sys.exit(1)

def create_lower_triangular_mask():
return torch.triu(torch.full( # type: ignore
[1, Mq, Mkv],
dtype=dtype,
fill_value=float("-inf"),
), diagonal=1)

def ref_mha_bmk(q, k, v, mask):
# Multi-head attention with inputs/outputs in BMK format
q = q.float()
k = k.float()
v = v.float()

q = q * (1 / q.shape[-1] ** 0.5)
attn = q @ k.transpose(-2, -1)
if mask is not None:
attn += mask
attn_max = attn.max(-1, True).values
attn_norm = (attn - attn_max).exp().sum(-1, True)
attn = attn.softmax(-1)
lse = attn_max + attn_norm.log()
lse = lse.squeeze(2)
return attn @ v, lse


def bmhk2bmk(t):
return t.permute((0, 2, 1, 3)).reshape(
[t.shape[0] * t.shape[2], t.shape[1], t.shape[3]]
)

def ref_mha_bmhk(q, k, v, mask):
# Multi-head attention with inputs/outputs in BMHK format
assert q.ndim == 4

out, lse = ref_mha_bmk(bmhk2bmk(q), bmhk2bmk(k), bmhk2bmk(v), mask=mask)
out = out.reshape([q.shape[0], q.shape[2], q.shape[1], v.shape[3]])
return out.permute((0, 2, 1, 3)), lse.reshape([q.shape[0], q.shape[2], q.shape[1]])

def ref_mha_bw_bmhk(q, k, v, mask, lse, out, grad_out, delta):
lse = lse[:, :, :q.shape[1]] #BMH, unpad Q dimension
delta = delta.reshape([-1, delta.shape[-1], 1])

# bmhk -> bmk
q, k, v, out, grad_out = [bmhk2bmk(x).float() for x in (q, k, v, out, grad_out)]

attn_T = k @ q.transpose(-2, -1)
if mask is not None:
attn_T += mask.transpose(-2, -1)
attn_T = attn_T * (1 / q.shape[-1] ** 0.5)
attn_T = attn_T - lse.reshape([-1, 1, lse.shape[-1]])
attn_T = attn_T.exp()

grad_v = attn_T @ grad_out

dov = grad_out @ v.transpose(-2, -1)
tmp = (dov - delta) * attn_T.transpose(-2, -1)
tmp = tmp / (q.shape[-1] ** 0.5)

grad_q = tmp @ k
grad_k = tmp.transpose(-2, -1) @ q

return [x.reshape([B, H, x.shape[1], x.shape[-1]]).permute([0, 2, 1, 3]) for x in [grad_q, grad_k, grad_v]]


print("initializing tensors...")
query = torch.randn([B, Mq, H, K], dtype=dtype)
key = 3 * torch.randn([B, Mkv, H, K], dtype=dtype)
value = 3 * torch.randn([B, Mkv, H, Kv], dtype=dtype)
mask = create_lower_triangular_mask() if causal else None

# let PyTorch compute gradients
query.requires_grad_(True)
key.requires_grad_(True)
value.requires_grad_(True)

print("computing fw...")
out, lse = ref_mha_bmhk(query, key, value, mask=mask)
out = out.to(dtype).contiguous()
grad_out = 3 * torch.randn([B, Mq, H, Kv], dtype=dtype)

print("computing bw with autograd...")
out.backward(grad_out)
scale = (1 / query.shape[-1] ** 0.5)


# Additional data needed by the kernel
delta = (grad_out.float() * out.float()).sum(-1).transpose(-2, -1).contiguous()
pad_amount = (32 - (lse.shape[2] % 32)) % 32
lse = torch.nn.functional.pad(lse, [0, pad_amount], value=math.inf)

print("computing bw with reference implem...")
gQr, gKr, gVr = ref_mha_bw_bmhk(query, key, value, mask, lse, out, grad_out, delta)

with PipedSubprocess(fmha_bw_binary) as bw_kernel:
# Send kernel arguments
bw_kernel.write(
TORCH_DTYPE_NAME[query.dtype],
"scale", scale,
"head_dim", K,
"head_dim_value", Kv,
"num_queries", Mq,
"num_keys", Mkv,
"num_heads", H,
"custom_mask_type", (1 if causal else 0),
"num_batches", B,
"repeat_count", repeat_count,
)
bw_kernel.writeTensor(query, "query", ["q_strideB", "q_strideM", "q_strideH"])
bw_kernel.writeTensor(key, "key", ["k_strideB", "k_strideM", "k_strideH"])
bw_kernel.writeTensor(value, "value", ["v_strideB", "v_strideM", "v_strideH"])
bw_kernel.writeTensor(lse, "logsumexp", ["lse_strideB", "lse_strideH"])
bw_kernel.writeTensor(out, "output", ["o_strideB", "o_strideM", "o_strideH"])
bw_kernel.writeTensor(grad_out, "grad_output", ["gO_strideB", "gO_strideM", "gO_strideH"])
bw_kernel.writeTensor(delta, "delta", ["delta_strideB", "delta_strideH"])

if bw_kernel.read() != "OK":
print("Got unexpected output")
print(bw_kernel.subp.communicate()[0])
sys.exit(0)

# Read kernel output
gQ = bw_kernel.readTensor("grad_query", ["gQ_strideB", "gQ_strideM", "gQ_strideH"], query.shape).float()
gK = bw_kernel.readTensor("grad_key", ["gK_strideB", "gK_strideM", "gK_strideH"], key.shape).float()
gV = bw_kernel.readTensor("grad_value", ["gV_strideB", "gV_strideM", "gV_strideH"], value.shape).float()
runtime_ms = float(bw_kernel.readNamed("runtime_ms"))

float_ops = B * H * sum([
# att = Q @ K.transpose
Mq * Mkv * K * 2,
# att @ dO
Mkv * Mq * Kv * 2,
# dov = dO @ V
Mq * Kv * Mkv * 2,
# dov @ K
Mq * K * Mkv * 2,
# dov @ Q
Mq * K * Mkv * 2,
])
if causal:
float_ops //= 2

print(f"""
Fused multi-head attention - backward
batch_size={B}
num_queries={Mq}
num_keys={Mkv}
num_heads={H}
head_dim={K}
head_dim_value={Kv}
Correctness:
grad_query: {"PASS" if torch.allclose(gQ, gQr, rtol=RTOL, atol=ATOL) else "FAIL"} (delta: {(gQ - gQr).abs().max()})
grad_key: {"PASS" if torch.allclose(gK, gKr, rtol=RTOL, atol=ATOL) else "FAIL"} (delta: {(gK - gKr).abs().max()})
grad_value: {"PASS" if torch.allclose(gV, gVr, rtol=RTOL, atol=ATOL) else "FAIL"} (delta: {(gV - gVr).abs().max()})
(atol={ATOL} / rtol={RTOL})
Runtime: {runtime_ms}ms ({(float_ops / (1024 ** 4)) / (runtime_ms / 1000):.4f} TFlops)
""")

assert torch.allclose(query.grad.float(), gQr, rtol=RTOL, atol=ATOL), "Reference implementation does not match PyTorch autograd!"
assert torch.allclose(key.grad.float(), gKr, rtol=RTOL, atol=ATOL), "Reference implementation does not match PyTorch autograd!"
assert torch.allclose(value.grad.float(), gVr, rtol=RTOL, atol=ATOL), "Reference implementation does not match PyTorch autograd!"
Loading

0 comments on commit 9b8166e

Please sign in to comment.