Skip to content

Commit

Permalink
[CPU EP] Refactor CPU mha (microsoft#16247)
Browse files Browse the repository at this point in the history
Followup of microsoft#16075
  • Loading branch information
cloudhan authored Jun 7, 2023
1 parent f013965 commit 3373160
Showing 1 changed file with 34 additions and 66 deletions.
100 changes: 34 additions & 66 deletions onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,33 @@ Status AddBiasReshape(const Tensor* qkv, // Input: Q/K/V data - query is
return Status::OK();
}

template <typename T>
Status MaybeTransposeToBNSHAndAddBias(OpKernelContext* context, AllocatorPtr allocator,
int batch_size, int num_heads, int sequence_length, int head_size,
const Tensor* in, const Tensor* bias, int bias_offset, OrtValue& out) {
auto element_type = DataTypeImpl::GetType<T>();
std::vector<int64_t> new_dims({batch_size, num_heads, sequence_length, head_size});
gsl::span<const int64_t> new_dims_span{new_dims};
TensorShape v_BNLH(new_dims_span);
Tensor::InitOrtValue(element_type, v_BNLH, allocator, out);
if (bias == nullptr) {
std::unique_ptr<Tensor> reshaped;
if (in->Shape().GetDims().size() == 3) {
reshaped = std::make_unique<Tensor>(in->DataType(), in->Shape(), const_cast<void*>(in->DataRaw()), in->Location());
ORT_RETURN_IF_ERROR(Reshape_BSD_to_BSNH(reshaped.get(), batch_size, sequence_length, num_heads, head_size));
}
ORT_RETURN_IF_ERROR(Transpose_BSNH_to_BNSH((reshaped == nullptr) ? in : reshaped.get(), out));
} else {
const auto* qkv_bias = bias->Data<T>();
if (sequence_length == 1) {
ORT_RETURN_IF_ERROR(AddBiasReshape(in, qkv_bias, out, bias_offset, batch_size, sequence_length, num_heads, head_size, num_heads * head_size, context));
} else {
ORT_RETURN_IF_ERROR(AddBiasTranspose(in, qkv_bias, out, bias_offset, batch_size, sequence_length, num_heads, head_size, num_heads * head_size, context));
}
}
return Status::OK();
};

template <typename T>
Status MultiHeadAttention<T>::Compute(OpKernelContext* context) const {
const Tensor* query = context->Input<Tensor>(0);
Expand Down Expand Up @@ -277,8 +304,6 @@ Status MultiHeadAttention<T>::Compute(OpKernelContext* context) const {
output_shape[2] = static_cast<int64_t>(parameters.v_hidden_size);
Tensor* output = context->Output(0, output_shape);

auto element_type = DataTypeImpl::GetType<T>();
const auto* qkv_bias = (bias == nullptr) ? nullptr : bias->Data<T>();
constexpr int q_bias_offset = 0;
const int k_bias_offset = qk_hidden_size;
const int v_bias_offset = 2 * qk_hidden_size;
Expand All @@ -292,6 +317,7 @@ Status MultiHeadAttention<T>::Compute(OpKernelContext* context) const {
Tensor* present_v = context->Output(2, present_v_shape);

AllocatorPtr allocator;
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));

// For each of Q/K/V, there are multiple scenarios:
// 1) Combined QKV bias is null
Expand All @@ -302,27 +328,8 @@ Status MultiHeadAttention<T>::Compute(OpKernelContext* context) const {
// b) Q/K/V has seq_len > 1

OrtValue Q;
{
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));
std::vector<int64_t> new_dims({batch_size, num_heads_, q_sequence_length, qk_head_size});
gsl::span<const int64_t> new_dims_span{new_dims};
TensorShape q_BNSH(new_dims_span);
Tensor::InitOrtValue(element_type, q_BNSH, allocator, Q);
if (qkv_bias == nullptr) {
std::unique_ptr<Tensor> query_reshaped;
if (query->Shape().GetDims().size() == 3) {
query_reshaped = std::make_unique<Tensor>(query->DataType(), query->Shape(), const_cast<void*>(query->DataRaw()), query->Location());
ORT_RETURN_IF_ERROR(Reshape_BSD_to_BSNH(query_reshaped.get(), batch_size, q_sequence_length, num_heads_, qk_head_size));
}
ORT_RETURN_IF_ERROR(Transpose_BSNH_to_BNSH((query_reshaped == nullptr) ? query : query_reshaped.get(), Q));
} else {
if (q_sequence_length == 1) {
ORT_RETURN_IF_ERROR(AddBiasReshape(query, qkv_bias, Q, q_bias_offset, batch_size, q_sequence_length, num_heads_, qk_head_size, qk_hidden_size, context));
} else {
ORT_RETURN_IF_ERROR(AddBiasTranspose(query, qkv_bias, Q, q_bias_offset, batch_size, q_sequence_length, num_heads_, qk_head_size, qk_hidden_size, context));
}
}
}
ORT_RETURN_IF_ERROR(MaybeTransposeToBNSHAndAddBias<T>(
context, allocator, batch_size, num_heads_, q_sequence_length, qk_head_size, query, bias, q_bias_offset, Q));

if (kv_BNSH) {
// No bias add needed for K/V, key already of shape BxNxLxH, value already of shape BxNxLxH_v
Expand All @@ -334,50 +341,11 @@ Status MultiHeadAttention<T>::Compute(OpKernelContext* context) const {
}

OrtValue K;
{
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));
std::vector<int64_t> new_dims({batch_size, num_heads_, kv_sequence_length, qk_head_size});
gsl::span<const int64_t> new_dims_span{new_dims};
TensorShape k_BNLH(new_dims_span);
Tensor::InitOrtValue(element_type, k_BNLH, allocator, K);
if (qkv_bias == nullptr) {
std::unique_ptr<Tensor> key_reshaped;
if (key->Shape().GetDims().size() == 3) {
key_reshaped = std::make_unique<Tensor>(key->DataType(), key->Shape(), const_cast<void*>(key->DataRaw()), key->Location());
ORT_RETURN_IF_ERROR(Reshape_BSD_to_BSNH(key_reshaped.get(), batch_size, kv_sequence_length, num_heads_, qk_head_size));
}
ORT_RETURN_IF_ERROR(Transpose_BSNH_to_BNSH((key_reshaped == nullptr) ? key : key_reshaped.get(), K));
} else {
if (kv_sequence_length == 1) {
ORT_RETURN_IF_ERROR(AddBiasReshape(key, qkv_bias, K, k_bias_offset, batch_size, kv_sequence_length, num_heads_, qk_head_size, qk_hidden_size, context));
} else {
ORT_RETURN_IF_ERROR(AddBiasTranspose(key, qkv_bias, K, k_bias_offset, batch_size, kv_sequence_length, num_heads_, qk_head_size, qk_hidden_size, context));
}
}
}

