Skip to content

Commit

Permalink
[inductor] Add FileCheck to flex attention epilogue test (pytorch#129343
Browse files Browse the repository at this point in the history
)

Pull Request resolved: pytorch#129343
Approved by: https://github.com/lezcano
ghstack dependencies: pytorch#128893, pytorch#129325
  • Loading branch information
peterbell10 authored and pytorchmergebot committed Jul 2, 2024
1 parent 7955cd3 commit 45844e0
Showing 1 changed file with 14 additions and 1 deletion.
15 changes: 14 additions & 1 deletion test/inductor/test_flex_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,7 +930,20 @@ def f(q, k, v):

q, k, v = (torch.randn(1, 8, 1024, 64, device="cuda") for _ in range(3))
metrics.reset()
f(q, k, v)
_, code = run_and_get_code(f, q, k, v)
# TODO: attention output is not being DCE'd
fc = FileCheck()
fc.check("buf0 = empty_strided_cuda((1, 1, 1)") # SPARSE_KV_NUM_BLKS
fc.check("buf1 = empty_strided_cuda((1, 1, 1, 1)") # SPARSE_KV_IDX
fc.check("buf4 = empty_strided_cuda") # logsumexp
fc.check("buf5 = empty_strided_cuda") # attention output
fc.check("buf7 = empty_strided_cuda") # cos(attention)
fc.run(code[0])
fc = FileCheck()
fc.check_not("buf2 =") # Dead buffer
fc.check_not("buf3 =") # Dead buffer
fc.check_not("buf6 =") # Mutation-buffer, not allocated
fc.run(code[0])
accessed_bytes = 1 * 8 * 1024 * 64 * torch.float32.itemsize
num_accesses = 4 # q, k, v reads, one output.
# TODO: Get rid of this fudge factor
Expand Down

0 comments on commit 45844e0

Please sign in to comment.