Skip to content

Commit

Permalink
[Inductor][FlexAttention] Correct partial/full blocks naming (pytorch…
Browse files Browse the repository at this point in the history
…#131993)

Pull Request resolved: pytorch#131993
Approved by: https://github.com/drisspg
  • Loading branch information
yanboliang authored and pytorchmergebot committed Jul 30, 2024
1 parent 03e0581 commit 54d4f6b
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions torch/nn/attention/flex_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,19 +538,19 @@ def _create_sparse_block_from_block_mask(
KV_BLOCK_SIZE: int = _DEFAULT_SPARSE_BLOCK_SIZE,
Q_BLOCK_SIZE: int = _DEFAULT_SPARSE_BLOCK_SIZE,
) -> BlockMask:
full_blocks, partial_blocks = block_mask
partial_blocks, full_blocks = block_mask

full_bm = _dense_to_ordered(full_blocks)
if partial_blocks is not None:
partial_bm = _dense_to_ordered(partial_blocks)
partial_bm = _dense_to_ordered(partial_blocks)
if full_blocks is not None:
full_bm = _dense_to_ordered(full_blocks)
else:
partial_bm = (None, None)
full_bm = (None, None)

return BlockMask( # type: ignore[call-arg]
full_bm[0],
full_bm[1],
partial_bm[0],
partial_bm[1],
full_bm[0],
full_bm[1],
BLOCK_SIZE=(KV_BLOCK_SIZE, Q_BLOCK_SIZE),
mask_mod=mask_mod,
)
Expand Down Expand Up @@ -622,14 +622,14 @@ def _create_block_mask_inner(
with the __torch_function__ mode.
"""
mask_tensor = create_mask(mask_mod, B, H, Q_LEN, KV_LEN, device, _compile=True)
full_block_mask, partial_block_mask = _convert_mask_to_block_mask(
partial_block_mask, full_block_mask = _convert_mask_to_block_mask(
mask_tensor,
KV_BLOCK_SIZE=KV_BLOCK_SIZE,
Q_BLOCK_SIZE=Q_BLOCK_SIZE,
separate_full_blocks=True,
)
return _create_sparse_block_from_block_mask(
(full_block_mask, partial_block_mask), mask_mod
(partial_block_mask, full_block_mask), mask_mod
)


Expand Down

0 comments on commit 54d4f6b

Please sign in to comment.