Skip to content

Commit

Permalink
support flash attention with sparse mask (PaddlePaddle#62029)
Browse files Browse the repository at this point in the history
* add flash attention with sparse mask

* fix doc

* Update python/paddle/nn/functional/flash_attention.py

* Update python/paddle/nn/functional/flash_attention.py

* Update python/paddle/nn/functional/flash_attention.py

* Update python/paddle/nn/functional/flash_attention.py

* fix docstring

---------

Co-authored-by: zachary sun <[email protected]>
Co-authored-by: zachary sun <[email protected]>
  • Loading branch information
3 people authored Mar 28, 2024
1 parent c178bda commit e05764a
Show file tree
Hide file tree
Showing 11 changed files with 722 additions and 116 deletions.
11 changes: 11 additions & 0 deletions paddle/phi/api/yaml/backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -859,6 +859,17 @@
func : flash_attn_unpadded_grad
data_type: q

- backward_op : flash_attn_with_sparse_mask_grad
forward : flash_attn_with_sparse_mask (Tensor q, Tensor k, Tensor v, Tensor attn_mask_start_row_indices, Tensor fixed_seed_offset, float dropout = 0.0, bool causal = false, int attn_mask_start_row = 0, bool return_softmax = false, bool is_test = false, str rng_name = "") -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
args : (Tensor q, Tensor k, Tensor v, Tensor attn_mask_start_row_indices, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, float dropout = 0.0, bool causal = false, int attn_mask_start_row = 0)
output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad)
infer_meta :
func : FlashAttnGradInferMeta
param : [q, k, v]
kernel :
func : flash_attn_with_sparse_mask_grad
data_type: q

- backward_op : flatten_grad
forward : flatten(Tensor x, int start_axis = 1, int stop_axis = 1) -> Tensor(out), Tensor(xshape)
args : (Tensor xshape, Tensor out_grad)
Expand Down
12 changes: 12 additions & 0 deletions paddle/phi/api/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1055,6 +1055,18 @@
intermediate : softmax_lse, seed_offset
backward : flash_attn_unpadded_grad

- op : flash_attn_with_sparse_mask
args : (Tensor q, Tensor k, Tensor v, Tensor attn_mask_start_row_indices, Tensor fixed_seed_offset, float dropout = 0.0, bool causal = false, int attn_mask_start_row = 0, bool return_softmax = false, bool is_test = false, str rng_name = "")
output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
optional : fixed_seed_offset
infer_meta :
func : FlashAttnInferMeta
param : [q, k, v]
kernel :
func : flash_attn_with_sparse_mask
data_type : q
backward : flash_attn_with_sparse_mask_grad

- op : flatten
args : (Tensor x, int start_axis = 1, int stop_axis = 1)
output : Tensor(out), Tensor(xshape)
Expand Down
18 changes: 18 additions & 0 deletions paddle/phi/kernels/flash_attn_grad_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,22 @@ void FlashAttnGradKernel(const Context& ctx,
DenseTensor* dk,
DenseTensor* dv);

template <typename T, typename Context>
void FlashAttnWithSparseMaskGradKernel(
const Context& ctx,
const DenseTensor& q,
const DenseTensor& k,
const DenseTensor& v,
const DenseTensor& attn_mask_start_row_indices,
const DenseTensor& out,
const DenseTensor& softmax_lse,
const DenseTensor& seed_offset,
const DenseTensor& dout,
float dropout,
bool causal,
int attn_mask_start_row,
DenseTensor* dq,
DenseTensor* dk,
DenseTensor* dv);

} // namespace phi
19 changes: 19 additions & 0 deletions paddle/phi/kernels/flash_attn_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,4 +59,23 @@ void FlashAttnKernel(const Context& ctx,
DenseTensor* softmax_lse,
DenseTensor* seed_offset);

template <typename T, typename Context>
void FlashAttnWithSparseMaskKernel(
const Context& ctx,
const DenseTensor& q,
const DenseTensor& k,
const DenseTensor& v,
const DenseTensor& attn_mask_start_row_indices,
const paddle::optional<DenseTensor>& fixed_seed_offset,
float dropout,
bool causal,
int attn_mask_start_row,
bool return_softmax,
bool is_test,
const std::string& rng_name,
DenseTensor* out,
DenseTensor* softmax,
DenseTensor* softmax_lse,
DenseTensor* seed_offset);

} // namespace phi
121 changes: 105 additions & 16 deletions paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,10 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx,
dropout,
scale,
causal,
0, // attn_mask_start_row
q.dtype(),
attn_mask,
nullptr, // attn_mask_start_row_indices
seed_offset.data<int64_t>());

