Skip to content

Commit

Permalink
Add specializations of Operator class for all backends (NVIDIA#934)
Browse files Browse the repository at this point in the history
Add RunImpl for HostWorkspace in CPU Ops.

Add using Operator::RunImpl so to not hide it when overriding only one overload in CPU Ops.

Refactor before changing CPU Ops to batch processing.

Signed-off-by: Krzysztof Lecki [email protected]
  • Loading branch information
klecki authored Jun 6, 2019
1 parent e47781a commit ae7c34e
Show file tree
Hide file tree
Showing 23 changed files with 120 additions and 38 deletions.
1 change: 1 addition & 0 deletions dali/pipeline/operators/color/color_twist.h
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ class ColorTwistBase : public Operator<Backend> {
const int C_;

USE_OPERATOR_MEMBERS();
using Operator<Backend>::RunImpl;

private:
void IdentityMatrix(float * matrix) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class ColorSpaceConversion : public Operator<Backend> {
void RunImpl(Workspace<Backend> *ws, const int idx) override;

USE_OPERATOR_MEMBERS();
using Operator<Backend>::RunImpl;

const DALIImageType input_type_;
const DALIImageType output_type_;
Expand Down
1 change: 1 addition & 0 deletions dali/pipeline/operators/crop/crop.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class Crop : public SliceBase<Backend>, protected CropAttr {
using SliceBase<Backend>::input_type_;
using SliceBase<Backend>::output_type_;
USE_OPERATOR_MEMBERS();
using Operator<Backend>::RunImpl;
std::size_t C_;

void SetupSample(int data_idx, DALITensorLayout layout, const vector<Index> &shape) {
Expand Down
1 change: 1 addition & 0 deletions dali/pipeline/operators/crop/slice_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class SliceBase : public Operator<Backend> {
kernels::ScratchpadAllocator, std::vector<kernels::ScratchpadAllocator>>::type scratch_alloc_;

USE_OPERATOR_MEMBERS();
using Operator<Backend>::RunImpl;
};

} // namespace dali
Expand Down
1 change: 1 addition & 0 deletions dali/pipeline/operators/detection/box_encoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ class BoxEncoder<CPUBackend>: public Operator<CPUBackend> {

protected:
void RunImpl(Workspace<CPUBackend> *ws, const int idx) override;
using Operator<CPUBackend>::RunImpl;

private:
const float criteria_;
Expand Down
1 change: 1 addition & 0 deletions dali/pipeline/operators/detection/random_crop.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class SSDRandomCrop : public Operator<Backend> {
DISABLE_COPY_MOVE_ASSIGN(SSDRandomCrop);

USE_OPERATOR_MEMBERS();
using Operator<Backend>::RunImpl;

protected:
void RunImpl(Workspace<Backend> * ws, const int idx) override;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ class DisplacementFilter<CPUBackend, Displacement, per_channel_transform>
}

USE_OPERATOR_MEMBERS();
using Operator<CPUBackend>::RunImpl;

private:
// TODO(klecki) We could probably interpolate with something other than float,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,7 @@ class DisplacementFilter<GPUBackend, Displacement,
}

USE_OPERATOR_MEMBERS();
using Operator<GPUBackend>::RunImpl;

private:
static const size_t nDims = 3;
Expand Down
1 change: 1 addition & 0 deletions dali/pipeline/operators/fused/crop_mirror_normalize.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ class CropMirrorNormalize : public Operator<Backend>, protected CropAttr {
std::vector<std::pair<int, int>> per_sample_dimensions_;

USE_OPERATOR_MEMBERS();
using Operator<Backend>::RunImpl;
};

} // namespace dali
Expand Down
1 change: 1 addition & 0 deletions dali/pipeline/operators/fused/normalize_permute.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ class NormalizePermute : public Operator<Backend> {
vector<Dims> output_shape_;

USE_OPERATOR_MEMBERS();
using Operator<Backend>::RunImpl;
};

} // namespace dali
Expand Down
2 changes: 2 additions & 0 deletions dali/pipeline/operators/fused/resize_crop_mirror.h
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,8 @@ class ResizeCropMirror : public Operator<CPUBackend>, protected ResizeCropMirror
vector<vector<uint8>> tl_workspace_;
vector<TransformMeta> per_thread_meta_;
USE_OPERATOR_MEMBERS();
using Operator<Backend>::RunImpl;
using Operator<Backend>::SetupSharedSampleParams;
};

/**
Expand Down
1 change: 1 addition & 0 deletions dali/pipeline/operators/geometric/bb_flip.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class BbFlip<CPUBackend> : public Operator<CPUBackend> {

protected:
void RunImpl(SampleWorkspace *ws, const int idx) override;
using Operator<CPUBackend>::RunImpl;

private:
/**
Expand Down
135 changes: 97 additions & 38 deletions dali/pipeline/operators/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,17 +71,17 @@ inline void CheckInputLayouts(const DeviceWorkspace *ws, const OpSpec &spec) {
*/
class DLL_PUBLIC OperatorBase {
public:
DLL_PUBLIC inline explicit OperatorBase(const OpSpec &spec) :
spec_(spec), num_threads_(spec.GetArgument<int>("num_threads")),
batch_size_(spec.GetArgument<int>("batch_size")),
input_sets_(spec.GetArgument<int>("num_input_sets")),
default_cuda_stream_priority_(spec.GetArgument<int>("default_cuda_stream_priority")) {
DLL_PUBLIC inline explicit OperatorBase(const OpSpec &spec)
: spec_(spec),
num_threads_(spec.GetArgument<int>("num_threads")),
batch_size_(spec.GetArgument<int>("batch_size")),
input_sets_(spec.GetArgument<int>("num_input_sets")),
default_cuda_stream_priority_(spec.GetArgument<int>("default_cuda_stream_priority")) {
DALI_ENFORCE(num_threads_ > 0, "Invalid value for argument num_threads.");
DALI_ENFORCE(batch_size_ > 0, "Invalid value for argument batch_size.");
}

DLL_PUBLIC virtual inline ~OperatorBase() noexcept(false)
{}
DLL_PUBLIC virtual inline ~OperatorBase() noexcept(false) {}

/**
* @brief Executes the operator on a single sample on the CPU.
Expand Down Expand Up @@ -159,68 +159,127 @@ class DLL_PUBLIC OperatorBase {
* name (the first arg to the registration macro).
*/
template <typename Backend>
class Operator : public OperatorBase {
class Operator : public OperatorBase {};

template <>
class Operator<SupportBackend> : public OperatorBase {
public:
inline explicit Operator(const OpSpec &spec) : OperatorBase(spec) {}

inline ~Operator() noexcept(false) override {}

using OperatorBase::Run;
void Run(SupportWorkspace *ws) override {
CheckInputLayouts(ws, spec_);
SetupSharedSampleParams(ws);
for (int i = 0; i < input_sets_; ++i) {
RunImpl(ws, i);
}
}

/**
* @brief Implementation of the operator - to be
* implemented by derived ops.
*/
virtual void RunImpl(SupportWorkspace *ws, int idx = 0) = 0;

/**
* @brief Shared param setup
*/
virtual void SetupSharedSampleParams(SupportWorkspace *ws) {}
};

template <>
class Operator<CPUBackend> : public OperatorBase {
public:
inline explicit Operator(const OpSpec &spec) :
OperatorBase(spec)
{}
inline explicit Operator(const OpSpec &spec) : OperatorBase(spec) {}

inline ~Operator() noexcept(false) override
{}
inline ~Operator() noexcept(false) override {}

using OperatorBase::Run;
void Run(Workspace<Backend> *ws) override {
void Run(SampleWorkspace *ws) override {
CheckInputLayouts(ws, spec_);
SetupSharedSampleParams(ws);
for (int i = 0; i < input_sets_; ++i) {
if (std::is_same<Backend, GPUBackend>::value) {
// Before we start working on the next input set, we need
// to wait until the last one is finished. Otherwise for some ops
// we risk overwriting data used by the kernel called for previous
// image. Doing it for all ops is a compromise between performance
// (which should not be greatly affected) and robustness (guarding
// against this potential problem for newly added ops)
SyncHelper(i, ws);
}
RunImpl(ws, i);
}
}

/**
* @brief Legacy implementation of CPU operator using per-sample approach
*
* Usage of this API will be deprecated.
*/
virtual void RunImpl(SampleWorkspace *ws, int idx = 0) {}

/**
* @brief Implementation of the operator - to be implemented by derived ops.
*/
virtual void RunImpl(HostWorkspace *ws, int idx = 0) {
DALI_ENFORCE(false, "Not implemented yet");
}

/**
* @brief Shared param setup. Legacy implementation for per-sample approach
*
* Usage of this API will be deprecated.
*/
virtual void SetupSharedSampleParams(SampleWorkspace *ws) {}

/**
* @brief Shared param setup
*/
virtual void SetupSharedSampleParams(Workspace<Backend> *ws) {}
virtual void SetupSharedSampleParams(HostWorkspace *ws) {}
};

template <>
class Operator<GPUBackend> : public OperatorBase {
public:
inline explicit Operator(const OpSpec &spec) : OperatorBase(spec) {}

inline ~Operator() noexcept(false) override {}

using OperatorBase::Run;
void Run(DeviceWorkspace *ws) override {
CheckInputLayouts(ws, spec_);
SetupSharedSampleParams(ws);
for (int i = 0; i < input_sets_; ++i) {
// Before we start working on the next input set, we need
// to wait until the last one is finished. Otherwise for some ops
// we risk overwriting data used by the kernel called for previous
// image. Doing it for all ops is a compromise between performance
// (which should not be greatly affected) and robustness (guarding
// against this potential problem for newly added ops)
SyncHelper(i, ws);
RunImpl(ws, i);
}
}

/**
* @brief Implementation of the operator - to be
* implemented by derived ops.
*/
virtual void RunImpl(Workspace<Backend> *ws, int idx = 0) = 0;
virtual void RunImpl(DeviceWorkspace *ws, int idx = 0) = 0;

/**
* @brief Shared param setup
*/
virtual void SetupSharedSampleParams(DeviceWorkspace *ws) {}

private:
// SINFAE for Run is not possible as we want it to be virtual
template <typename B = Backend>
typename std::enable_if<std::is_same<B, GPUBackend>::value>::type
SyncHelper(int i, Workspace<B> *ws) {
void SyncHelper(int i, DeviceWorkspace *ws) {
if (i != 0) {
CUDA_CALL(cudaStreamSynchronize(ws->stream()));
}
}

template <typename B = Backend>
typename std::enable_if<!std::is_same<B, GPUBackend>::value>::type
SyncHelper(int /*unused*/, Workspace<B> */*unused*/) {}
};

template<>
class Operator<MixedBackend> : public OperatorBase {
public:
inline explicit Operator(const OpSpec &spec) :
OperatorBase(spec)
{}
inline explicit Operator(const OpSpec &spec) : OperatorBase(spec) {}

inline ~Operator() noexcept(false) override
{}
inline ~Operator() noexcept(false) override {}

using OperatorBase::Run;
void Run(MixedWorkspace *ws) override = 0;
Expand Down
1 change: 1 addition & 0 deletions dali/pipeline/operators/paste/bbox_paste.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class BBoxPaste : public Operator<Backend> {
void RunImpl(Workspace<Backend> *ws, const int idx) override;

USE_OPERATOR_MEMBERS();
using Operator<Backend>::RunImpl;
};

} // namespace dali
Expand Down
1 change: 1 addition & 0 deletions dali/pipeline/operators/paste/paste.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ class Paste : public Operator<Backend> {
Tensor<GPUBackend> input_ptrs_gpu_, output_ptrs_gpu_, in_out_dims_paste_yx_gpu_;

USE_OPERATOR_MEMBERS();
using Operator<Backend>::RunImpl;
};

} // namespace dali
Expand Down
1 change: 1 addition & 0 deletions dali/pipeline/operators/python_function/python_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class PythonFunctionImpl : public Operator<Backend> {
void RunImpl(Workspace<Backend> *ws, const int idx) override;

USE_OPERATOR_MEMBERS();
using Operator<Backend>::RunImpl;

py::object python_function;
};
Expand Down
1 change: 1 addition & 0 deletions dali/pipeline/operators/resize/random_resized_crop.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class RandomResizedCrop : public Operator<Backend>
DISABLE_COPY_MOVE_ASSIGN(RandomResizedCrop);

USE_OPERATOR_MEMBERS();
using Operator<Backend>::RunImpl;

protected:
void RunImpl(Workspace<Backend> * ws, const int idx) override;
Expand Down
1 change: 1 addition & 0 deletions dali/pipeline/operators/resize/resize.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ class Resize : public Operator<Backend>
}

USE_OPERATOR_MEMBERS();
using Operator<Backend>::RunImpl;
bool save_attrs_;
int outputs_per_idx_;
};
Expand Down
1 change: 1 addition & 0 deletions dali/pipeline/operators/sequence/element_extract.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class ElementExtract : public Operator<Backend> {
void RunImpl(Workspace<Backend> *ws, const int idx) override;

USE_OPERATOR_MEMBERS();
using Operator<Backend>::RunImpl;

private:
std::vector<int> element_map_;
Expand Down
1 change: 1 addition & 0 deletions dali/pipeline/operators/support/random/coin_flip.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class CoinFlip : public Operator<SupportBackend> {
DISABLE_COPY_MOVE_ASSIGN(CoinFlip);

USE_OPERATOR_MEMBERS();
using Operator<SupportBackend>::RunImpl;

protected:
void RunImpl(Workspace<SupportBackend> * ws, const int idx) override;
Expand Down
1 change: 1 addition & 0 deletions dali/pipeline/operators/support/random/uniform.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class Uniform : public Operator<SupportBackend> {
DISABLE_COPY_MOVE_ASSIGN(Uniform);

USE_OPERATOR_MEMBERS();
using Operator<SupportBackend>::RunImpl;

protected:
void RunImpl(Workspace<SupportBackend> * ws, const int idx) override;
Expand Down
1 change: 1 addition & 0 deletions dali/pipeline/operators/transpose/transpose.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class Transpose : public Operator<Backend> {
Dims previous_iter_shape_;

USE_OPERATOR_MEMBERS();
using Operator<Backend>::RunImpl;
};

} // namespace dali
Expand Down
1 change: 1 addition & 0 deletions dali/pipeline/operators/util/cast.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class Cast : public Operator<Backend> {
DALIDataType output_type_;

USE_OPERATOR_MEMBERS();
using Operator<Backend>::RunImpl;
};

} // namespace dali
Expand Down

0 comments on commit ae7c34e

Please sign in to comment.