Skip to content

Commit

Permalink
fix: malloc host opt
Browse files Browse the repository at this point in the history
  • Loading branch information
byshiue committed Feb 6, 2023
1 parent 652e61e commit 90cc5a7
Show file tree
Hide file tree
Showing 17 changed files with 92 additions and 102 deletions.
6 changes: 2 additions & 4 deletions src/fastertransformer/layers/DynamicDecodeLayer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,15 @@ template<typename T>
void DynamicDecodeLayer<T>::allocateBuffer()
{
FT_LOG_DEBUG(__PRETTY_FUNCTION__);
check_cuda_error(cudaMallocHost((void**)&h_pinned_finished_sum_, sizeof(int)));
h_pinned_finished_sum_ = (int*)allocator_->reMalloc(h_pinned_finished_sum_, sizeof(int), true, true);
return;
}

template<typename T>
void DynamicDecodeLayer<T>::freeBuffer()
{
FT_LOG_DEBUG(__PRETTY_FUNCTION__);
if (h_pinned_finished_sum_ != nullptr) {
check_cuda_error(cudaFreeHost(h_pinned_finished_sum_));
}
allocator_->free((void**)(&h_pinned_finished_sum_), true);
return;
}

Expand Down
10 changes: 4 additions & 6 deletions src/fastertransformer/models/bart/BartEncoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ template<typename T>
void BartEncoder<T>::allocateBuffer()
{
if (is_allocate_buffer_ == false) {
check_cuda_error(cudaMallocHost((void**)&h_pinned_token_num_ptr_, sizeof(size_t)));
h_pinned_token_num_ptr_ = (size_t*)allocator_->reMalloc(h_pinned_token_num_ptr_, sizeof(size_t), true, true);
padding_offset_ =
(int*)allocator_->reMalloc(padding_offset_, sizeof(int) * max_batch_size_ * max_seq_len_, false);
trt_mha_padding_offset_ =
Expand Down Expand Up @@ -256,10 +256,8 @@ template<typename T>
void BartEncoder<T>::allocateBuffer(size_t batch_size, size_t seq_len)
{
FT_LOG_DEBUG(__PRETTY_FUNCTION__);
if (!is_allocate_buffer_) {
check_cuda_error(cudaMallocHost((void**)&h_pinned_token_num_ptr_, sizeof(size_t)));
}
padding_offset_ = (int*)allocator_->reMalloc(padding_offset_, sizeof(int) * batch_size * seq_len, false);
h_pinned_token_num_ptr_ = (size_t*)allocator_->reMalloc(h_pinned_token_num_ptr_, sizeof(size_t), true, true);
padding_offset_ = (int*)allocator_->reMalloc(padding_offset_, sizeof(int) * batch_size * seq_len, false);
trt_mha_padding_offset_ =
(int*)allocator_->reMalloc(trt_mha_padding_offset_, sizeof(int) * (2 * batch_size + 1), false);

Expand Down Expand Up @@ -292,7 +290,7 @@ template<typename T>
void BartEncoder<T>::freeBuffer()
{
if (is_allocate_buffer_) {
check_cuda_error(cudaFreeHost(h_pinned_token_num_ptr_));
allocator_->free((void**)(&h_pinned_token_num_ptr_), true);
allocator_->free((void**)(&padding_offset_));
allocator_->free((void**)(&trt_mha_padding_offset_));

Expand Down
8 changes: 3 additions & 5 deletions src/fastertransformer/models/bert/Bert.cc
Original file line number Diff line number Diff line change
Expand Up @@ -217,10 +217,8 @@ template<typename T>
void Bert<T>::allocateBuffer(size_t batch_size, size_t seq_len)
{
FT_LOG_DEBUG(__PRETTY_FUNCTION__);
if (!is_allocate_buffer_) {
check_cuda_error(cudaMallocHost((void**)&h_pinned_token_num_ptr_, sizeof(size_t)));
}
padding_offset_ = (int*)allocator_->reMalloc(padding_offset_, sizeof(int) * batch_size * seq_len, false);
h_pinned_token_num_ptr_ = (size_t*)allocator_->reMalloc(h_pinned_token_num_ptr_, sizeof(size_t), true, true);
padding_offset_ = (int*)allocator_->reMalloc(padding_offset_, sizeof(int) * batch_size * seq_len, false);
trt_mha_padding_offset_ =
(int*)allocator_->reMalloc(trt_mha_padding_offset_, sizeof(int) * (2 * batch_size + 1), false);

Expand Down Expand Up @@ -250,7 +248,7 @@ void Bert<T>::freeBuffer()
{
FT_LOG_DEBUG(__PRETTY_FUNCTION__);
if (is_allocate_buffer_) {
check_cuda_error(cudaFreeHost(h_pinned_token_num_ptr_));
allocator_->free((void**)(&h_pinned_token_num_ptr_), true);
allocator_->free((void**)(&padding_offset_));
allocator_->free((void**)(&trt_mha_padding_offset_));

Expand Down
12 changes: 3 additions & 9 deletions src/fastertransformer/models/bert_fp8/BertFP8.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,8 @@ template<typename T1, typename T2>
void BertFP8<T1, T2>::allocateBuffer(size_t batch_size, size_t seq_len)
{
FT_LOG_DEBUG(__PRETTY_FUNCTION__);
if (!is_allocate_buffer_) {
check_cuda_error(cudaMallocHost((void**)&h_pinned_token_num_ptr_, sizeof(size_t)));
is_allocate_buffer_ = true;
}
padding_offset_ = (int*)allocator_->reMalloc(padding_offset_, sizeof(int) * batch_size * seq_len, false);
h_pinned_token_num_ptr_ = (size_t*)allocator_->reMalloc(h_pinned_token_num_ptr_, sizeof(size_t), true, true);
padding_offset_ = (int*)allocator_->reMalloc(padding_offset_, sizeof(int) * batch_size * seq_len, false);
trt_mha_padding_offset_ =
(int*)allocator_->reMalloc(trt_mha_padding_offset_, sizeof(int) * (2 * batch_size + 1), false);

Expand Down Expand Up @@ -161,10 +158,7 @@ void BertFP8<T1, T2>::allocateBuffer(size_t batch_size, size_t seq_len)
template<typename T1, typename T2>
void BertFP8<T1, T2>::freeBuffer()
{
if (is_allocate_buffer_) {
check_cuda_error(cudaFreeHost(h_pinned_token_num_ptr_));
is_allocate_buffer_ = false;
}
allocator_->free((void**)(&h_pinned_token_num_ptr_), true);
allocator_->free((void**)(&padding_offset_));
allocator_->free((void**)(&trt_mha_padding_offset_));

Expand Down
4 changes: 2 additions & 2 deletions src/fastertransformer/models/bert_int8/BertINT8.cc
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ template<typename T>
void BertINT8<T>::allocateBuffer()
{
if (is_allocate_buffer_ == false) {
check_cuda_error(cudaMallocHost((void**)&h_pinned_token_num_ptr_, sizeof(size_t)));
h_pinned_token_num_ptr_ = (size_t*)allocator_->reMalloc(h_pinned_token_num_ptr_, sizeof(size_t), true, true);
padding_offset_ =
(int*)allocator_->reMalloc(padding_offset_, sizeof(int) * max_batch_size_ * max_seq_len_, false);
trt_mha_padding_offset_ =
Expand All @@ -147,7 +147,7 @@ template<typename T>
void BertINT8<T>::freeBuffer()
{
if (is_allocate_buffer_ == true) {
check_cuda_error(cudaFreeHost(h_pinned_token_num_ptr_));
allocator_->free((void**)(&h_pinned_token_num_ptr_), true);
allocator_->free((void**)(&padding_offset_));
allocator_->free((void**)(&trt_mha_padding_offset_));

Expand Down
8 changes: 3 additions & 5 deletions src/fastertransformer/models/deberta/Deberta.cc
Original file line number Diff line number Diff line change
Expand Up @@ -204,10 +204,8 @@ template<typename T>
void Deberta<T>::allocateBuffer(size_t batch_size, size_t seq_len)
{
FT_LOG_DEBUG(__PRETTY_FUNCTION__);
if (!is_allocate_buffer_) {
check_cuda_error(cudaMallocHost((void**)&h_pinned_token_num_ptr_, sizeof(size_t)));
}
padding_offset_ = (int*)allocator_->reMalloc(padding_offset_, sizeof(int) * batch_size * seq_len, false);
h_pinned_token_num_ptr_ = (size_t*)allocator_->reMalloc(h_pinned_token_num_ptr_, sizeof(size_t), true, true);
padding_offset_ = (int*)allocator_->reMalloc(padding_offset_, sizeof(int) * batch_size * seq_len, false);
attention_mask_ = (T*)allocator_->reMalloc(attention_mask_, sizeof(T) * batch_size * seq_len * seq_len, false);

deberta_emb_buf_ =
Expand Down Expand Up @@ -241,7 +239,7 @@ void Deberta<T>::freeBuffer()
{
FT_LOG_DEBUG(__PRETTY_FUNCTION__);
if (is_allocate_buffer_) {
check_cuda_error(cudaFreeHost(h_pinned_token_num_ptr_));
allocator_->free((void**)(&h_pinned_token_num_ptr_), true);
allocator_->free((void**)(&padding_offset_));
allocator_->free((void**)(&attention_mask_));
allocator_->free((void**)(&deberta_emb_buf_));
Expand Down
7 changes: 2 additions & 5 deletions src/fastertransformer/models/gptj/GptJ.cc
Original file line number Diff line number Diff line change
Expand Up @@ -151,10 +151,7 @@ void GptJ<T>::allocateBuffer(
context_decoder_output_buf_, sizeof(T) * batchxbeam * max_input_len * hidden_units_, false));
output_log_probs_buf_ =
(float*)(allocator_->reMalloc(output_log_probs_buf_, sizeof(float) * batchxbeam * max_seq_len, false));

if (generation_should_stop_ == nullptr) {
cudaMallocHost(&generation_should_stop_, 1 * sizeof(bool));
}
generation_should_stop_ = (bool*)(allocator_->reMalloc(generation_should_stop_, sizeof(bool), true, true));

is_allocate_buffer_ = true;
}
Expand Down Expand Up @@ -206,7 +203,7 @@ void GptJ<T>::freeBuffer()
allocator_->free((void**)(&context_decoder_output_buf_));
allocator_->free((void**)(&output_log_probs_buf_));

cudaFreeHost(generation_should_stop_);
allocator_->free((void**)(&generation_should_stop_), true);

is_allocate_buffer_ = false;
}
Expand Down
6 changes: 2 additions & 4 deletions src/fastertransformer/models/gptj/GptJContextDecoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,7 @@ void GptJContextDecoder<T>::allocateBuffer(size_t batch_size, size_t seq_len)
allocator_->reMalloc(ffn_output_, sizeof(T) * batch_size * seq_len * hidden_units_, false));
decoder_layer_output_ = reinterpret_cast<T*>(
allocator_->reMalloc(decoder_layer_output_, sizeof(T) * batch_size * seq_len * hidden_units_, false));
if (!is_allocate_buffer_) {
check_cuda_error(cudaMallocHost((void**)&h_pinned_token_num_ptr_, sizeof(size_t)));
}
h_pinned_token_num_ptr_ = (size_t*)allocator_->reMalloc(h_pinned_token_num_ptr_, sizeof(size_t), true, true);
padding_offset_ =
reinterpret_cast<int*>(allocator_->reMalloc(padding_offset_, sizeof(int) * batch_size * seq_len, false));
cu_seqlens_ = reinterpret_cast<int*>(allocator_->reMalloc(cu_seqlens_, sizeof(int) * (batch_size + 1), false));
Expand All @@ -97,7 +95,7 @@ void GptJContextDecoder<T>::freeBuffer()
allocator_->free((void**)(&self_attn_output_));
allocator_->free((void**)(&ffn_output_));
allocator_->free((void**)(&decoder_layer_output_));
check_cuda_error(cudaFreeHost(h_pinned_token_num_ptr_));
allocator_->free((void**)(&h_pinned_token_num_ptr_), true);
allocator_->free((void**)(&padding_offset_));
allocator_->free((void**)(&cu_seqlens_));
is_allocate_buffer_ = false;
Expand Down
6 changes: 2 additions & 4 deletions src/fastertransformer/models/gptneox/GptNeoX.cc
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,7 @@ void GptNeoX<T>::allocateBuffer(
output_log_probs_buf_ =
(float*)(allocator_->reMalloc(output_log_probs_buf_, sizeof(float) * batchxbeam * max_seq_len, false));

if (generation_should_stop_ == nullptr) {
cudaMallocHost(&generation_should_stop_, 1 * sizeof(bool));
}
generation_should_stop_ = (bool*)allocator_->reMalloc(generation_should_stop_, sizeof(bool), true, true);

is_allocate_buffer_ = true;
}
Expand Down Expand Up @@ -201,7 +199,7 @@ void GptNeoX<T>::freeBuffer()
allocator_->free((void**)(&context_decoder_output_buf_));
allocator_->free((void**)(&output_log_probs_buf_));

cudaFreeHost(generation_should_stop_);
allocator_->free((void**)(&generation_should_stop_), true);

is_allocate_buffer_ = false;
}
Expand Down
6 changes: 2 additions & 4 deletions src/fastertransformer/models/gptneox/GptNeoXContextDecoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,7 @@ void GptNeoXContextDecoder<T>::allocateBuffer(size_t batch_size, size_t seq_len)
allocator_->reMalloc(ffn_output_, sizeof(T) * batch_size * seq_len * hidden_units_, false));
decoder_layer_output_ = reinterpret_cast<T*>(
allocator_->reMalloc(decoder_layer_output_, sizeof(T) * batch_size * seq_len * hidden_units_, false));
if (!is_allocate_buffer_) {
check_cuda_error(cudaMallocHost((void**)&h_pinned_token_num_ptr_, sizeof(size_t)));
}
h_pinned_token_num_ptr_ = (size_t*)allocator_->reMalloc(h_pinned_token_num_ptr_, sizeof(size_t), true, true);
padding_offset_ =
reinterpret_cast<int*>(allocator_->reMalloc(padding_offset_, sizeof(int) * batch_size * seq_len, false));
cu_seqlens_ = reinterpret_cast<int*>(allocator_->reMalloc(cu_seqlens_, sizeof(int) * (batch_size + 1), false));
Expand All @@ -97,7 +95,7 @@ void GptNeoXContextDecoder<T>::freeBuffer()
allocator_->free((void**)(&self_attn_output_));
allocator_->free((void**)(&ffn_output_));
allocator_->free((void**)(&decoder_layer_output_));
check_cuda_error(cudaFreeHost(h_pinned_token_num_ptr_));
allocator_->free((void**)(&h_pinned_token_num_ptr_), true);
allocator_->free((void**)(&padding_offset_));
allocator_->free((void**)(&cu_seqlens_));
is_allocate_buffer_ = false;
Expand Down
7 changes: 2 additions & 5 deletions src/fastertransformer/models/multi_gpu_gpt/ParallelGpt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -189,10 +189,7 @@ void ParallelGpt<T>::allocateBuffer(size_t batch_size,
compact_idx_ = shared_contexts_idx_ + 2 * batch_size;
compact_size_ = (int*)allocator_->reMalloc(compact_size_, sizeof(int), false);
}

if (generation_should_stop_ == nullptr) {
cudaMallocHost(&generation_should_stop_, 1 * sizeof(bool));
}
generation_should_stop_ = (bool*)allocator_->reMalloc(generation_should_stop_, sizeof(bool), true, true);
tiled_total_padding_count_ =
(int*)allocator_->reMalloc(tiled_total_padding_count_, batchxbeam * sizeof(int), false);

Expand Down Expand Up @@ -257,7 +254,7 @@ void ParallelGpt<T>::freeBuffer()
allocator_->free((void**)(&lp_nccl_logits_buf_));
allocator_->free((void**)(&lp_logprob_buf_));

cudaFreeHost(generation_should_stop_);
allocator_->free((void**)(&generation_should_stop_), true);

if (shared_contexts_ratio_ > 0.0f) {
allocator_->free((void**)(&shared_contexts_idx_));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,7 @@ void ParallelGptContextDecoder<T>::allocateBuffer(size_t batch_size, size_t seq_
ffn_intermediate_dynamic_scale_ = reinterpret_cast<float*>(
allocator_->reMalloc(ffn_intermediate_dynamic_scale_, sizeof(float) * batch_size * seq_len, true));
}
if (!is_allocate_buffer_) {
check_cuda_error(cudaMallocHost((void**)&h_pinned_token_num_ptr_, sizeof(size_t)));
}
h_pinned_token_num_ptr_ = (size_t*)allocator_->reMalloc(h_pinned_token_num_ptr_, sizeof(size_t), true, true);
padding_offset_ =
reinterpret_cast<int*>(allocator_->reMalloc(padding_offset_, sizeof(int) * batch_size * seq_len, false));
cu_seqlens_ = reinterpret_cast<int*>(allocator_->reMalloc(cu_seqlens_, sizeof(int) * (batch_size + 1), false));
Expand Down Expand Up @@ -160,7 +158,7 @@ void ParallelGptContextDecoder<T>::freeBuffer()
allocator_->free((void**)(&adapter_fc2_result_));
}
allocator_->free((void**)(&decoder_layer_output_));
check_cuda_error(cudaFreeHost(h_pinned_token_num_ptr_));
allocator_->free((void**)(&h_pinned_token_num_ptr_), true);
allocator_->free((void**)(&padding_offset_));
allocator_->free((void**)(&cu_seqlens_));

Expand Down
10 changes: 4 additions & 6 deletions src/fastertransformer/models/t5/T5Encoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ template<typename T>
void T5Encoder<T>::allocateBuffer()
{
if (is_allocate_buffer_ == false) {
check_cuda_error(cudaMallocHost((void**)&h_pinned_token_num_ptr_, sizeof(size_t)));
h_pinned_token_num_ptr_ = (size_t*)allocator_->reMalloc(h_pinned_token_num_ptr_, sizeof(size_t), true, true);
padding_offset_ =
(int*)allocator_->reMalloc(padding_offset_, sizeof(int) * max_batch_size_ * max_seq_len_, false);
trt_mha_padding_offset_ =
Expand Down Expand Up @@ -306,10 +306,8 @@ template<typename T>
void T5Encoder<T>::allocateBuffer(size_t batch_size, size_t seq_len)
{
FT_LOG_DEBUG(__PRETTY_FUNCTION__);
if (!is_allocate_buffer_) {
check_cuda_error(cudaMallocHost((void**)&h_pinned_token_num_ptr_, sizeof(size_t)));
}
padding_offset_ = (int*)allocator_->reMalloc(padding_offset_, sizeof(int) * batch_size * seq_len, false);
h_pinned_token_num_ptr_ = (size_t*)allocator_->reMalloc(h_pinned_token_num_ptr_, sizeof(size_t), true, true);
padding_offset_ = (int*)allocator_->reMalloc(padding_offset_, sizeof(int) * batch_size * seq_len, false);
trt_mha_padding_offset_ =
(int*)allocator_->reMalloc(trt_mha_padding_offset_, sizeof(int) * (2 * batch_size + 1), false);

Expand Down Expand Up @@ -359,7 +357,7 @@ template<typename T>
void T5Encoder<T>::freeBuffer()
{
if (is_allocate_buffer_) {
check_cuda_error(cudaFreeHost(h_pinned_token_num_ptr_));
allocator_->free((void**)(&h_pinned_token_num_ptr_), true);
allocator_->free((void**)(&padding_offset_));
allocator_->free((void**)(&trt_mha_padding_offset_));

Expand Down
8 changes: 3 additions & 5 deletions src/fastertransformer/models/vit/ViT.cc
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ void ViTTransformer<T>::allocateBuffer()
(T*)allocator_->reMalloc(mask_buf_, sizeof(T) * max_batch_size_ * max_seq_len_ * max_seq_len_, false);
padding_offset_ =
(int*)allocator_->reMalloc(padding_offset_, sizeof(int) * max_batch_size_ * max_seq_len_, false);
check_cuda_error(cudaMallocHost((void**)&h_pinned_token_num_ptr_, sizeof(size_t)));
h_pinned_token_num_ptr_ = (size_t*)allocator_->reMalloc(h_pinned_token_num_ptr_, sizeof(size_t), true, true);

trt_mha_padding_offset_ =
(int*)allocator_->reMalloc(trt_mha_padding_offset_, sizeof(int) * (2 * max_batch_size_ + 1), false);
Expand Down Expand Up @@ -229,9 +229,7 @@ void ViTTransformer<T>::allocateBuffer(size_t batch_size)
REMALLOC(embed_buf_3_, sizeof(T) * batch_size * max_seq_len_ * embed_dim_);
REMALLOC(mask_buf_, sizeof(T) * batch_size * max_seq_len_ * max_seq_len_);
REMALLOC(padding_offset_, sizeof(int) * batch_size * max_seq_len_);
if (!is_allocate_buffer_) {
check_cuda_error(cudaMallocHost((void**)&h_pinned_token_num_ptr_, sizeof(size_t)));
}
h_pinned_token_num_ptr_ = (size_t*)allocator_->reMalloc(h_pinned_token_num_ptr_, sizeof(size_t), true, true);
REMALLOC(trt_mha_padding_offset_, sizeof(int) * (2 * batch_size + 1));
REMALLOC(seq_len_vec_, sizeof(int) * batch_size);
resetBatch(batch_size);
Expand All @@ -253,7 +251,7 @@ void ViTTransformer<T>::freeBuffer()
allocator_->free((void**)(&trt_mha_padding_offset_));
allocator_->free((void**)(&seq_len_vec_));
allocator_->free((void**)(&padding_offset_));
check_cuda_error(cudaFreeHost(h_pinned_token_num_ptr_));
allocator_->free((void**)(&h_pinned_token_num_ptr_), true);
is_allocate_buffer_ = false;
}
}
Expand Down
8 changes: 3 additions & 5 deletions src/fastertransformer/models/vit_int8/ViTINT8.cc
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ void ViTTransformerINT8<T>::allocateBuffer()
(T*)allocator_->reMalloc(mask_buf_, sizeof(T) * max_batch_size_ * max_seq_len_ * max_seq_len_, false);
padding_offset_ =
(int*)allocator_->reMalloc(padding_offset_, sizeof(int) * max_batch_size_ * max_seq_len_, false);
check_cuda_error(cudaMallocHost((void**)&h_pinned_token_num_ptr_, sizeof(size_t)));
h_pinned_token_num_ptr_ = (size_t*)allocator_->reMalloc(h_pinned_token_num_ptr_, sizeof(size_t), true, true);

trt_mha_padding_offset_ =
(int*)allocator_->reMalloc(trt_mha_padding_offset_, sizeof(int) * (2 * max_batch_size_ + 1), false);
Expand Down Expand Up @@ -222,9 +222,7 @@ void ViTTransformerINT8<T>::allocateBuffer(size_t batch_size)
embed_buf_4_ = (T*)allocator_->reMalloc(embed_buf_4_, sizeof(T) * batch_size * max_seq_len_ * embed_dim_, false);
mask_buf_ = (T*)allocator_->reMalloc(mask_buf_, sizeof(T) * batch_size * max_seq_len_ * max_seq_len_, false);
REMALLOC(padding_offset_, sizeof(int) * batch_size * max_seq_len_);
if (!is_allocate_buffer_) {
check_cuda_error(cudaMallocHost((void**)&h_pinned_token_num_ptr_, sizeof(size_t)));
}
h_pinned_token_num_ptr_ = (size_t*)allocator_->reMalloc(h_pinned_token_num_ptr_, sizeof(size_t), true, true);
trt_mha_padding_offset_ =
(int*)allocator_->reMalloc(trt_mha_padding_offset_, sizeof(int) * (2 * batch_size + 1), false);
seq_len_vec_ = (int*)allocator_->reMalloc(seq_len_vec_, sizeof(int) * batch_size, false);
Expand All @@ -249,7 +247,7 @@ void ViTTransformerINT8<T>::freeBuffer()
allocator_->free((void**)(&trt_mha_padding_offset_));
allocator_->free((void**)(&seq_len_vec_));
allocator_->free((void**)(&padding_offset_));
check_cuda_error(cudaFreeHost(h_pinned_token_num_ptr_));
allocator_->free((void**)(&h_pinned_token_num_ptr_), true);

is_allocate_buffer_ = false;
}
Expand Down
4 changes: 2 additions & 2 deletions src/fastertransformer/models/wenet/WenetEncoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ void WenetEncoder<T>::initialize()
check_cuda_error(cudaStreamCreate(&stream2_));
check_cuda_error(cudaEventCreate(&stream_finished_));
check_cuda_error(cudaEventCreate(&stream2_finished_));
check_cuda_error(cudaMallocHost((void**)&h_var_token_num_, sizeof(size_t)));
h_var_token_num_ = (size_t*)allocator_->reMalloc(h_var_token_num_, sizeof(size_t), true, true);

attention_layer_ = new RelPositionAttentionLayer<T>(0,
0,
Expand Down Expand Up @@ -170,7 +170,7 @@ WenetEncoder<T>::~WenetEncoder()
delete ffn_layer_;
delete conformer_conv_layer_;

check_cuda_error(cudaFreeHost(h_var_token_num_));
allocator_->free((void**)(&h_var_token_num_), true);
check_cuda_error(cudaEventDestroy(stream2_finished_));
check_cuda_error(cudaEventDestroy(stream_finished_));
check_cuda_error(cudaStreamDestroy(stream2_));
Expand Down
Loading

0 comments on commit 90cc5a7

Please sign in to comment.