Skip to content

Commit

Permalink
Update disentangled attention plugin based on DeBERTa v2
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Chen <[email protected]>
  • Loading branch information
Haohang Huang authored and kevinch-nv committed Aug 17, 2022
1 parent 52158de commit 568092d
Show file tree
Hide file tree
Showing 4 changed files with 206 additions and 283 deletions.
13 changes: 7 additions & 6 deletions plugin/disentangledAttentionPlugin/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ This TensorRT plugin implements an efficient algorithm to perform the calculatio

Unlike [BERT](https://arxiv.org/abs/1810.04805) where each word is represented by one vector that sums the content embedding and position embedding, [DeBERTa](https://arxiv.org/abs/2006.03654) design first proposed the concept of disentangled attention, which uses two vectors to encode content and position respectively and forms attention weights by summing disentangled matrices. Performance gap has been identified between the new attention scheme and the original self-attention, mainly due to extra indexing and gather opertaions. Major optimizations implemented in this plugin includes: (i) fusion of gather and pointwise operataions (ii) utilizing the pattern of relative position matrix and shortcuting out-of-boundary index calculation (iii) parallel index calculation.

This TensorRT plugin is primarily intended to be used together with DeBERTa network, but also applies to generic architectures that adopt disentangeld attention.
This TensorRT plugin is primarily intended to be used together with DeBERTa network (with HuggingFace [DeBERTa](https://huggingface.co/docs/transformers/model_doc/deberta) and [DeBERTa-V2](https://huggingface.co/docs/transformers/model_doc/deberta-v2) implementation), but also applies to generic architectures that adopt disentangeld attention.

## Structure
This plugin works for network with graph node named `DisentangledAttention_TRT`.
This plugin works for network with graph node named `DisentangledAttention_TRT`. The corresponding graph modification script can be found under the `demo/DeBERTa` folder of TensorRT OSS.

### Input(s)
This plugin takes three inputs:
Expand Down Expand Up @@ -67,12 +67,13 @@ This plugin generates one output.
## Additional Resources
- [BERT](https://arxiv.org/abs/1810.04805)
- [DeBERTa](https://arxiv.org/abs/2006.03654)
- [DeBERTa HuggingFace Implementation](https://github.com/huggingface/transformers/tree/main/src/transformers/models/deberta_v2)


- [DeBERTa HuggingFace Implementation](https://github.com/huggingface/transformers/tree/main/src/transformers/models/deberta)
- [DeBERTa-V2 HuggingFace Implementation](https://github.com/huggingface/transformers/tree/main/src/transformers/models/deberta_v2)
## License
For terms and conditions for use, reproduction, and distribution, see the [TensorRT Software License Agreement](https://docs.nvidia.com/deeplearning/sdk/tensorrt-sla/index.html)
documentation.

## Changelog
2022.04: This is the first release of this `README` file.
- 2022.04: This is the first release of this `README` file.
- 2022.07: Added log bucket for the relative position index calculation (since DeBERTa V2).
255 changes: 113 additions & 142 deletions plugin/disentangledAttentionPlugin/disentangledAttentionPlugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,16 +97,9 @@ nvinfer1::DimsExprs DisentangledAttentionPlugin::getOutputDimensions(
int32_t index, nvinfer1::DimsExprs const* inputs, int32_t nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept
{
nvinfer1::DimsExprs output;
if (kDISENTANGLED_VERSION == 1)
{
PLUGIN_ASSERT(nbInputs == 4); // 4 inputs
output = inputs[1]; // same as input[1] or input[3], i.e. index1 or index2
}
else if (kDISENTANGLED_VERSION == 2)
{
PLUGIN_ASSERT(nbInputs == 3); // 3 inputs
output = inputs[0]; // same as input[0], i.e. data0
}

PLUGIN_ASSERT(nbInputs == 3); // 3 inputs
output = inputs[0]; // same as input[0], i.e. data0

PLUGIN_ASSERT(index < 1); // only one output

Expand Down Expand Up @@ -136,78 +129,92 @@ int32_t DisentangledAttentionPlugin::enqueue(nvinfer1::PluginTensorDesc const* i
nvinfer1::PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* workspace,
cudaStream_t stream) noexcept
{
if (kDISENTANGLED_VERSION == 1)

#if kDISENTANGLED_VERSION == 1
nvinfer1::Dims dims0 = inputDesc[0].dims;
nvinfer1::Dims dims1 = inputDesc[1].dims;
nvinfer1::Dims dims2 = inputDesc[2].dims;
dim3 dimData0(dims0.d[0], dims0.d[1], dims0.d[2]);
dim3 dimData1(dims1.d[0], dims1.d[1], dims1.d[2]);
dim3 dimData2(dims2.d[0], dims2.d[1], dims2.d[2]);
dim3 dimResult(dimData0);

dim3 block_optimized(kDISENTANGLED_TILESIZE_V1, kDISENTANGLED_BLOCKDIMY_V1);
dim3 grid_optimized((dimResult.z - 1) / kDISENTANGLED_TILESIZE_V1 + 1,
(dimResult.y - 1) / kDISENTANGLED_TILESIZE_V1 + 1, dimResult.x);

if (inputDesc[0].type == nvinfer1::DataType::kFLOAT)
{
nvinfer1::Dims dims0 = inputDesc[0].dims;
nvinfer1::Dims dims1 = inputDesc[1].dims;
nvinfer1::Dims dims2 = inputDesc[2].dims;
nvinfer1::Dims dims3 = inputDesc[3].dims;
dim3 dimData1(dims0.d[0], dims0.d[1], dims0.d[2]);
dim3 dimIndex1(dims1.d[0], dims1.d[1], dims1.d[2]);
dim3 dimData2(dims2.d[0], dims2.d[1], dims2.d[2]);
dim3 dimIndex2(dims3.d[0], dims3.d[1], dims3.d[2]);
dim3 dimResult(dimIndex2);

dim3 block_optimized(kDISENTANGLED_TILESIZE, kDISENTANGLED_BLOCKDIMY);
dim3 grid_optimized((dimResult.z - 1) / kDISENTANGLED_TILESIZE + 1,
(dimResult.y - 1) / kDISENTANGLED_TILESIZE + 1, dimResult.x);

__half const* data1 = static_cast<__half const*>(inputs[0]);
int32_t const* index1 = static_cast<int32_t const*>(inputs[1]);
__half const* data2 = static_cast<__half const*>(inputs[2]);
int32_t const* index2 = static_cast<int32_t const*>(inputs[3]);
__half* result = static_cast<__half*>(outputs[0]);

disentangled_kernel_wrapper_v1<__half>(data1, index1, data2, index2, result, dimData1, dimIndex1, dimData2,
dimIndex2, dimResult, block_optimized, grid_optimized, stream);
auto const* data0 = pointer_const_cast<float>(inputs[0]);
auto const* data1 = pointer_const_cast<float>(inputs[1]);
auto const* data2 = pointer_const_cast<float>(inputs[2]);
auto* result = pointer_cast<float>(outputs[0]);
disentangled_kernel_wrapper<float, kDISENTANGLED_TILESIZE_V1, kDISENTANGLED_BLOCKDIMY_V1>(data0, data1, data2,
result, dimData0, dimData1, dimData2, dimResult, mFactor, mSpan, block_optimized, grid_optimized, stream);
}
else if (kDISENTANGLED_VERSION == 2)
else if (inputDesc[0].type == nvinfer1::DataType::kHALF)
{
nvinfer1::Dims dims0 = inputDesc[0].dims;
nvinfer1::Dims dims1 = inputDesc[1].dims;
nvinfer1::Dims dims2 = inputDesc[2].dims;
dim3 dimData0(dims0.d[0], dims0.d[1], dims0.d[2]);
dim3 dimData1(dims1.d[0], dims1.d[1], dims1.d[2]);
dim3 dimData2(dims2.d[0], dims2.d[1], dims2.d[2]);
dim3 dimResult(dimData0);

dim3 block_optimized(kDISENTANGLED_TILESIZE, kDISENTANGLED_BLOCKDIMY);
dim3 grid_optimized((dimResult.z - 1) / kDISENTANGLED_TILESIZE + 1,
(dimResult.y - 1) / kDISENTANGLED_TILESIZE + 1, dimResult.x);

if (inputDesc[0].type == nvinfer1::DataType::kFLOAT)
{
auto const* data0 = pointer_const_cast<float>(inputs[0]);
auto const* data1 = pointer_const_cast<float>(inputs[1]);
auto const* data2 = pointer_const_cast<float>(inputs[2]);
auto* result = pointer_cast<float>(outputs[0]);
disentangled_kernel_wrapper_v2<float, kDISENTANGLED_TILESIZE, kDISENTANGLED_BLOCKDIMY>(data0, data1, data2,
result, dimData0, dimData1, dimData2, dimResult, mFactor, mSpan, block_optimized, grid_optimized,
stream);
}
else if (inputDesc[0].type == nvinfer1::DataType::kHALF)
{
auto const* data0 = pointer_const_cast<__half>(inputs[0]);
auto const* data1 = pointer_const_cast<__half>(inputs[1]);
auto const* data2 = pointer_const_cast<__half>(inputs[2]);
auto* result = pointer_cast<__half>(outputs[0]);
__half factor = __float2half(mFactor);
disentangled_kernel_wrapper_v2<__half, kDISENTANGLED_TILESIZE, kDISENTANGLED_BLOCKDIMY>(data0, data1, data2,
result, dimData0, dimData1, dimData2, dimResult, factor, mSpan, block_optimized, grid_optimized,
stream);
}
else if (inputDesc[0].type == nvinfer1::DataType::kINT8)
{
auto const* data0 = pointer_const_cast<int8_t>(inputs[0]);
auto const* data1 = pointer_const_cast<int8_t>(inputs[1]);
auto const* data2 = pointer_const_cast<int8_t>(inputs[2]);
auto* result = pointer_cast<int8_t>(outputs[0]);
int8_t factor = int8_t(mFactor);
disentangled_kernel_wrapper_v2<int8_t, kDISENTANGLED_TILESIZE, kDISENTANGLED_BLOCKDIMY>(data0, data1, data2,
result, dimData0, dimData1, dimData2, dimResult, factor, mSpan, block_optimized, grid_optimized,
stream);
}
auto const* data0 = pointer_const_cast<__half>(inputs[0]);
auto const* data1 = pointer_const_cast<__half>(inputs[1]);
auto const* data2 = pointer_const_cast<__half>(inputs[2]);
auto* result = pointer_cast<__half>(outputs[0]);
__half factor = __float2half(mFactor);
disentangled_kernel_wrapper<__half, kDISENTANGLED_TILESIZE_V1, kDISENTANGLED_BLOCKDIMY_V1>(data0, data1, data2,
result, dimData0, dimData1, dimData2, dimResult, factor, mSpan, block_optimized, grid_optimized, stream);
}
else if (inputDesc[0].type == nvinfer1::DataType::kINT8)
{
auto const* data0 = pointer_const_cast<int8_t>(inputs[0]);
auto const* data1 = pointer_const_cast<int8_t>(inputs[1]);
auto const* data2 = pointer_const_cast<int8_t>(inputs[2]);
auto* result = pointer_cast<int8_t>(outputs[0]);
int8_t factor = int8_t(mFactor);
disentangled_kernel_wrapper<int8_t, kDISENTANGLED_TILESIZE_V1, kDISENTANGLED_BLOCKDIMY_V1>(data0, data1, data2,
result, dimData0, dimData1, dimData2, dimResult, factor, mSpan, block_optimized, grid_optimized, stream);
}
#elif kDISENTANGLED_VERSION == 2
nvinfer1::Dims dims0 = inputDesc[0].dims;
nvinfer1::Dims dims1 = inputDesc[1].dims;
nvinfer1::Dims dims2 = inputDesc[2].dims;
dim3 dimData0(dims0.d[0], dims0.d[1], dims0.d[2]);
dim3 dimData1(dims1.d[0], dims1.d[1], dims1.d[2]);
dim3 dimData2(dims2.d[0], dims2.d[1], dims2.d[2]);
dim3 dimResult(dimData0);

dim3 block_optimized(kDISENTANGLED_TILESIZE_V2, kDISENTANGLED_BLOCKDIMY_V2);
dim3 grid_optimized((dimResult.z - 1) / kDISENTANGLED_TILESIZE_V2 + 1,
(dimResult.y - 1) / kDISENTANGLED_TILESIZE_V2 + 1, dimResult.x);

if (inputDesc[0].type == nvinfer1::DataType::kFLOAT)
{
auto const* data0 = pointer_const_cast<float>(inputs[0]);
auto const* data1 = pointer_const_cast<float>(inputs[1]);
auto const* data2 = pointer_const_cast<float>(inputs[2]);
auto* result = pointer_cast<float>(outputs[0]);
disentangled_kernel_wrapper<float, kDISENTANGLED_TILESIZE_V2, kDISENTANGLED_BLOCKDIMY_V2>(data0, data1, data2,
result, dimData0, dimData1, dimData2, dimResult, mFactor, mSpan, block_optimized, grid_optimized, stream);
}
else if (inputDesc[0].type == nvinfer1::DataType::kHALF)
{
auto const* data0 = pointer_const_cast<__half>(inputs[0]);
auto const* data1 = pointer_const_cast<__half>(inputs[1]);
auto const* data2 = pointer_const_cast<__half>(inputs[2]);
auto* result = pointer_cast<__half>(outputs[0]);
__half factor = __float2half(mFactor);
disentangled_kernel_wrapper<__half, kDISENTANGLED_TILESIZE_V2, kDISENTANGLED_BLOCKDIMY_V2>(data0, data1, data2,
result, dimData0, dimData1, dimData2, dimResult, factor, mSpan, block_optimized, grid_optimized, stream);
}
else if (inputDesc[0].type == nvinfer1::DataType::kINT8)
{
auto const* data0 = pointer_const_cast<int8_t>(inputs[0]);
auto const* data1 = pointer_const_cast<int8_t>(inputs[1]);
auto const* data2 = pointer_const_cast<int8_t>(inputs[2]);
auto* result = pointer_cast<int8_t>(outputs[0]);
int8_t factor = int8_t(mFactor);
disentangled_kernel_wrapper<int8_t, kDISENTANGLED_TILESIZE_V2, kDISENTANGLED_BLOCKDIMY_V2>(data0, data1, data2,
result, dimData0, dimData1, dimData2, dimResult, factor, mSpan, block_optimized, grid_optimized, stream);
}
#endif

return cudaPeekAtLastError();
}
Expand Down Expand Up @@ -276,70 +283,34 @@ IPluginV2DynamicExt* DisentangledAttentionPlugin::clone() const noexcept
void DisentangledAttentionPlugin::configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int32_t nbInputs,
nvinfer1::DynamicPluginTensorDesc const* out, int32_t nbOutputs) noexcept
{
if (kDISENTANGLED_VERSION == 1)
{
// inputs
PLUGIN_ASSERT(nbInputs == 4); // 4 inputs

// check for valid input dimensions
PLUGIN_ASSERT(in[0].desc.dims.nbDims == 3);
PLUGIN_ASSERT(in[1].desc.dims.nbDims == 3);
PLUGIN_ASSERT(in[2].desc.dims.nbDims == 3);
PLUGIN_ASSERT(in[3].desc.dims.nbDims == 3);

// check BN (batch_size * num_heads) dimension consistency
PLUGIN_ASSERT(in[0].desc.dims.d[0] == in[1].desc.dims.d[0]);
PLUGIN_ASSERT(in[0].desc.dims.d[0] == in[2].desc.dims.d[0]);
PLUGIN_ASSERT(in[0].desc.dims.d[0] == in[3].desc.dims.d[0]);

// check S (sequence_length) dimension consistency
PLUGIN_ASSERT(in[0].desc.dims.d[1] == in[1].desc.dims.d[1]);
PLUGIN_ASSERT(in[0].desc.dims.d[1] == in[2].desc.dims.d[1]);
PLUGIN_ASSERT(in[0].desc.dims.d[1] == in[3].desc.dims.d[1]);
PLUGIN_ASSERT(in[1].desc.dims.d[1] == in[1].desc.dims.d[2]);
PLUGIN_ASSERT(in[3].desc.dims.d[1] == in[3].desc.dims.d[2]);

// check K (2 * span) dimension consistency for in[0] and in[2]
PLUGIN_ASSERT(in[0].desc.dims.d[2] == 2 * mSpan);
PLUGIN_ASSERT(in[2].desc.dims.d[2] == 2 * mSpan);

// Outputs (same dimension as in[1])
PLUGIN_ASSERT(nbOutputs == 1);
PLUGIN_ASSERT(out[0].desc.dims.nbDims == 3);
PLUGIN_ASSERT(in[1].desc.dims.d[0] == out[0].desc.dims.d[0]);
PLUGIN_ASSERT(in[1].desc.dims.d[1] == out[0].desc.dims.d[1]);
PLUGIN_ASSERT(in[1].desc.dims.d[2] == out[0].desc.dims.d[2]);
}
else if (kDISENTANGLED_VERSION == 2)
{
// inputs
PLUGIN_ASSERT(nbInputs == 3); // 3 inputs

// check for valid input dimensions
PLUGIN_ASSERT(in[0].desc.dims.nbDims == 3);
PLUGIN_ASSERT(in[1].desc.dims.nbDims == 3);
PLUGIN_ASSERT(in[2].desc.dims.nbDims == 3);

// check BN (batch_size * num_heads) dimension consistency
PLUGIN_ASSERT(in[0].desc.dims.d[0] == in[1].desc.dims.d[0]);
PLUGIN_ASSERT(in[0].desc.dims.d[0] == in[2].desc.dims.d[0]);

// check S (sequence_length) dimension consistency
PLUGIN_ASSERT(in[0].desc.dims.d[1] == in[1].desc.dims.d[1]);
PLUGIN_ASSERT(in[0].desc.dims.d[1] == in[2].desc.dims.d[1]);
PLUGIN_ASSERT(in[0].desc.dims.d[1] == in[0].desc.dims.d[2]);

// check K (2 * span) dimension consistency for in[1] and in[2]
PLUGIN_ASSERT(in[1].desc.dims.d[2] == 2 * mSpan);
PLUGIN_ASSERT(in[2].desc.dims.d[2] == 2 * mSpan);

// Outputs (same dimension as in[0])
PLUGIN_ASSERT(nbOutputs == 1);
PLUGIN_ASSERT(out[0].desc.dims.nbDims == 3);
PLUGIN_ASSERT(in[0].desc.dims.d[0] == out[0].desc.dims.d[0]);
PLUGIN_ASSERT(in[0].desc.dims.d[1] == out[0].desc.dims.d[1]);
PLUGIN_ASSERT(in[0].desc.dims.d[2] == out[0].desc.dims.d[2]);
}

// inputs
PLUGIN_ASSERT(nbInputs == 3); // 3 inputs

// check for valid input dimensions
PLUGIN_ASSERT(in[0].desc.dims.nbDims == 3);
PLUGIN_ASSERT(in[1].desc.dims.nbDims == 3);
PLUGIN_ASSERT(in[2].desc.dims.nbDims == 3);

// check BN (batch_size * num_heads) dimension consistency
PLUGIN_ASSERT(in[0].desc.dims.d[0] == in[1].desc.dims.d[0]);
PLUGIN_ASSERT(in[0].desc.dims.d[0] == in[2].desc.dims.d[0]);

// check S (sequence_length) dimension consistency
PLUGIN_ASSERT(in[0].desc.dims.d[1] == in[1].desc.dims.d[1]);
PLUGIN_ASSERT(in[0].desc.dims.d[1] == in[2].desc.dims.d[1]);
PLUGIN_ASSERT(in[0].desc.dims.d[1] == in[0].desc.dims.d[2]);

// check K (2 * span) dimension consistency for in[1] and in[2]
PLUGIN_ASSERT(in[1].desc.dims.d[2] == 2 * mSpan);
PLUGIN_ASSERT(in[2].desc.dims.d[2] == 2 * mSpan);

// Outputs (same dimension as in[0])
PLUGIN_ASSERT(nbOutputs == 1);
PLUGIN_ASSERT(out[0].desc.dims.nbDims == 3);
PLUGIN_ASSERT(in[0].desc.dims.d[0] == out[0].desc.dims.d[0]);
PLUGIN_ASSERT(in[0].desc.dims.d[1] == out[0].desc.dims.d[1]);
PLUGIN_ASSERT(in[0].desc.dims.d[2] == out[0].desc.dims.d[2]);
}

nvinfer1::DataType DisentangledAttentionPlugin::getOutputDataType(
Expand Down
28 changes: 12 additions & 16 deletions plugin/disentangledAttentionPlugin/disentangledAttentionPlugin.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,13 @@
#ifndef TRT_DISENTANGLED_ATTENTION_PLUGIN_H
#define TRT_DISENTANGLED_ATTENTION_PLUGIN_H

#include "serialize.hpp"
#include "NvInferPlugin.h"
#include "plugin.h"
#include <cudnn.h>
#include <vector>
#include "serialize.hpp"
#include <cstdint>
#include <iostream>
#include <string>
#include <cstdint>
#include "NvInferPlugin.h"
#include <vector>

// One of the preferred ways of making TensorRT to be able to see
// our custom layer requires extending IPluginV2 and IPluginCreator classes.
Expand All @@ -37,17 +36,16 @@ namespace plugin

// using namespace nvinfer1;

constexpr int32_t kDISENTANGLED_VERSION = 2;
constexpr int32_t kDISENTANGLED_TILESIZE = 32;
constexpr int32_t kDISENTANGLED_BLOCKDIMY = 8;

template <typename TDataType>
void disentangled_kernel_wrapper_v1(TDataType const* data1, int32_t const* index1, TDataType const* data2,
int32_t const* index2, TDataType* result, dim3 dimData1, dim3 dimIndex1, dim3 dimData2, dim3 dimIndex2,
dim3 dimResult, dim3 block, dim3 grid, cudaStream_t stream);
#define kDISENTANGLED_VERSION 2
// Version 1: regular relative position index
// Version 2: log bucket relative position index
constexpr int32_t kDISENTANGLED_TILESIZE_V1 = 32;
constexpr int32_t kDISENTANGLED_BLOCKDIMY_V1 = 8;
constexpr int32_t kDISENTANGLED_TILESIZE_V2 = 64;
constexpr int32_t kDISENTANGLED_BLOCKDIMY_V2 = 4;

template <typename TDataType, int32_t tTileSize, int32_t tBlockDimY>
void disentangled_kernel_wrapper_v2(TDataType const* data0, TDataType const* data1, TDataType const* data2,
void disentangled_kernel_wrapper(TDataType const* data0, TDataType const* data1, TDataType const* data2,
TDataType* result, dim3 dimData0, dim3 dimData1, dim3 dimData2, dim3 dimResult, TDataType factor, int32_t span,
dim3 block, dim3 grid, cudaStream_t stream);

Expand Down Expand Up @@ -122,8 +120,6 @@ class DisentangledAttentionPlugin final : public nvinfer1::IPluginV2DynamicExt
// attributes
int32_t mSpan;
float mFactor;

cudnnHandle_t _cudnn_handle;
};

class DisentangledAttentionPluginCreator : public nvinfer1::IPluginCreator
Expand Down
Loading

0 comments on commit 568092d

Please sign in to comment.