OrtValue V;
{
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));
std::vector<int64_t> new_dims({batch_size, num_heads_, kv_sequence_length, v_head_size});
gsl::span<const int64_t> new_dims_span{new_dims};
TensorShape v_BNLH(new_dims_span);
Tensor::InitOrtValue(element_type, v_BNLH, allocator, V);
if (qkv_bias == nullptr) {
std::unique_ptr<Tensor> value_reshaped;
if (value->Shape().GetDims().size() == 3) {
value_reshaped = std::make_unique<Tensor>(value->DataType(), value->Shape(), const_cast<void*>(value->DataRaw()), value->Location());
ORT_RETURN_IF_ERROR(Reshape_BSD_to_BSNH(value_reshaped.get(), batch_size, kv_sequence_length, num_heads_, v_head_size));
}
ORT_RETURN_IF_ERROR(Transpose_BSNH_to_BNSH((value_reshaped == nullptr) ? value : value_reshaped.get(), V));
} else {
if (kv_sequence_length == 1) {
ORT_RETURN_IF_ERROR(AddBiasReshape(value, qkv_bias, V, v_bias_offset, batch_size, kv_sequence_length, num_heads_, v_head_size, v_hidden_size, context));
} else {
ORT_RETURN_IF_ERROR(AddBiasTranspose(value, qkv_bias, V, v_bias_offset, batch_size, kv_sequence_length, num_heads_, v_head_size, v_hidden_size, context));
}
}
}
ORT_RETURN_IF_ERROR(MaybeTransposeToBNSHAndAddBias<T>(
context, allocator, batch_size, num_heads_, kv_sequence_length, qk_head_size, key, bias, k_bias_offset, K));
ORT_RETURN_IF_ERROR(MaybeTransposeToBNSHAndAddBias<T>(
context, allocator, batch_size, num_heads_, kv_sequence_length, v_head_size, value, bias, v_bias_offset, V));

// Compute the attention score and apply the score to V
return ApplyAttention(Q.GetMutable<Tensor>()->MutableData<T>(), K.GetMutable<Tensor>()->MutableData<T>(), V.GetMutable<Tensor>()->MutableData<T>(),
Expand Down

0 comments on commit 3373160

Please sign in to comment.