forked from microsoft/onnxruntime
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
LLaMA Model Optimization (microsoft#18021)
### Description This PR contains fusion-level and kernel-level optimizations for [Meta's LLaMA-2](https://blogs.microsoft.com/blog/2023/07/18/microsoft-and-meta-expand-their-ai-partnership-with-llama-2-on-azure-and-windows/). Some of the added optimizations include: - SimplifiedLayerNorm changes - Fusions for multiple variants - SkipSimplifiedLayerNorm changes - Kernel support for CPU - Rotary embeddings (previously did not exist) - Fusions for multiple variants - CPU and CUDA kernels - Supports interleaving and non-interleaving in the same kernels - Optimized cache that requires half of its originally exported sizes - Reduced from `(max_sequence_length, head_size)` to `(max_sequence_length, head_size / 2)` - Multi-head attention - Support for 2D and 3D attention masks - Group query attention (for FP16 CUDA and INT4 CUDA) - Integration with flash attention v2 and past-present buffer sharing - Removes need for `attention_mask` input as it is supported in the kernel - 4 bit quantization - `block_size` parameter is available for customizing - Support the new changes for [Microsoft version](https://github.com/microsoft/Llama-2-Onnx) - Support combinations of the below variants (ex: export ORT version and run with Optimum) Supported variants of LLaMA-2 include: - [ORT version](https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/python/tools/transformers/models/llama) - Produces one ONNX file that is already optimized (and quantized if requested) - Integrates with Optimum - [Another Microsoft version](https://github.com/microsoft/Llama-2-Onnx) - Already exported and available off-the-shelf - Faster versions of those models will be uploaded there soon - [Hugging Face version](https://huggingface.co/meta-llama) - Models that end with `-hf` - Some older and current versions of [`transformers`](https://github.com/huggingface/transformers) and [`optimum`](https://github.com/huggingface/optimum) that export the model to ONNX differently - Note that while some older versions are supported, it is recommended to use the latest package versions. ### Usage To use the optimizations, please see `README.md` for details. Please note the various `requirements.txt` files for the package versions recommended in order to use these changes. To run the ORT transformer optimizer separately, run the script as follows: ``` $ cd onnxruntime/onnxruntime/python/tools/transformers/ $ python3 optimizer.py --input <filename>.onnx --output <filename>.onnx --model_type gpt2 --num_heads <number of attention heads> --hidden_size <attention hidden size> --use_external_data_format --opt_level 0 ``` ### Motivation and Context This PR helps the following issues: - microsoft#14997 - microsoft#16254 - microsoft#17681 - microsoft#17925 - microsoft/onnxruntime-inference-examples#320 This PR uses changes from the following PRs: - pytorch/pytorch#104468 - pytorch/pytorch#109759 - microsoft#17020 - microsoft#17674 - microsoft#17890 - microsoft#17920 - huggingface/transformers#26162 - huggingface/optimum#1257 - huggingface/optimum#1289 - huggingface/optimum#1462 ### New TorchDynamo Exporter (experimental stage) This PR uses changes from the following issues and PRs to begin supporting the [new TorchDynamo exporter](https://pytorch.org/docs/stable/onnx.html#torchdynamo-based-onnx-exporter): - huggingface/transformers#26307 - pytorch/pytorch#104903 - pytorch/pytorch#105040 - microsoft/onnxscript#847 - microsoft/onnxscript#862 - microsoft/onnxscript#493
- Loading branch information
1 parent
8a12b2c
commit 2a17d5c
Showing
49 changed files
with
5,897 additions
and
563 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#include "contrib_ops/cpu/bert/rotary_embedding.h" | ||
#include "contrib_ops/cpu/bert/rotary_embedding_helper.h" | ||
|
||
#include "core/platform/threadpool.h" | ||
|
||
using onnxruntime::concurrency::ThreadPool; | ||
using namespace onnxruntime::contrib::rotary_embedding_helper; | ||
|
||
namespace onnxruntime { | ||
namespace contrib { | ||
|
||
// These ops are internal-only, so register outside of onnx | ||
ONNX_OPERATOR_TYPED_KERNEL_EX( | ||
RotaryEmbedding, | ||
kMSDomain, | ||
1, | ||
float, | ||
kCpuExecutionProvider, | ||
KernelDefBuilder() | ||
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()) | ||
.TypeConstraint("M", DataTypeImpl::GetTensorType<int64_t>()), | ||
RotaryEmbedding<float>); | ||
|
||
template <typename T> | ||
RotaryEmbedding<T>::RotaryEmbedding(const OpKernelInfo& info) : OpKernel(info) { | ||
scale = info.GetAttrOrDefault<float>("scale", 1.0); | ||
interleaved = (info.GetAttrOrDefault<int64_t>("interleaved", 0) == 1); | ||
} | ||
|
||
template <typename T> | ||
Status RotaryEmbedding<T>::Compute(OpKernelContext* context) const { | ||
const Tensor* input = context->Input<Tensor>(0); | ||
const Tensor* position_ids = context->Input<Tensor>(1); | ||
const Tensor* cos_cache = context->Input<Tensor>(2); | ||
const Tensor* sin_cache = context->Input<Tensor>(3); | ||
|
||
RotaryParameters parameters = {}; | ||
ORT_RETURN_IF_ERROR(rotary_embedding_helper::CheckInputs<Tensor>(input, | ||
position_ids, | ||
cos_cache, | ||
sin_cache, | ||
¶meters)); | ||
|
||
Tensor* output = context->Output(0, input->Shape()); | ||
|
||
if (parameters.sequence_length > parameters.max_sequence_length) { | ||
// Launch update_cos_sin_cache kernel with scale | ||
ORT_NOT_IMPLEMENTED("Updating cos_cache and sin_cache in RotaryEmbedding is not currently supported"); | ||
} | ||
|
||
const T* input_src = input->Data<T>(); | ||
const int64_t* pos_ids_data = position_ids->Data<int64_t>(); | ||
const T* cos_cache_data = cos_cache->Data<T>(); | ||
const T* sin_cache_data = sin_cache->Data<T>(); | ||
T* output_dest = output->MutableData<T>(); | ||
|
||
const int batch_size = parameters.batch_size; | ||
const int sequence_length = parameters.sequence_length; | ||
const int num_heads = parameters.num_heads; | ||
const int head_size = parameters.head_size; | ||
const int position_ids_format = parameters.position_ids_format; | ||
const int half_head_size = head_size / 2; | ||
|
||
AllocatorPtr allocator; | ||
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); | ||
auto* tp = context->GetOperatorThreadPool(); | ||
|
||
const int loop_len = batch_size * sequence_length * num_heads; | ||
const double cost = static_cast<double>(head_size); | ||
ThreadPool::TryParallelFor(tp, loop_len, cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { | ||
for (std::ptrdiff_t ptr = begin; ptr != end; ++ptr) { | ||
const int b = static_cast<int>((ptr / num_heads) / sequence_length); | ||
const int s = static_cast<int>((ptr / num_heads) % sequence_length); | ||
const int n = static_cast<int>(ptr % num_heads); | ||
|
||
const int block_offset = b * sequence_length * num_heads + s * num_heads + n; | ||
const int data_offset = block_offset * head_size; | ||
|
||
const T* input_data = input_src + data_offset; | ||
T* output_data = output_dest + data_offset; | ||
|
||
// Cache is (M, H/2) | ||
const int position_id = (position_ids_format == 0) | ||
? static_cast<int>(pos_ids_data[0]) + s | ||
: static_cast<int>(pos_ids_data[b * sequence_length + s]); | ||
const int cache_offset = position_id * half_head_size; | ||
const T* cos_data = cos_cache_data + cache_offset; | ||
const T* sin_data = sin_cache_data + cache_offset; | ||
|
||
int cache_idx = 0; | ||
T sign = 0; | ||
int j = 0; | ||
for (int i = 0; i < head_size; i++) { | ||
if (interleaved) { | ||
cache_idx = (i / 2) % half_head_size; | ||
sign = (i % 2 == 0) ? static_cast<T>(-1) : static_cast<T>(1); | ||
j = (i % 2 == 0) ? i + 1 : i - 1; // i - sign | ||
} else { | ||
cache_idx = i % half_head_size; | ||
sign = (i < half_head_size) ? static_cast<T>(-1) : static_cast<T>(1); | ||
j = (i + half_head_size) % head_size; | ||
} | ||
output_data[i] = input_data[i] * cos_data[cache_idx] + sign * input_data[j] * sin_data[cache_idx]; | ||
} | ||
} | ||
}); | ||
|
||
return Status::OK(); | ||
} | ||
|
||
} // namespace contrib | ||
} // namespace onnxruntime |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#pragma once | ||
#include "core/common/common.h" | ||
#include "core/framework/op_kernel.h" | ||
|
||
namespace onnxruntime { | ||
namespace contrib { | ||
|
||
template <typename T> | ||
class RotaryEmbedding final : public OpKernel { | ||
public: | ||
RotaryEmbedding(const OpKernelInfo& info); | ||
Status Compute(OpKernelContext* context) const override; | ||
|
||
protected: | ||
float scale; | ||
bool interleaved; | ||
}; | ||
|
||
} // namespace contrib | ||
} // namespace onnxruntime |
Oops, something went wrong.