VLOG(10) << "FlashAttn bwd seed: " << params.seed
Expand Down Expand Up @@ -174,22 +176,24 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx,
RaiseNotSupportedError();
#endif
}

template <typename T, typename Context>
void FlashAttnGradKernel(const Context& ctx,
const DenseTensor& q,
const DenseTensor& k,
const DenseTensor& v,
const DenseTensor& out,
const DenseTensor& softmax_lse,
const DenseTensor& seed_offset,
const paddle::optional<DenseTensor>& attn_mask,
const DenseTensor& dout,
float dropout,
bool causal,
DenseTensor* dq,
DenseTensor* dk,
DenseTensor* dv) {
void FlashAttnGradBaseKernel(
const Context& ctx,
const DenseTensor& q,
const DenseTensor& k,
const DenseTensor& v,
const DenseTensor& out,
const DenseTensor& softmax_lse,
const DenseTensor& seed_offset,
const paddle::optional<DenseTensor>& attn_mask,
const paddle::optional<DenseTensor>& attn_mask_start_row_indices,
const DenseTensor& dout,
float dropout,
bool causal,
int attn_mask_start_row,
DenseTensor* dq,
DenseTensor* dk,
DenseTensor* dv) {
#ifdef PADDLE_WITH_FLASHATTN
// q, k, v [batch_size, seq_len, num_heads, head_dim]
const auto& dims = q.dims();
Expand Down Expand Up @@ -259,8 +263,10 @@ void FlashAttnGradKernel(const Context& ctx,
dropout,
softmax_scale,
causal,
attn_mask_start_row,
q.dtype(),
attn_mask,
attn_mask_start_row_indices,
seed_offset.data<int64_t>());

VLOG(10) << "[FlashAttn Forward] q.shape=[" << q.dims() << "], k.shape=["
Expand Down Expand Up @@ -308,7 +314,14 @@ void FlashAttnGradKernel(const Context& ctx,
params.seed,
params.offset,
params.attn_mask_tensor ? params.attn_mask_tensor->data() : nullptr,
params.attn_mask_tensor ? params.mask_dims.data() : nullptr);
params.attn_mask_tensor ? params.mask_dims.data() : nullptr,
params.attn_mask_start_row_indices_tensor
? params.attn_mask_start_row_indices_tensor->data()
: nullptr,
params.attn_mask_start_row_indices_tensor
? params.attn_mask_start_row_indices_dims.data()
: nullptr,
params.attn_mask_start_row);
CheckFlashAttnStatus(succ);
if (!is_mha) {
if (dk) {
Expand All @@ -323,6 +336,73 @@ void FlashAttnGradKernel(const Context& ctx,
#endif
}

template <typename T, typename Context>
void FlashAttnGradKernel(const Context& ctx,
const DenseTensor& q,
const DenseTensor& k,
const DenseTensor& v,
const DenseTensor& out,
const DenseTensor& softmax_lse,
const DenseTensor& seed_offset,
const paddle::optional<DenseTensor>& attn_mask,
const DenseTensor& dout,
float dropout,
bool causal,
DenseTensor* dq,
DenseTensor* dk,
DenseTensor* dv) {
FlashAttnGradBaseKernel<T, Context>(ctx,
q,
k,
v,
out,
softmax_lse,
seed_offset,
attn_mask,
paddle::none,
dout,
dropout,
causal,
0,
dq,
dk,
dv);
}

template <typename T, typename Context>
void FlashAttnWithSparseGradKernel(
const Context& ctx,
const DenseTensor& q,
const DenseTensor& k,
const DenseTensor& v,
const DenseTensor& attn_mask_start_row_indices,
const DenseTensor& out,
const DenseTensor& softmax_lse,
const DenseTensor& seed_offset,
const DenseTensor& dout,
float dropout,
bool causal,
int attn_mask_start_row,
DenseTensor* dq,
DenseTensor* dk,
DenseTensor* dv) {
FlashAttnGradBaseKernel<T, Context>(ctx,
q,
k,
v,
out,
softmax_lse,
seed_offset,
paddle::none,
attn_mask_start_row_indices,
dout,
dropout,
causal,
attn_mask_start_row,
dq,
dk,
dv);
}
} // namespace phi

PD_REGISTER_KERNEL(flash_attn_unpadded_grad,
Expand All @@ -342,3 +422,12 @@ PD_REGISTER_KERNEL(flash_attn_grad,
phi::dtype::bfloat16) {
kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND); // seed_offset
}

PD_REGISTER_KERNEL(flash_attn_with_sparse_mask_grad,
GPU,
ALL_LAYOUT,
phi::FlashAttnWithSparseGradKernel,
phi::dtype::float16,
phi::dtype::bfloat16) {
kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND); // seed_offset
}
Loading

0 comments on commit e05764a

Please sign in to comment.