Skip to content

Commit

Permalink
Fix sync in Resize operator family (NVIDIA#4990)
Browse files Browse the repository at this point in the history
Signed-off-by: Joaquin Anton <[email protected]>
  • Loading branch information
jantonguirao authored Aug 11, 2023
1 parent 4589b6a commit 2352071
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions dali/kernels/imgproc/resample/resampling_filters.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "dali/kernels/imgproc/resample/resampling_windows.h"
#include "dali/core/mm/memory.h"
#include "dali/core/span.h"
#include "dali/core/cuda_stream_pool.h"

namespace dali {
namespace kernels {
Expand Down Expand Up @@ -103,10 +104,13 @@ void InitFilters(ResamplingFilters &filters) {
filters[3].rescale(4);

if (need_staging) {
auto filter_data_gpu = mm::alloc_raw_unique<float, mm::memory_kind::device>(total_size);
CUDA_CALL(cudaMemcpy(filter_data_gpu.get(), filters.filter_data.get(),
total_size * sizeof(float), cudaMemcpyHostToDevice));
ptrdiff_t diff = filter_data_gpu.get() - filters.filter_data.get();
auto cuda_stream = CUDAStreamPool::instance().Get();
auto filter_data_gpu = mm::alloc_raw_async_unique<float, mm::memory_kind::device>(
total_size, cuda_stream, mm::host_sync);
CUDA_CALL(cudaMemcpyAsync(filter_data_gpu.get(), filters.filter_data.get(),
total_size * sizeof(float), cudaMemcpyHostToDevice, cuda_stream));
CUDA_CALL(cudaStreamSynchronize(cuda_stream));
ptrdiff_t diff = filter_data_gpu.get() - filters.filter_data.get();
filters.filter_data = std::move(filter_data_gpu);
for (auto &f : filters.filters)
f.coeffs += diff;
Expand Down

0 comments on commit 2352071

Please sign in to comment.