From 33731608639476f2cf3d3555975e13f8549b8554 Mon Sep 17 00:00:00 2001 From: cloudhan Date: Wed, 7 Jun 2023 14:41:14 +0800 Subject: [PATCH] [CPU EP] Refactor CPU mha (#16247) Followup of #16075 --- .../cpu/bert/multihead_attention.cc | 100 ++++++------------ 1 file changed, 34 insertions(+), 66 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc index b025bf4f57bc5..0b55cb7804c61 100644 --- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc @@ -225,6 +225,33 @@ Status AddBiasReshape(const Tensor* qkv, // Input: Q/K/V data - query is return Status::OK(); } +template +Status MaybeTransposeToBNSHAndAddBias(OpKernelContext* context, AllocatorPtr allocator, + int batch_size, int num_heads, int sequence_length, int head_size, + const Tensor* in, const Tensor* bias, int bias_offset, OrtValue& out) { + auto element_type = DataTypeImpl::GetType(); + std::vector new_dims({batch_size, num_heads, sequence_length, head_size}); + gsl::span new_dims_span{new_dims}; + TensorShape v_BNLH(new_dims_span); + Tensor::InitOrtValue(element_type, v_BNLH, allocator, out); + if (bias == nullptr) { + std::unique_ptr reshaped; + if (in->Shape().GetDims().size() == 3) { + reshaped = std::make_unique(in->DataType(), in->Shape(), const_cast(in->DataRaw()), in->Location()); + ORT_RETURN_IF_ERROR(Reshape_BSD_to_BSNH(reshaped.get(), batch_size, sequence_length, num_heads, head_size)); + } + ORT_RETURN_IF_ERROR(Transpose_BSNH_to_BNSH((reshaped == nullptr) ? in : reshaped.get(), out)); + } else { + const auto* qkv_bias = bias->Data(); + if (sequence_length == 1) { + ORT_RETURN_IF_ERROR(AddBiasReshape(in, qkv_bias, out, bias_offset, batch_size, sequence_length, num_heads, head_size, num_heads * head_size, context)); + } else { + ORT_RETURN_IF_ERROR(AddBiasTranspose(in, qkv_bias, out, bias_offset, batch_size, sequence_length, num_heads, head_size, num_heads * head_size, context)); + } + } + return Status::OK(); +}; + template Status MultiHeadAttention::Compute(OpKernelContext* context) const { const Tensor* query = context->Input(0); @@ -277,8 +304,6 @@ Status MultiHeadAttention::Compute(OpKernelContext* context) const { output_shape[2] = static_cast(parameters.v_hidden_size); Tensor* output = context->Output(0, output_shape); - auto element_type = DataTypeImpl::GetType(); - const auto* qkv_bias = (bias == nullptr) ? nullptr : bias->Data(); constexpr int q_bias_offset = 0; const int k_bias_offset = qk_hidden_size; const int v_bias_offset = 2 * qk_hidden_size; @@ -292,6 +317,7 @@ Status MultiHeadAttention::Compute(OpKernelContext* context) const { Tensor* present_v = context->Output(2, present_v_shape); AllocatorPtr allocator; + ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); // For each of Q/K/V, there are multiple scenarios: // 1) Combined QKV bias is null @@ -302,27 +328,8 @@ Status MultiHeadAttention::Compute(OpKernelContext* context) const { // b) Q/K/V has seq_len > 1 OrtValue Q; - { - ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); - std::vector new_dims({batch_size, num_heads_, q_sequence_length, qk_head_size}); - gsl::span new_dims_span{new_dims}; - TensorShape q_BNSH(new_dims_span); - Tensor::InitOrtValue(element_type, q_BNSH, allocator, Q); - if (qkv_bias == nullptr) { - std::unique_ptr query_reshaped; - if (query->Shape().GetDims().size() == 3) { - query_reshaped = std::make_unique(query->DataType(), query->Shape(), const_cast(query->DataRaw()), query->Location()); - ORT_RETURN_IF_ERROR(Reshape_BSD_to_BSNH(query_reshaped.get(), batch_size, q_sequence_length, num_heads_, qk_head_size)); - } - ORT_RETURN_IF_ERROR(Transpose_BSNH_to_BNSH((query_reshaped == nullptr) ? query : query_reshaped.get(), Q)); - } else { - if (q_sequence_length == 1) { - ORT_RETURN_IF_ERROR(AddBiasReshape(query, qkv_bias, Q, q_bias_offset, batch_size, q_sequence_length, num_heads_, qk_head_size, qk_hidden_size, context)); - } else { - ORT_RETURN_IF_ERROR(AddBiasTranspose(query, qkv_bias, Q, q_bias_offset, batch_size, q_sequence_length, num_heads_, qk_head_size, qk_hidden_size, context)); - } - } - } + ORT_RETURN_IF_ERROR(MaybeTransposeToBNSHAndAddBias( + context, allocator, batch_size, num_heads_, q_sequence_length, qk_head_size, query, bias, q_bias_offset, Q)); if (kv_BNSH) { // No bias add needed for K/V, key already of shape BxNxLxH, value already of shape BxNxLxH_v @@ -334,50 +341,11 @@ Status MultiHeadAttention::Compute(OpKernelContext* context) const { } OrtValue K; - { - ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); - std::vector new_dims({batch_size, num_heads_, kv_sequence_length, qk_head_size}); - gsl::span new_dims_span{new_dims}; - TensorShape k_BNLH(new_dims_span); - Tensor::InitOrtValue(element_type, k_BNLH, allocator, K); - if (qkv_bias == nullptr) { - std::unique_ptr key_reshaped; - if (key->Shape().GetDims().size() == 3) { - key_reshaped = std::make_unique(key->DataType(), key->Shape(), const_cast(key->DataRaw()), key->Location()); - ORT_RETURN_IF_ERROR(Reshape_BSD_to_BSNH(key_reshaped.get(), batch_size, kv_sequence_length, num_heads_, qk_head_size)); - } - ORT_RETURN_IF_ERROR(Transpose_BSNH_to_BNSH((key_reshaped == nullptr) ? key : key_reshaped.get(), K)); - } else { - if (kv_sequence_length == 1) { - ORT_RETURN_IF_ERROR(AddBiasReshape(key, qkv_bias, K, k_bias_offset, batch_size, kv_sequence_length, num_heads_, qk_head_size, qk_hidden_size, context)); - } else { - ORT_RETURN_IF_ERROR(AddBiasTranspose(key, qkv_bias, K, k_bias_offset, batch_size, kv_sequence_length, num_heads_, qk_head_size, qk_hidden_size, context)); - } - } - } - OrtValue V; - { - ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); - std::vector new_dims({batch_size, num_heads_, kv_sequence_length, v_head_size}); - gsl::span new_dims_span{new_dims}; - TensorShape v_BNLH(new_dims_span); - Tensor::InitOrtValue(element_type, v_BNLH, allocator, V); - if (qkv_bias == nullptr) { - std::unique_ptr value_reshaped; - if (value->Shape().GetDims().size() == 3) { - value_reshaped = std::make_unique(value->DataType(), value->Shape(), const_cast(value->DataRaw()), value->Location()); - ORT_RETURN_IF_ERROR(Reshape_BSD_to_BSNH(value_reshaped.get(), batch_size, kv_sequence_length, num_heads_, v_head_size)); - } - ORT_RETURN_IF_ERROR(Transpose_BSNH_to_BNSH((value_reshaped == nullptr) ? value : value_reshaped.get(), V)); - } else { - if (kv_sequence_length == 1) { - ORT_RETURN_IF_ERROR(AddBiasReshape(value, qkv_bias, V, v_bias_offset, batch_size, kv_sequence_length, num_heads_, v_head_size, v_hidden_size, context)); - } else { - ORT_RETURN_IF_ERROR(AddBiasTranspose(value, qkv_bias, V, v_bias_offset, batch_size, kv_sequence_length, num_heads_, v_head_size, v_hidden_size, context)); - } - } - } + ORT_RETURN_IF_ERROR(MaybeTransposeToBNSHAndAddBias( + context, allocator, batch_size, num_heads_, kv_sequence_length, qk_head_size, key, bias, k_bias_offset, K)); + ORT_RETURN_IF_ERROR(MaybeTransposeToBNSHAndAddBias( + context, allocator, batch_size, num_heads_, kv_sequence_length, v_head_size, value, bias, v_bias_offset, V)); // Compute the attention score and apply the score to V return ApplyAttention(Q.GetMutable()->MutableData(), K.GetMutable()->MutableData(), V.GetMutable()->MutableData(),