forked from microsoft/onnxruntime
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
EP context for custom op (microsoft#16454)
Implement infrastructures to allow EP resources surfaced to custom ops. --------- Co-authored-by: Randy Shuai <[email protected]>
- Loading branch information
1 parent
7b9d1f1
commit 3dd2c1b
Showing
26 changed files
with
669 additions
and
247 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
#pragma once | ||
|
||
#define ORT_CUDA_CTX | ||
|
||
#include "cuda_resource.h" | ||
#include "core/providers/custom_op_context.h" | ||
#include <cuda.h> | ||
#include <cuda_runtime.h> | ||
#include <cublas_v2.h> | ||
#include <cudnn.h> | ||
|
||
namespace Ort { | ||
|
||
namespace Custom { | ||
|
||
struct CudaContext : public CustomOpContext { | ||
cudaStream_t cuda_stream = {}; | ||
cudnnHandle_t cudnn_handle = {}; | ||
cublasHandle_t cublas_handle = {}; | ||
|
||
void Init(const OrtKernelContext& kernel_ctx) override { | ||
const auto& ort_api = Ort::GetApi(); | ||
void* resource = {}; | ||
OrtStatus* status = nullptr; | ||
|
||
status = ort_api.KernelContext_GetResource(&kernel_ctx, ORT_CUDA_RESOUCE_VERSION, CudaResource::cuda_stream_t, &resource); | ||
if (status) { | ||
ORT_CXX_API_THROW("failed to fetch cuda stream", OrtErrorCode::ORT_RUNTIME_EXCEPTION); | ||
} | ||
cuda_stream = reinterpret_cast<cudaStream_t>(resource); | ||
|
||
resource = {}; | ||
status = ort_api.KernelContext_GetResource(&kernel_ctx, ORT_CUDA_RESOUCE_VERSION, CudaResource::cudnn_handle_t, &resource); | ||
if (status) { | ||
ORT_CXX_API_THROW("failed to fetch cudnn handle", OrtErrorCode::ORT_RUNTIME_EXCEPTION); | ||
} | ||
cudnn_handle = reinterpret_cast<cudnnHandle_t>(resource); | ||
|
||
resource = {}; | ||
status = ort_api.KernelContext_GetResource(&kernel_ctx, ORT_CUDA_RESOUCE_VERSION, CudaResource::cublas_handle_t, &resource); | ||
if (status) { | ||
ORT_CXX_API_THROW("failed to fetch cublas handle", OrtErrorCode::ORT_RUNTIME_EXCEPTION); | ||
} | ||
cublas_handle = reinterpret_cast<cublasHandle_t>(resource); | ||
} | ||
}; | ||
|
||
} // namespace Custom | ||
} // namespace Ort |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#include "core/providers/resource.h" | ||
|
||
#define ORT_CUDA_RESOUCE_VERSION 1 | ||
|
||
enum CudaResource : int { | ||
cuda_stream_t = cuda_resource_offset, | ||
cudnn_handle_t, | ||
cublas_handle_t | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#pragma once | ||
|
||
#include <core/session/onnxruntime_cxx_api.h> | ||
|
||
// CustomOpContext defines an interface allowing a custom op to access ep-specific resources. | ||
struct CustomOpContext { | ||
CustomOpContext() = default; | ||
virtual ~CustomOpContext(){}; | ||
virtual void Init(const OrtKernelContext&){}; | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#pragma once | ||
|
||
enum ResourceOffset { | ||
cpu_resource_offset = 0, | ||
cuda_resource_offset = 10000, | ||
dml_resource_offset = 20000, | ||
rocm_resource_offset = 30000, | ||
// offsets for other ort eps | ||
custom_ep_resource_offset = 10000000, | ||
// offsets for customized eps | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#define ORT_ROCM_CTX | ||
|
||
#include "rocm_resource.h" | ||
#include "core/providers/custom_op_context.h" | ||
#include <hip/hip_runtime.h> | ||
#include <miopen/miopen.h> | ||
#include <rocblas/rocblas.h> | ||
|
||
namespace Ort { | ||
|
||
namespace Custom { | ||
|
||
struct RocmContext : public CustomOpContext { | ||
hipStream_t hip_stream = {}; | ||
miopenHandle_t miopen_handle = {}; | ||
rocblas_handle rblas_handle = {}; | ||
|
||
void Init(const OrtKernelContext& kernel_ctx) override { | ||
const auto& ort_api = Ort::GetApi(); | ||
void* resource = {}; | ||
OrtStatus* status = nullptr; | ||
|
||
status = ort_api.KernelContext_GetResource(&kernel_ctx, ORT_ROCM_RESOUCE_VERSION, RocmResource::hip_stream_t, &resource); | ||
if (status) { | ||
ORT_CXX_API_THROW("failed to fetch hip stream", OrtErrorCode::ORT_RUNTIME_EXCEPTION); | ||
} | ||
hip_stream = reinterpret_cast<hipStream_t>(resource); | ||
|
||
resource = {}; | ||
status = ort_api.KernelContext_GetResource(&kernel_ctx, ORT_ROCM_RESOUCE_VERSION, RocmResource::miopen_handle_t, &resource); | ||
if (status) { | ||
ORT_CXX_API_THROW("failed to fetch miopen handle", OrtErrorCode::ORT_RUNTIME_EXCEPTION); | ||
} | ||
miopen_handle = reinterpret_cast<miopenHandle_t>(resource); | ||
|
||
resource = {}; | ||
status = ort_api.KernelContext_GetResource(&kernel_ctx, ORT_ROCM_RESOUCE_VERSION, RocmResource::rocblas_handle_t, &resource); | ||
if (status) { | ||
ORT_CXX_API_THROW("failed to fetch rocblas handle", OrtErrorCode::ORT_RUNTIME_EXCEPTION); | ||
} | ||
rblas_handle = reinterpret_cast<rocblas_handle>(resource); | ||
} | ||
}; | ||
|
||
} // namespace Custom | ||
} // namespace Ort |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#include "core/providers/resource.h" | ||
|
||
#define ORT_ROCM_RESOUCE_VERSION 1 | ||
|
||
enum RocmResource : int { | ||
hip_stream_t = rocm_resource_offset, | ||
miopen_handle_t, | ||
rocblas_handle_t | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.