Skip to content

Commit

Permalink
EP context for custom op (microsoft#16454)
Browse files Browse the repository at this point in the history
Implement infrastructures to allow EP resources surfaced to custom ops.

---------

Co-authored-by: Randy Shuai <[email protected]>
  • Loading branch information
RandySheriffH and RandyShuai authored Aug 16, 2023
1 parent 7b9d1f1 commit 3dd2c1b
Show file tree
Hide file tree
Showing 26 changed files with 669 additions and 247 deletions.
37 changes: 29 additions & 8 deletions cmake/onnxruntime_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -1450,19 +1450,40 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP)
endif()

if (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten")

set(custom_op_src_patterns
"${TEST_SRC_DIR}/testdata/custom_op_library/*.h"
"${TEST_SRC_DIR}/testdata/custom_op_library/*.cc"
"${TEST_SRC_DIR}/testdata/custom_op_library/cpu/cpu_ops.*"
)

set(custom_op_lib_include ${REPO_ROOT}/include)
set(custom_op_lib_option)
set(custom_op_lib_link ${GSL_TARGET})

if (onnxruntime_USE_CUDA)
onnxruntime_add_shared_library(custom_op_library ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/cuda_ops.cu
${TEST_SRC_DIR}/testdata/custom_op_library/custom_op_library.cc)
target_include_directories(custom_op_library PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
list(APPEND custom_op_src_patterns
"${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/cuda_ops.cu"
"${TEST_SRC_DIR}/testdata/custom_op_library/cuda/cuda_ops.*")
list(APPEND custom_op_lib_include ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES} ${onnxruntime_CUDNN_HOME}/include)
if (HAS_QSPECTRE)
target_compile_options(custom_op_library PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:SHELL:-Xcompiler /Qspectre>")
list(APPEND custom_op_lib_option "$<$<COMPILE_LANGUAGE:CUDA>:SHELL:-Xcompiler /Qspectre>")
endif()
else()
onnxruntime_add_shared_library(custom_op_library ${TEST_SRC_DIR}/testdata/custom_op_library/custom_op_library.cc)
endif()

target_include_directories(custom_op_library PRIVATE ${REPO_ROOT}/include)
target_link_libraries(custom_op_library PRIVATE ${GSL_TARGET})
if (onnxruntime_USE_ROCM)
list(APPEND custom_op_src_patterns
"${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/rocm_ops.hip"
"${TEST_SRC_DIR}/testdata/custom_op_library/rocm/rocm_ops.*")
list(APPEND custom_op_lib_include ${onnxruntime_ROCM_HOME}/include)
list(APPEND custom_op_lib_option "-D__HIP_PLATFORM_AMD__=1 -D__HIP_PLATFORM_HCC__=1")
endif()

file(GLOB custom_op_src ${custom_op_src_patterns})
onnxruntime_add_shared_library(custom_op_library ${custom_op_src})
target_compile_options(custom_op_library PRIVATE ${custom_op_lib_option})
target_include_directories(custom_op_library PRIVATE ${REPO_ROOT}/include ${custom_op_lib_include})
target_link_libraries(custom_op_library PRIVATE ${GSL_TARGET} ${custom_op_lib_link})

if(UNIX)
if (APPLE)
Expand Down
4 changes: 4 additions & 0 deletions include/onnxruntime/core/framework/stream_handles.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,10 @@ class Stream {
}
}

virtual void* GetResource(int /*version*/, int /*id*/) const {
return nullptr;
}

private:
StreamHandle handle_;
const OrtDevice& device_;
Expand Down
51 changes: 51 additions & 0 deletions include/onnxruntime/core/providers/cuda/cuda_context.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once

#define ORT_CUDA_CTX

#include "cuda_resource.h"
#include "core/providers/custom_op_context.h"
#include <cuda.h>
#include <cuda_runtime.h>
#include <cublas_v2.h>
#include <cudnn.h>

namespace Ort {

namespace Custom {

struct CudaContext : public CustomOpContext {
cudaStream_t cuda_stream = {};
cudnnHandle_t cudnn_handle = {};
cublasHandle_t cublas_handle = {};

void Init(const OrtKernelContext& kernel_ctx) override {
const auto& ort_api = Ort::GetApi();
void* resource = {};
OrtStatus* status = nullptr;

status = ort_api.KernelContext_GetResource(&kernel_ctx, ORT_CUDA_RESOUCE_VERSION, CudaResource::cuda_stream_t, &resource);
if (status) {
ORT_CXX_API_THROW("failed to fetch cuda stream", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
}
cuda_stream = reinterpret_cast<cudaStream_t>(resource);

resource = {};
status = ort_api.KernelContext_GetResource(&kernel_ctx, ORT_CUDA_RESOUCE_VERSION, CudaResource::cudnn_handle_t, &resource);
if (status) {
ORT_CXX_API_THROW("failed to fetch cudnn handle", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
}
cudnn_handle = reinterpret_cast<cudnnHandle_t>(resource);

resource = {};
status = ort_api.KernelContext_GetResource(&kernel_ctx, ORT_CUDA_RESOUCE_VERSION, CudaResource::cublas_handle_t, &resource);
if (status) {
ORT_CXX_API_THROW("failed to fetch cublas handle", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
}
cublas_handle = reinterpret_cast<cublasHandle_t>(resource);
}
};

} // namespace Custom
} // namespace Ort
12 changes: 12 additions & 0 deletions include/onnxruntime/core/providers/cuda/cuda_resource.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/resource.h"

#define ORT_CUDA_RESOUCE_VERSION 1

enum CudaResource : int {
cuda_stream_t = cuda_resource_offset,
cudnn_handle_t,
cublas_handle_t
};
13 changes: 13 additions & 0 deletions include/onnxruntime/core/providers/custom_op_context.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include <core/session/onnxruntime_cxx_api.h>

// CustomOpContext defines an interface allowing a custom op to access ep-specific resources.
struct CustomOpContext {
CustomOpContext() = default;
virtual ~CustomOpContext(){};
virtual void Init(const OrtKernelContext&){};
};
14 changes: 14 additions & 0 deletions include/onnxruntime/core/providers/resource.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

enum ResourceOffset {
cpu_resource_offset = 0,
cuda_resource_offset = 10000,
dml_resource_offset = 20000,
rocm_resource_offset = 30000,
// offsets for other ort eps
custom_ep_resource_offset = 10000000,
// offsets for customized eps
};
49 changes: 49 additions & 0 deletions include/onnxruntime/core/providers/rocm/rocm_context.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#define ORT_ROCM_CTX

#include "rocm_resource.h"
#include "core/providers/custom_op_context.h"
#include <hip/hip_runtime.h>
#include <miopen/miopen.h>
#include <rocblas/rocblas.h>

namespace Ort {

namespace Custom {

struct RocmContext : public CustomOpContext {
hipStream_t hip_stream = {};
miopenHandle_t miopen_handle = {};
rocblas_handle rblas_handle = {};

void Init(const OrtKernelContext& kernel_ctx) override {
const auto& ort_api = Ort::GetApi();
void* resource = {};
OrtStatus* status = nullptr;

status = ort_api.KernelContext_GetResource(&kernel_ctx, ORT_ROCM_RESOUCE_VERSION, RocmResource::hip_stream_t, &resource);
if (status) {
ORT_CXX_API_THROW("failed to fetch hip stream", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
}
hip_stream = reinterpret_cast<hipStream_t>(resource);

resource = {};
status = ort_api.KernelContext_GetResource(&kernel_ctx, ORT_ROCM_RESOUCE_VERSION, RocmResource::miopen_handle_t, &resource);
if (status) {
ORT_CXX_API_THROW("failed to fetch miopen handle", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
}
miopen_handle = reinterpret_cast<miopenHandle_t>(resource);

resource = {};
status = ort_api.KernelContext_GetResource(&kernel_ctx, ORT_ROCM_RESOUCE_VERSION, RocmResource::rocblas_handle_t, &resource);
if (status) {
ORT_CXX_API_THROW("failed to fetch rocblas handle", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
}
rblas_handle = reinterpret_cast<rocblas_handle>(resource);
}
};

} // namespace Custom
} // namespace Ort
12 changes: 12 additions & 0 deletions include/onnxruntime/core/providers/rocm/rocm_resource.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/resource.h"

#define ORT_ROCM_RESOUCE_VERSION 1

enum RocmResource : int {
hip_stream_t = rocm_resource_offset,
miopen_handle_t,
rocblas_handle_t
};
15 changes: 13 additions & 2 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -4308,8 +4308,6 @@ struct OrtApi {
*/
void(ORT_API_CALL* ReleaseROCMProviderOptions)(_Frees_ptr_opt_ OrtROCMProviderOptions* input);

/// @}

/** \brief Create an allocator with specific type and register it with the ::OrtEnv
* This API enhance CreateAndRegisterAllocator that it can create an allocator with specific type, not just CPU allocator
* Enables sharing the allocator between multiple sessions that use the same env instance.
Expand Down Expand Up @@ -4398,6 +4396,19 @@ struct OrtApi {
* \since Version 1.16.
*/
ORT_API2_STATUS(GetCUDAProviderOptionsByName, _In_ const OrtCUDAProviderOptionsV2* cuda_options, _In_ const char* key, _Outptr_ void** ptr);

/**
* Get a EP resoure.
* E.g. a cuda stream or a cublas handle
*
* \param context - Kernel context
* \param resouce_version - Version of the resource
* \param resource_id - Type of resource
* \param resource - A pointer to returned resource
*
* \since Version 1.16.
*/
ORT_API2_STATUS(KernelContext_GetResource, _In_ const OrtKernelContext* context, _In_ int resouce_version, _In_ int resource_id, _Outptr_ void** resource);
};

/*
Expand Down
54 changes: 54 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_lite_custom_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,38 @@ struct OrtLiteCustomOp : public OrtCustomOp {
return std::tuple_cat(current, next);
}

template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
static typename std::enable_if<std::is_same<T, OrtKernelContext&>::value, std::tuple<T, Ts...>>::type
CreateTuple(OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
std::tuple<T> current = std::tuple<OrtKernelContext&>{*context};
auto next = CreateTuple<ith_input, ith_output, Ts...>(context, tensors, num_input, num_output, ep);
return std::tuple_cat(current, next);
}

#ifdef ORT_CUDA_CTX
template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
static typename std::enable_if<std::is_same<T, const CudaContext&>::value, std::tuple<T, Ts...>>::type
CreateTuple(OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
thread_local CudaContext cuda_context;
cuda_context.Init(*context);
std::tuple<T> current = std::tuple<const CudaContext&>{cuda_context};
auto next = CreateTuple<ith_input, ith_output, Ts...>(context, tensors, num_input, num_output, ep);
return std::tuple_cat(current, next);
}
#endif

#ifdef ORT_ROCM_CTX
template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
static typename std::enable_if<std::is_same<T, const RocmContext&>::value, std::tuple<T, Ts...>>::type
CreateTuple(OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
thread_local RocmContext rocm_context;
rocm_context.Init(*context);
std::tuple<T> current = std::tuple<const RocmContext&>{rocm_context};
auto next = CreateTuple<ith_input, ith_output, Ts...>(context, tensors, num_input, num_output, ep);
return std::tuple_cat(current, next);
}
#endif

#define CREATE_TUPLE_INPUT(data_type) \
template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
static typename std::enable_if<std::is_same<T, const Custom::Tensor<data_type>*>::value, std::tuple<T, Ts...>>::type \
Expand Down Expand Up @@ -437,6 +469,28 @@ struct OrtLiteCustomOp : public OrtCustomOp {
ParseArgs<Ts...>(input_types, output_types);
}

template <typename T, typename... Ts>
static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, OrtKernelContext&>::value>::type
ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
ParseArgs<Ts...>(input_types, output_types);
}

#ifdef ORT_CUDA_CTX
template <typename T, typename... Ts>
static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const CudaContext&>::value>::type
ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
ParseArgs<Ts...>(input_types, output_types);
}
#endif

#ifdef ORT_ROCM_CTX
template <typename T, typename... Ts>
static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const RocmContext&>::value>::type
ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
ParseArgs<Ts...>(input_types, output_types);
}
#endif

#define PARSE_INPUT_BASE(pack_type, onnx_type) \
template <typename T, typename... Ts> \
static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, pack_type>::value>::type \
Expand Down
21 changes: 20 additions & 1 deletion onnxruntime/core/providers/cuda/cuda_stream_handle.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/cuda/cuda_resource.h"
#include "core/providers/cuda/cuda_stream_handle.h"
#include "core/providers/cuda/cuda_common.h"
#include "core/common/spin_pause.h"
Expand Down Expand Up @@ -149,6 +149,25 @@ Status CudaStream::CleanUpOnRunEnd() {
return Status::OK();
}

void* CudaStream::GetResource(int version, int id) const {
ORT_ENFORCE(version <= ORT_CUDA_RESOUCE_VERSION, "resource version unsupported!");
void* resource{};
switch (id) {
case CudaResource::cuda_stream_t:
return reinterpret_cast<void*>(GetHandle());
break;
case CudaResource::cudnn_handle_t:
return reinterpret_cast<void*>(cudnn_handle_);
break;
case CudaResource::cublas_handle_t:
return reinterpret_cast<void*>(cublas_handle_);
break;
default:
break;
}
return resource;
}

// CPU Stream command handles
void WaitCudaNotificationOnDevice(Stream& stream, synchronize::Notification& notification) {
static_cast<CudaNotification*>(&notification)->wait_on_device(stream);
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/providers/cuda/cuda_stream_handle.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ struct CudaStream : Stream {

cublasHandle_t cublas_handle_{};

void* GetResource(int version, int id) const override;

private:
std::vector<void*> deferred_cpu_buffers_;
AllocatorPtr cpu_allocator_;
Expand Down
21 changes: 20 additions & 1 deletion onnxruntime/core/providers/rocm/rocm_stream_handle.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "core/providers/rocm/rocm_stream_handle.h"
#include "core/providers/rocm/rocm_common.h"
// #include "core/common/spin_pause.h"
#include "core/providers/rocm/rocm_resource.h"

namespace onnxruntime {

Expand Down Expand Up @@ -129,7 +130,25 @@ Status RocmStream::CleanUpOnRunEnd() {
return Status::OK();
}

// CPU Stream command handles
void* RocmStream::GetResource(int version, int type) const {
ORT_ENFORCE(version <= ORT_ROCM_RESOUCE_VERSION, "resource version unsupported!");
void* resource{};
switch (type) {
case RocmResource::hip_stream_t:
return reinterpret_cast<void*>(GetHandle());
break;
case RocmResource::miopen_handle_t:
return reinterpret_cast<void*>(miopen_handle_);
break;
case RocmResource::rocblas_handle_t:
return reinterpret_cast<void*>(rocblas_handle_);
break;
default:
break;
}
return resource;
}

void WaitRocmNotificationOnDevice(Stream& stream, synchronize::Notification& notification) {
static_cast<RocmNotification*>(&notification)->wait_on_device(stream);
}
Expand Down
Loading

0 comments on commit 3dd2c1b

Please sign in to comment.