Skip to content

Commit

Permalink
Fix/gpt early stop (NVIDIA#584)
Browse files Browse the repository at this point in the history
* fix: fix bug of early stopping of gpt
  • Loading branch information
byshiue authored May 1, 2023
1 parent 19b2956 commit c6e8f60
Show file tree
Hide file tree
Showing 8 changed files with 121 additions and 73 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,9 @@ In the experiments of decoding, we updated the following parameters:

### Changelog

May 2023
- Fix bugs of generation early stopping

January 2023
- Support GPT MoE
- Support FP8 for Bert and GPT (**Experimental**)
Expand Down
34 changes: 25 additions & 9 deletions src/fastertransformer/kernels/gpt_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -640,6 +640,7 @@ __global__ void generate_dups_indices(int* batch_to_compact,
int* compact_size,
const int* shared_contexts,
const size_t batch_size,
const size_t beam_width,
const size_t input_seq_len)
{
const int padded_batchsize = blockDim.x * ((batch_size + blockDim.x - 1) / blockDim.x);
Expand All @@ -649,20 +650,23 @@ __global__ void generate_dups_indices(int* batch_to_compact,
__shared__ int scan_offset;

int scan = 0;
for (int batch = threadIdx.x; batch < padded_batchsize; batch += blockDim.x) {
bool masked = (batch >= batch_size);
bool first_iter = batch < blockDim.x;
for (int seq_idx = threadIdx.x; seq_idx < padded_batchsize; seq_idx += blockDim.x) {
bool masked = (seq_idx >= batch_size);
bool first_iter = seq_idx < blockDim.x;

int is_first_occur = masked ? 0 : shared_contexts[batch] == batch;
int is_first_occur = masked ? 0 : shared_contexts[seq_idx] == seq_idx;
BlockScan(temp_storage).ExclusiveSum(is_first_occur, scan);

if (!masked && is_first_occur) {
int compact_idx = scan + (first_iter ? 0 : scan_offset);
// Context rep. writes initial index
batch_to_compact[batch] = compact_idx;
compact_to_batch[compact_idx] = batch;
batch_to_compact[seq_idx * beam_width] = compact_idx;
// input ids are tiled in context part
compact_to_batch[compact_idx] = seq_idx * beam_width;
}

__syncthreads();

if (threadIdx.x == blockDim.x - 1) {
scan_offset = scan + is_first_occur + (first_iter ? 0 : scan_offset);
}
Expand All @@ -671,8 +675,15 @@ __global__ void generate_dups_indices(int* batch_to_compact,

if (!masked && !is_first_occur) {
// Fill the rest of batch_to_compact based on what rep. wrote
const int src_idx = batch_to_compact[shared_contexts[batch]];
batch_to_compact[batch] = src_idx;
const int src_idx = batch_to_compact[shared_contexts[seq_idx] * beam_width];
batch_to_compact[seq_idx * beam_width] = src_idx;
}

if (!masked) {
// set same compact idx for beams
for (int beam_id = 1; beam_id < beam_width; ++beam_id) {
batch_to_compact[seq_idx * beam_width + beam_id] = batch_to_compact[seq_idx * beam_width];
}
}
}

Expand All @@ -696,14 +707,17 @@ void invokeFindContextDups(int* shared_contexts,
int* compact_size,
const int* input_ids,
const size_t batch_size,
const size_t beam_width,
const size_t input_seq_len,
cudaStream_t stream)
{
dim3 block{512};
dim3 grid{((int)batch_size + block.x - 1) / block.x};
// set shared_context[i] = i
init_shared_contexts<<<grid, block, 0, stream>>>(shared_contexts, batch_size);

grid = dim3{(unsigned int)(batch_size * (batch_size - 1)) / 2};
// set shared_contexts[i] = j, where j = min{k, such that input_ids[k] == input_ids[i]}
if (input_seq_len <= 128) {
block = 128;
find_context_dups<128><<<grid, block, 0, stream>>>(shared_contexts, input_ids, batch_size, input_seq_len);
Expand All @@ -713,8 +727,10 @@ void invokeFindContextDups(int* shared_contexts,
find_context_dups<256><<<grid, block, 0, stream>>>(shared_contexts, input_ids, batch_size, input_seq_len);
}

// set batch_to_compact[i] = j, where j is the position of input_ids[i] in the compact_batch
// set compact_to_batch[i] = j, where j is such that compact_to_batch[i] = input_ids[j]
generate_dups_indices<<<1, DUPS_INDICES_BLOCK_SIZE, 0, stream>>>(
batch_to_compact, compact_to_batch, compact_size, shared_contexts, batch_size, input_seq_len);
batch_to_compact, compact_to_batch, compact_size, shared_contexts, batch_size, beam_width, input_seq_len);
}

template<typename T>
Expand Down
1 change: 1 addition & 0 deletions src/fastertransformer/kernels/gpt_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ void invokeFindContextDups(int* shared_contexts,
int* compact_size,
const int* input_ids,
const size_t batch_size,
const size_t beam_width,
const size_t input_seq_len,
cudaStream_t stream = 0);

Expand Down
2 changes: 1 addition & 1 deletion src/fastertransformer/kernels/stop_criteria_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ void invokeLengthCriterion(bool* finished,

length_criterion<<<grid, block, 0, stream>>>(
finished, should_stop, h_pinned_finished_sum_, sequence_limit_length, batch_size, beam_width, step);
while (((volatile size_t*)h_pinned_finished_sum_)[0] == -1) {};
while (((volatile int*)h_pinned_finished_sum_)[0] == -1) {};
sync_check_cuda_error();

*should_stop = h_pinned_finished_sum_[0] == batch_size * beam_width;
Expand Down
Loading

0 comments on commit c6e8f60

Please sign in to comment.