Skip to content

Commit

Permalink
Add ROI support to NumpyReader GPU (NVIDIA#3034)
Browse files Browse the repository at this point in the history
Signed-off-by: Joaquin Anton <[email protected]>
  • Loading branch information
jantonguirao authored Jun 17, 2021
1 parent 0ebcf4e commit 39de51b
Show file tree
Hide file tree
Showing 14 changed files with 276 additions and 282 deletions.
3 changes: 2 additions & 1 deletion dali/operators/reader/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2017-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -25,6 +25,7 @@ list(APPEND DALI_OPERATOR_SRCS "${CMAKE_CURRENT_SOURCE_DIR}/numpy_reader_op.cc")

if (BUILD_CUFILE)
list(APPEND DALI_OPERATOR_SRCS "${CMAKE_CURRENT_SOURCE_DIR}/numpy_reader_gpu_op.cc")
list(APPEND DALI_OPERATOR_SRCS "${CMAKE_CURRENT_SOURCE_DIR}/numpy_reader_gpu_op_impl.cu")
endif()

list(APPEND DALI_OPERATOR_SRCS "${CMAKE_CURRENT_SOURCE_DIR}/sequence_reader_op.cc")
Expand Down
72 changes: 1 addition & 71 deletions dali/operators/reader/loader/file_loader.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,6 @@

namespace dali {

namespace detail {

template <typename T, typename result = decltype(std::declval<T &>().data)>
inline std::true_type HasData(T *);
inline std::false_type HasData(...);

/// @brief Inherits `true_type`, `if T::data` exists
template <typename T>
struct has_data : decltype(HasData((T*)0)) {}; // NOLINT

} // namespace detail

struct FileWrapper {
Tensor<CPUBackend> data;
std::string filename;
Expand Down Expand Up @@ -112,65 +100,6 @@ class FileLoader : public Loader<Backend, Target> {
copy_read_data_ = dont_use_mmap_ || !mmap_reserver_.CanShareMappedData();
}

std::enable_if_t<detail::has_data<Target>::value, void>
PrepareEmpty(Target &target) override {
PrepareEmptyTensor(target.data);
target.filename.clear();
}

std::enable_if_t<detail::has_data<Target>::value, void>
ReadSample(Target &target) override {
auto filename = files_[current_index_++];

// handle wrap-around
MoveToNextShard(current_index_);

// metadata info
DALIMeta meta;
meta.SetSourceInfo(filename);
meta.SetSkipSample(false);

// if data is cached, skip loading
if (ShouldSkipImage(filename)) {
meta.SetSkipSample(true);
target.data.Reset();
target.data.SetMeta(meta);
target.data.set_type(TypeInfo::Create<uint8_t>());
target.data.Resize({0});
target.filename = "";
return;
}

auto current_file = InputStream::Open(filesystem::join_path(file_root_, filename),
read_ahead_, !copy_read_data_);
Index file_size = current_file->Size();

if (copy_read_data_) {
if (target.data.shares_data()) {
target.data.Reset();
}
target.data.Resize({file_size});
// copy the data
Index ret = current_file->Read(target.data.template mutable_data<uint8_t>(), file_size);
DALI_ENFORCE(ret == file_size, make_string("Failed to read file: ", filename));
} else {
auto p = current_file->Get(file_size);
DALI_ENFORCE(p != nullptr, make_string("Failed to read file: ", filename));
// Wrap the raw data in the Tensor object.
target.data.ShareData(p, file_size, {file_size});
target.data.set_type(TypeInfo::Create<uint8_t>());
}

// close the file handle
current_file->Close();

// set metadata
target.data.SetMeta(meta);

// set string
target.filename = filesystem::join_path(file_root_, filename);
}

protected:
Index SizeImpl() override {
return static_cast<Index>(files_.size());
Expand Down Expand Up @@ -246,6 +175,7 @@ class FileLoader : public Loader<Backend, Target> {
typename InputStream::MappingReserver mmap_reserver_;
};


} // namespace dali

#endif // DALI_OPERATORS_READER_LOADER_FILE_LOADER_H_
3 changes: 1 addition & 2 deletions dali/operators/reader/loader/loader.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
// Copyright (c) 2017-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -149,7 +149,6 @@ class Loader {
initial_buffer_filled_ = true;
}

int samples_to_choose_from = initial_buffer_fill_;
if (shards_.front().start == shards_.front().end) {
// If the reader has depleted samples from the given shard, but shards are not equal
// and we need to pad samples inside batch (even create a whole new dummy batch) using padding
Expand Down
2 changes: 1 addition & 1 deletion dali/operators/reader/loader/numpy_loader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ void NumpyLoader::ReadSample(NumpyFileWrapper& target) {
target.filename = file_root_ + "/" + filename;

// set meta
target.fortan_order = parse_target.fortran_order;
target.fortran_order = parse_target.fortran_order;
}

} // namespace dali
8 changes: 4 additions & 4 deletions dali/operators/reader/loader/numpy_loader.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,17 +57,17 @@ class NumpyParseTarget{
struct NumpyFileWrapper {
Tensor<CPUBackend> data;
std::string filename;
bool fortan_order;
bool fortran_order;

const TypeInfo& type() const {
const TypeInfo& get_type() const {
return data.type();
}

const TensorShape<>& shape() const {
const TensorShape<>& get_shape() const {
return data.shape();
}

const DALIMeta& meta() const {
const DALIMeta& get_meta() const {
return data.GetMeta();
}
};
Expand Down
2 changes: 1 addition & 1 deletion dali/operators/reader/loader/numpy_loader_gpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ void NumpyLoaderGPU::ReadSample(NumpyFileWrapperGPU& target) {

target.type = parse_target.type_info;
target.shape = parse_target.shape;
target.fortan_order = parse_target.fortran_order;
target.fortran_order = parse_target.fortran_order;
};

target.read_sample_f = [this, filename, &target] (void *buffer, Index file_offset,
Expand Down
2 changes: 1 addition & 1 deletion dali/operators/reader/loader/numpy_loader_gpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ namespace dali {

struct NumpyFileWrapperGPU {
std::string filename;
bool fortan_order;
bool fortran_order;
TensorShape<> shape;
TypeInfo type;
DALIMeta meta;
Expand Down
114 changes: 14 additions & 100 deletions dali/operators/reader/numpy_reader_gpu_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,8 @@
// limitations under the License.

#include <string>

#include "dali/pipeline/data/views.h"
#include "dali/kernels/transpose/transpose_gpu.h"
#include "dali/core/static_switch.h"
#include "dali/operators/reader/numpy_reader_gpu_op.h"
#include "dali/pipeline/data/views.h"

namespace dali {

Expand All @@ -38,31 +35,37 @@ void NumpyReaderGPU::Prefetch() {
thread_pool_.RunAll();

// resize the current batch
std::vector<TensorShape<>> tmp_shapes;
auto ref_type = curr_batch[0]->get_type();
auto ref_shape = curr_batch[0]->get_shape();
TensorListShape<> tmp_shapes(curr_batch.size(), ref_shape.sample_dim());
for (size_t data_idx = 0; data_idx < curr_batch.size(); ++data_idx) {
auto &sample = curr_batch[data_idx];
DALI_ENFORCE(ref_type == sample->get_type(), make_string("Inconsistent data! "
"The data produced by the reader has inconsistent type:\n"
"type of [", data_idx, "] is ", sample->get_type().id(), " whereas\n"
"type of [0] is ", ref_type.id()));

DALI_ENFORCE(ref_shape.size() == sample->get_shape().size(), make_string("Inconsistent data! "
"The data produced by the reader has inconsistent dimensionality:\n"
"[", data_idx, "] has ", sample->get_shape().size(), " dimensions whereas\n"
"[0] has ", ref_shape.size(), " dimensions."));
tmp_shapes.push_back(sample->get_shape());
DALI_ENFORCE(
ref_shape.sample_dim() == sample->get_shape().sample_dim(),
make_string(
"Inconsistent data! The data produced by the reader has inconsistent dimensionality:\n"
"[",
data_idx, "] has ", sample->get_shape().sample_dim(),
" dimensions whereas\n"
"[0] has ",
ref_shape.sample_dim(), " dimensions."));
tmp_shapes.set_tensor_shape(data_idx, sample->get_shape());
}

curr_tensor_list.Resize(TensorListShape<>(tmp_shapes), ref_type);
curr_tensor_list.Resize(tmp_shapes, ref_type);

size_t chunk_size = static_cast<size_t>( \
div_ceil(static_cast<uint64_t>(curr_tensor_list.nbytes()),
static_cast<uint64_t>(thread_pool_.NumThreads())));

// read the data
for (size_t data_idx = 0; data_idx < curr_tensor_list.ntensor(); ++data_idx) {
curr_tensor_list.SetMeta(data_idx, curr_batch[data_idx]->get_meta());
size_t image_bytes = static_cast<size_t>(volume(curr_tensor_list.tensor_shape(data_idx))
* curr_tensor_list.type().size());
uint8_t* dst_ptr = static_cast<uint8_t*>(curr_tensor_list.raw_mutable_tensor(data_idx));
Expand All @@ -87,95 +90,6 @@ void NumpyReaderGPU::Prefetch() {
}
}

void PermuteHelper(const TensorShape<> &plain_shapes, std::vector<int64_t> &perm_shape,
std::vector<int> &perm) {
int n_dims = plain_shapes.size();
if (perm.empty()) {
perm.resize(n_dims);
for (int i = 0; i < n_dims; ++i) {
perm[i] = n_dims - i - 1;
}
}
for (int i = 0; i < n_dims; ++i) {
perm_shape[i] = plain_shapes[perm[i]];
}
}

void NumpyReaderGPU::RunImpl(DeviceWorkspace &ws) {
TensorListShape<> shape(max_batch_size_);
// use vector for temporarily storing shapes
std::vector<TensorShape<>> tmp_shapes;
std::vector<TensorShape<>> transpose_shapes;
std::vector<int> perm;
std::vector<int64_t> perm_shape;

perm.reserve(GetSampleShape(0).size());
perm_shape.resize(GetSampleShape(0).size());

for (int sample_idx = 0; sample_idx < max_batch_size_; sample_idx++) {
const auto& target = GetSample(sample_idx);
auto plain_shape = GetSampleShape(sample_idx);
if (target.fortan_order) {
PermuteHelper(plain_shape, perm_shape, perm);
tmp_shapes.push_back(perm_shape);
transpose_shapes.push_back(plain_shape);
} else {
tmp_shapes.push_back(plain_shape);
}
}
auto ref_type = GetSampleType(0);
shape = TensorListShape<>(tmp_shapes);
ws.Output<GPUBackend>(0).Resize(shape, ref_type);

auto &image_output = ws.Output<GPUBackend>(0);

SmallVector<int64_t, 256> copy_sizes;
copy_sizes.reserve(max_batch_size_);
SmallVector<const void *, 256> copy_from;
copy_from.reserve(max_batch_size_);
SmallVector<void *, 256> copy_to;
copy_to.reserve(max_batch_size_);

SmallVector<const void *, 256> transpose_from;
transpose_from.reserve(max_batch_size_);
SmallVector<void *, 256> transpose_to;
transpose_to.reserve(max_batch_size_);


for (int data_idx = 0; data_idx < max_batch_size_; ++data_idx) {
const auto& target = GetSample(data_idx);
if (target.fortan_order) {
transpose_from.push_back(GetSampleRawData(data_idx));
transpose_to.push_back(image_output.raw_mutable_tensor(data_idx));
} else {
copy_from.push_back(GetSampleRawData(data_idx));
copy_to.push_back(image_output.raw_mutable_tensor(data_idx));
copy_sizes.push_back(shape.tensor_size(data_idx));
}
image_output.SetMeta(data_idx, target.get_meta());
}

if (transpose_from.empty() && !copy_sizes.empty()) {
std::swap(image_output, prefetched_batch_tensors_[curr_batch_consumer_]);
} else {
// use copy kernel for plan samples
if (!copy_sizes.empty()) {
ref_type.template Copy<GPUBackend, GPUBackend>(copy_to.data(), copy_from.data(),
copy_sizes.data(), copy_sizes.size(),
ws.stream(), true);
}

// transpose remaining samples
if (!transpose_from.empty()) {
kernels::KernelContext ctx;
ctx.gpu.stream = ws.stream();
kmgr_.Setup<TransposeKernel>(0, ctx, TensorListShape<>(transpose_shapes), make_span(perm),
ref_type.size());
kmgr_.Run<TransposeKernel>(0, 0, ctx, transpose_to.data(), transpose_from.data());
}
}
}

DALI_REGISTER_OPERATOR(readers__Numpy, NumpyReaderGPU, GPU);

// Deprecated alias
Expand Down
Loading

0 comments on commit 39de51b

Please sign in to comment.