Skip to content

Commit

Permalink
Merge pull request NVIDIA#708 from senior-zero/fix-main/github/cdp_wr…
Browse files Browse the repository at this point in the history
…apper_ctk_11.5

Fix CDP test wrapper for CTK 11.5
  • Loading branch information
gevtushenko authored Jun 2, 2023
2 parents caffe6c + e74bf12 commit 9ebcf69
Showing 1 changed file with 36 additions and 22 deletions.
58 changes: 36 additions & 22 deletions test/catch2_test_cdp_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,21 +63,6 @@
#error Test file should contain %PARAM% TEST_CDP cdp 0:1
#endif

#define DECLARE_CDP_INVOCABLE(API, WRAPPED_API_NAME) \
namespace \
{ \
struct WRAPPED_API_NAME##_invocable_t \
{ \
template <class... Ts> \
CUB_RUNTIME_FUNCTION cudaError_t operator()(std::uint8_t *d_temp_storage, \
std::size_t &temp_storage_bytes, \
Ts... args) \
{ \
return API(d_temp_storage, temp_storage_bytes, args...); \
} \
}; \
}

#if TEST_CDP == 1
template <class ActionT, class... Args>
__global__ void device_side_api_launch_kernel(std::uint8_t *d_temp_storage,
Expand Down Expand Up @@ -119,6 +104,26 @@ void device_side_api_launch(ActionT action, Args... args)

#define cdp_launch device_side_api_launch

#define DECLARE_CDP_INVOCABLE(API, WRAPPED_API_NAME) \
struct WRAPPED_API_NAME##_device_invocable_t \
{ \
template <class... Ts> \
CUB_RUNTIME_FUNCTION cudaError_t operator()(std::uint8_t *d_temp_storage, \
std::size_t &temp_storage_bytes, \
Ts... args) \
{ \
return API(d_temp_storage, temp_storage_bytes, args...); \
} \
};

#define DECLARE_CDP_WRAPPER(API, WRAPPED_API_NAME) \
DECLARE_CDP_INVOCABLE(API, WRAPPED_API_NAME); \
template <class... As> \
static void WRAPPED_API_NAME(As... args) \
{ \
cdp_launch(WRAPPED_API_NAME##_device_invocable_t{}, args...); \
}

#else

template <class ActionT, class... Args>
Expand All @@ -142,15 +147,24 @@ void host_side_api_launch(ActionT action, Args... args)

#define cdp_launch host_side_api_launch

#endif
#define DECLARE_CDP_INVOCABLE(API, WRAPPED_API_NAME) \
struct WRAPPED_API_NAME##_host_invocable_t \
{ \
template <class... Ts> \
CUB_RUNTIME_FUNCTION cudaError_t operator()(std::uint8_t *d_temp_storage, \
std::size_t &temp_storage_bytes, \
Ts... args) \
{ \
return API(d_temp_storage, temp_storage_bytes, args...); \
} \
};

#define DECLARE_CDP_WRAPPER(API, WRAPPED_API_NAME) \
DECLARE_CDP_INVOCABLE(API, WRAPPED_API_NAME); \
namespace \
{ \
template <class... As> \
void WRAPPED_API_NAME(As... args) \
static void WRAPPED_API_NAME(As... args) \
{ \
cdp_launch(WRAPPED_API_NAME##_invocable_t{}, args...); \
} \
}
cdp_launch(WRAPPED_API_NAME##_host_invocable_t{}, args...); \
}

#endif

0 comments on commit 9ebcf69

Please sign in to comment.