Skip to content

Commit

Permalink
Fix a computational problem of scaledSoftmax.
Browse files Browse the repository at this point in the history
The original implementation results in wrong results of sum of softmax such that the results of BERT models (128 < seq_len < 384 and seq_len > 384) are very large or even 'nan'.
This implementation fix the computational problem such that the results of BERT models (128 < seq_len < 384 and seq_len > 384) become correct.

Signed-off-by: yuanzexi <[email protected]>
  • Loading branch information
yuanzexi authored and rajeevsrao committed Mar 18, 2021
1 parent 88f9cae commit 10a51be
Showing 1 changed file with 6 additions and 11 deletions.
17 changes: 6 additions & 11 deletions plugin/common/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,6 @@ template <typename T, unsigned TPB>
__device__ inline void scaledSoftmax(
const int ld, const int lastValid, const float rsqrtHeadSize, const T* input, T* output)
{

using BlockReduce = cub::BlockReduce<float, TPB>;
__shared__ typename BlockReduce::TempStorage tmpStorage;

Expand All @@ -346,7 +345,7 @@ __device__ inline void scaledSoftmax(
for (int i = threadIdx.x; i < lastValid; i += TPB)
{
const int idx = offset + i;
threadData = input[idx];
threadData = max(static_cast<float>(input[idx]), threadData);
}

const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max());
Expand All @@ -356,16 +355,12 @@ __device__ inline void scaledSoftmax(
}
__syncthreads();

if (lastValid < blockDim.x)
{
if (threadIdx.x >= lastValid)
{
threadData = 0;
}
}
threadData = 0;

for (int i = threadIdx.x; i < lastValid; i += TPB)
{
threadData += exp((threadData - fMax) * w);
const int idx = offset + i;
threadData += exp((static_cast<float>(input[idx]) - fMax) * w);
}

const auto Z = BlockReduce(tmpStorage).Reduce(threadData, sum);
Expand All @@ -379,7 +374,7 @@ __device__ inline void scaledSoftmax(
for (int i = threadIdx.x; i < ld; i += TPB)
{
const int idx = offset + i;
const float val = (i < lastValid) ? exp(float(input[idx]) * w) * rZ : 0.f;
const float val = (i < lastValid) ? exp((static_cast<float>(input[idx]) - fMax) * w) * rZ : 0.f;
output[idx] = T(val);
}
}
Expand Down

0 comments on commit 10a51be

Please sign in to comment.