diff --git a/src/runtime_src/hip/api/hip_context.cpp b/src/runtime_src/hip/api/hip_context.cpp index 2a7781fc050..d2bed7710e0 100644 --- a/src/runtime_src/hip/api/hip_context.cpp +++ b/src/runtime_src/hip/api/hip_context.cpp @@ -1,9 +1,6 @@ // SPDX-License-Identifier: Apache-2.0 // Copyright (C) 2023-2024 Advanced Micro Device, Inc. All rights reserved. -#include "hip/config.h" -#include "hip/hip_runtime_api.h" - #include "hip/core/common.h" #include "hip/core/context.h" #include "hip/core/device.h" @@ -23,27 +20,24 @@ static context_handle hip_ctx_create(unsigned int flags, hipDevice_t device) { auto hip_dev = device_cache.get(static_cast(device)); - if (!hip_dev) - throw xrt_core::system_error(hipErrorInvalidValue, "device requested is not available"); + throw_invalid_value_if(!hip_dev, "device requested is not available"); hip_dev->set_flags(flags); auto hip_ctx = std::make_shared(hip_dev); tls_objs.ctx_stack.push(hip_ctx); // make it current - context_cache.add(hip_ctx.get(), std::move(hip_ctx)); - return hip_ctx.get(); + tls_objs.dev_hdl = device; + // insert handle in ctx map and return handle + return insert_in_map(context_cache, std::move(hip_ctx)); } static void hip_ctx_destroy(hipCtx_t ctx) { auto handle = reinterpret_cast(ctx); - if (!handle) { - throw xrt_core::system_error(hipErrorInvalidValue, "device requested is not available"); - } + throw_invalid_value_if(!handle, "device requested is not available"); auto hip_ctx = context_cache.get(handle); - if (!hip_ctx) - throw xrt_core::system_error(hipErrorInvalidValue, "context handle not found"); + throw_invalid_value_if(!hip_ctx, "context handle not found"); // Need to remove the ctx of calling thread if its the top one if (!tls_objs.ctx_stack.empty() && tls_objs.ctx_stack.top().lock() == hip_ctx) { @@ -57,8 +51,7 @@ static device_handle hip_ctx_get_device() { auto ctx = get_current_context(); - if (!ctx) - throw xrt_core::system_error(hipErrorInvalidValue, "Error retrieving context"); + throw_invalid_value_if(!ctx, "Error retrieving context"); return ctx->get_dev_id(); } @@ -74,8 +67,10 @@ hip_ctx_set_current(hipCtx_t ctx) auto handle = reinterpret_cast(ctx); auto hip_ctx = context_cache.get(handle); - if (hip_ctx) + if (hip_ctx) { tls_objs.ctx_stack.push(hip_ctx); + tls_objs.dev_hdl = hip_ctx->get_dev_id(); + } } // remove primary ctx as active @@ -85,8 +80,7 @@ hip_device_primary_ctx_release(hipDevice_t dev) { auto dev_hdl = static_cast(dev); auto hip_dev = device_cache.get(dev_hdl); - if (!hip_dev) - throw xrt_core::system_error(hipErrorInvalidDevice, "Invalid device"); + throw_invalid_device_if(!hip_dev, "Invalid device"); auto ctx = hip_dev->get_pri_ctx(); if (!ctx) @@ -97,21 +91,19 @@ hip_device_primary_ctx_release(hipDevice_t dev) auto ctx_hdl = reinterpret_cast(std::hash{}(std::this_thread::get_id())); context_cache.remove(ctx_hdl); - if (tls_objs.pri_ctx_info.active && tls_objs.pri_ctx_info.dev_hdl == dev_hdl) { + if (tls_objs.pri_ctx_info.active && tls_objs.dev_hdl == dev_hdl) { tls_objs.pri_ctx_info.active = false; - tls_objs.pri_ctx_info.dev_hdl = UINT32_MAX; tls_objs.pri_ctx_info.ctx_hdl = nullptr; } } // create primary context on given device if not already present // else increment reference count -static context_handle -hip_device_primary_ctx_retain(hipDevice_t dev) +context_handle +hip_device_primary_ctx_retain(device_handle dev) { auto hip_dev = device_cache.get(dev); - if (!hip_dev) - throw xrt_core::system_error(hipErrorInvalidDevice, "Invalid device"); + throw_invalid_device_if(!hip_dev, "Invalid device"); auto hip_ctx = hip_dev->get_pri_ctx(); // create primary context @@ -123,13 +115,12 @@ hip_device_primary_ctx_retain(hipDevice_t dev) // unqiue handle, using thread id here as primary context is unique per thread auto ctx_hdl = reinterpret_cast(std::hash{}(std::this_thread::get_id())); - auto handle = hip_ctx.get(); context_cache.add(ctx_hdl, std::move(hip_ctx)); tls_objs.pri_ctx_info.active = true; tls_objs.pri_ctx_info.ctx_hdl = ctx_hdl; - tls_objs.pri_ctx_info.dev_hdl = dev; - return handle; + tls_objs.dev_hdl = dev; + return ctx_hdl; } } // xrt::core::hip @@ -139,8 +130,7 @@ hipError_t hipCtxCreate(hipCtx_t* ctx, unsigned int flags, hipDevice_t device) { try { - if (!ctx) - throw xrt_core::system_error(hipErrorInvalidValue, "ctx passed is nullptr"); + throw_invalid_value_if(!ctx, "ctx passed is nullptr"); auto handle = xrt::core::hip::hip_ctx_create(flags, device); *ctx = reinterpret_cast(handle); @@ -177,8 +167,7 @@ hipError_t hipCtxGetDevice(hipDevice_t* device) { try { - if (!device) - throw xrt_core::system_error(hipErrorInvalidValue, "device passed is nullptr"); + throw_invalid_value_if(!device, "device passed is nullptr"); *device = xrt::core::hip::hip_ctx_get_device(); return hipSuccess; @@ -214,8 +203,7 @@ hipError_t hipDevicePrimaryCtxRetain(hipCtx_t* pctx, hipDevice_t dev) { try { - if (!pctx) - throw xrt_core::system_error(hipErrorInvalidValue, "nullptr passed"); + throw_invalid_value_if(!pctx, "nullptr passed"); auto handle = xrt::core::hip::hip_device_primary_ctx_retain(dev); *pctx = reinterpret_cast(handle); @@ -247,3 +235,4 @@ hipDevicePrimaryCtxRelease(hipDevice_t dev) } return hipErrorUnknown; } + diff --git a/src/runtime_src/hip/api/hip_device.cpp b/src/runtime_src/hip/api/hip_device.cpp index 4fffc84cac9..e133d0f546d 100644 --- a/src/runtime_src/hip/api/hip_device.cpp +++ b/src/runtime_src/hip/api/hip_device.cpp @@ -1,12 +1,9 @@ // SPDX-License-Identifier: Apache-2.0 // Copyright (C) 2023-2024 Advanced Micro Device, Inc. All rights reserved. -#include "core/common/error.h" #include "core/include/experimental/xrt_system.h" -#include "hip/config.h" -#include "hip/hip_runtime_api.h" - +#include "hip/core/common.h" #include "hip/core/device.h" #include @@ -50,16 +47,18 @@ device_init() auto dev = std::make_shared(i); device_cache.add(i, std::move(dev)); } + // make first device as default device + if (dev_count > 0) + tls_objs.dev_hdl = static_cast(0); } static void hip_init(unsigned int flags) { // Flags should be zero as per Hip doc - if (flags != 0) - throw xrt_core::system_error(hipErrorInvalidValue, "non zero flags passed to hipinit"); + throw_invalid_value_if(flags != 0, "non zero flags passed to hipinit"); - // call device_init function, device enumeration might not have happened + // call device_init function, device enumeration might not have happened // at library load because of some exception // std::once_flag ensures init is called only once std::call_once(device_init_flag, xrt::core::hip::device_init); @@ -71,19 +70,23 @@ hip_get_device_count() // Get device count auto count = xrt::core::hip::device_cache.size(); - if (count < 1) - throw xrt_core::system_error(hipErrorNoDevice, "No valid device available"); + throw_if(count < 1, hipErrorNoDevice, "No valid device available"); return count; } +inline bool +check(int dev_id) +{ + return (dev_id < 0 || device_cache.count(static_cast(dev_id)) == 0); +} + // Returns a handle to compute device // Throws on error static int hip_device_get(int ordinal) { - if (ordinal < 0 || device_cache.count(static_cast(ordinal)) == 0) - throw xrt_core::system_error(hipErrorInvalidDevice, "device requested is not available"); + throw_invalid_device_if(check(ordinal), "device requested is not available"); return ordinal; } @@ -91,8 +94,7 @@ hip_device_get(int ordinal) static std::string hip_device_get_name(hipDevice_t device) { - if (device < 0 || xrt::core::hip::device_cache.count(static_cast(device)) == 0) - throw xrt_core::system_error(hipErrorInvalidDevice, " - device requested is not available"); + throw_invalid_device_if(check(device), "device requested is not available"); throw std::runtime_error("Not implemented"); } @@ -100,8 +102,7 @@ hip_device_get_name(hipDevice_t device) static hipDeviceProp_t hip_get_device_properties(hipDevice_t device) { - if (device < 0 || xrt::core::hip::device_cache.count(static_cast(device)) == 0) - throw xrt_core::system_error(hipErrorInvalidDevice, "device requested is not available"); + throw_invalid_device_if(check(device), "device requested is not available"); throw std::runtime_error("Not implemented"); } @@ -109,8 +110,7 @@ hip_get_device_properties(hipDevice_t device) static hipUUID hip_device_get_uuid(hipDevice_t device) { - if (device < 0 || xrt::core::hip::device_cache.count(static_cast(device)) == 0) - throw xrt_core::system_error(hipErrorInvalidDevice, "device requested is not available"); + throw_invalid_device_if(check(device), "device requested is not available"); throw std::runtime_error("Not implemented"); } @@ -118,8 +118,7 @@ hip_device_get_uuid(hipDevice_t device) static int hip_device_get_attribute(hipDeviceAttribute_t attr, int device) { - if (device < 0 || xrt::core::hip::device_cache.count(static_cast(device)) == 0) - throw xrt_core::system_error(hipErrorInvalidDevice, "device requested is not available"); + throw_invalid_device_if(check(device), "device requested is not available"); throw std::runtime_error("Not implemented"); } @@ -148,8 +147,7 @@ hipError_t hipGetDeviceCount(int* count) { try { - if (!count) - throw xrt_core::system_error(hipErrorInvalidValue, "arg passed is nullptr"); + throw_invalid_value_if(!count, "arg passed is nullptr"); *count = xrt::core::hip::hip_get_device_count(); return hipSuccess; @@ -168,8 +166,7 @@ hipError_t hipDeviceGet(hipDevice_t* device, int ordinal) { try { - if (!device) - throw xrt_core::system_error(hipErrorInvalidValue, "device is nullptr"); + throw_invalid_value_if(!device, "device is nullptr"); *device = xrt::core::hip::hip_device_get(ordinal); return hipSuccess; @@ -188,8 +185,7 @@ hipError_t hipDeviceGetName(char* name, int len, hipDevice_t device) { try { - if (!name || len <= 0) - throw xrt_core::system_error(hipErrorInvalidValue, "invalid arg"); + throw_invalid_value_if((!name || len <= 0), "invalid arg"); auto name_str = xrt::core::hip::hip_device_get_name(device); // Only copy partial name if size of `dest` is smaller than size of `src` including @@ -213,8 +209,7 @@ hipError_t hipGetDeviceProperties(hipDeviceProp_t* props, hipDevice_t device) { try { - if (!props) - throw xrt_core::system_error(hipErrorInvalidValue, "arg passed is nullptr"); + throw_invalid_value_if(!props, "arg passed is nullptr"); *props = xrt::core::hip::hip_get_device_properties(device); return hipSuccess; @@ -233,8 +228,7 @@ hipError_t hipDeviceGetUuid(hipUUID* uuid, hipDevice_t device) { try { - if (!uuid) - throw xrt_core::system_error(hipErrorInvalidValue, "arg passed is nullptr"); + throw_invalid_value_if(!uuid, "arg passed is nullptr"); *uuid = xrt::core::hip::hip_device_get_uuid(device); return hipSuccess; @@ -253,8 +247,7 @@ hipError_t hipDeviceGetAttribute(int* pi, hipDeviceAttribute_t attr, int device) { try { - if (!pi) - throw xrt_core::system_error(hipErrorInvalidValue, "arg passed is nullptr"); + throw_invalid_value_if(!pi, "arg passed is nullptr"); *pi = xrt::core::hip::hip_device_get_attribute(attr, device); return hipSuccess; diff --git a/src/runtime_src/hip/api/hip_module.cpp b/src/runtime_src/hip/api/hip_module.cpp index ec9b2e6b555..b4df50a4781 100644 --- a/src/runtime_src/hip/api/hip_module.cpp +++ b/src/runtime_src/hip/api/hip_module.cpp @@ -1,64 +1,73 @@ // SPDX-License-Identifier: Apache-2.0 // Copyright (C) 2023-2024 Advanced Micro Device, Inc. All rights reserved. -#include "core/common/error.h" - -#include "hip/config.h" -#include "hip/hip_runtime_api.h" - +#include "hip/core/common.h" #include "hip/core/module.h" namespace xrt::core::hip { static void hip_module_launch_kernel(hipFunction_t f, uint32_t gridDimX, uint32_t gridDimY, - uint32_t gridDimZ, uint32_t blockDimX, uint32_t blockDimY, - uint32_t blockDimZ, uint32_t sharedMemBytes, hipStream_t hStream, - void** kernelParams, void** extra) + uint32_t gridDimZ, uint32_t blockDimX, uint32_t blockDimY, + uint32_t blockDimZ, uint32_t sharedMemBytes, hipStream_t hStream, + void** kernelParams, void** extra) { - if (!f) - throw xrt_core::system_error(hipErrorInvalidResourceHandle, "function is nullptr"); + throw_invalid_resource_if(!f, "function is nullptr"); + + auto func_hdl = reinterpret_cast(f); + auto hip_mod = module_cache.get(static_cast(func_hdl)->get_module()); + throw_invalid_resource_if(!hip_mod, "module associated with function is unloaded"); + + auto hip_func = hip_mod->get_function(func_hdl); + throw_invalid_resource_if(!hip_func, "invalid function passed"); throw std::runtime_error("Not implemented"); } -static hipFunction_t +static function_handle hip_module_get_function(hipModule_t hmod, const char* name) { - if (!name || strlen(name) == 0) - throw xrt_core::system_error(hipErrorInvalidValue, "name is invalid"); + throw_invalid_value_if((!name || strlen(name) == 0), "name is invalid"); - if (!hmod) - throw xrt_core::system_error(hipErrorInvalidResourceHandle, "module is nullptr"); + throw_invalid_resource_if(!hmod, "module is nullptr"); - throw std::runtime_error("Not implemented"); + auto mod_hdl = reinterpret_cast(hmod); + auto hip_mod = module_cache.get(mod_hdl); + throw_invalid_resource_if(!hip_mod, "module not available"); + + // create function obj and store in map maintained by module + return hip_mod->add_function(std::make_shared(mod_hdl, std::string(name))); } -static void -hip_module_load_data_ex(hipModule_t* module, const void* image, unsigned int numOptions, - hipJitOption* options, void** optionsValues) +static module_handle +create_module(const void* image) { - if (!module) - throw xrt_core::system_error(hipErrorInvalidResourceHandle, "module is nullptr"); - - throw std::runtime_error("Not implemented"); + auto ctx = get_current_context(); + // create module and store it in module map + return insert_in_map(module_cache, std::make_shared(ctx, const_cast(image))); } -static void -hip_module_load_data(hipModule_t* module, const void* image) +static module_handle +hip_module_load_data_ex(const void* image, unsigned int /*numOptions*/, + hipJitOption* /*options*/, void** /*optionsValues*/) { - if (!module) - throw xrt_core::system_error(hipErrorInvalidResourceHandle, "module is nullptr"); + // Jit options are ignored for now + return create_module(image); +} - throw std::runtime_error("Not implemented"); +// image is mapped address of program to be loaded +static module_handle +hip_module_load_data(const void* image) +{ + return create_module(image); } static void hip_module_unload(hipModule_t hmod) { - if (!hmod) - throw xrt_core::system_error(hipErrorInvalidResourceHandle, "module is nullptr"); + throw_invalid_resource_if(!hmod, "module is nullptr"); - throw std::runtime_error("Not implemented"); + auto handle = reinterpret_cast(hmod); + module_cache.remove(handle); } static void @@ -78,7 +87,7 @@ hipModuleLaunchKernel(hipFunction_t f, uint32_t gridDimX, uint32_t gridDimY, { try { xrt::core::hip::hip_module_launch_kernel(f, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, - blockDimZ, sharedMemBytes, hStream, kernelParams, extra); + blockDimZ, sharedMemBytes, hStream, kernelParams, extra); return hipSuccess; } catch (const xrt_core::system_error& ex) { @@ -95,10 +104,10 @@ hipError_t hipModuleGetFunction(hipFunction_t* hfunc, hipModule_t hmod, const char* name) { try { - if (!hfunc) - throw xrt_core::system_error(hipErrorInvalidHandle, "function passed is nullptr"); + throw_invalid_handle_if(!hfunc, "function passed is nullptr"); - *hfunc = xrt::core::hip::hip_module_get_function(hmod, name); + auto handle = xrt::core::hip::hip_module_get_function(hmod, name); + *hfunc = reinterpret_cast(handle); return hipSuccess; } catch (const xrt_core::system_error& ex) { @@ -116,7 +125,11 @@ hipModuleLoadDataEx(hipModule_t* module, const void* image, unsigned int numOpti hipJitOption* options, void** optionsValues) { try { - xrt::core::hip::hip_module_load_data_ex(module, image, numOptions, options, optionsValues); + throw_invalid_resource_if(!module, "module is nullptr"); + + auto handle = xrt::core::hip:: + hip_module_load_data_ex(image, numOptions, options, optionsValues); + *module = reinterpret_cast(handle); return hipSuccess; } catch (const xrt_core::system_error& ex) { @@ -133,7 +146,10 @@ hipError_t hipModuleLoadData(hipModule_t* module, const void* image) { try { - xrt::core::hip::hip_module_load_data(module, image); + throw_invalid_resource_if(!module, "module is nullptr"); + + auto handle = xrt::core::hip::hip_module_load_data(image); + *module = reinterpret_cast(handle); return hipSuccess; } catch (const xrt_core::system_error& ex) { @@ -179,3 +195,4 @@ hipFuncSetAttribute(const void* func, hipFuncAttribute attr, int value) } return hipErrorUnknown; } + diff --git a/src/runtime_src/hip/api/hip_stream.cpp b/src/runtime_src/hip/api/hip_stream.cpp index 05e611c2cd4..16f377aab3f 100644 --- a/src/runtime_src/hip/api/hip_stream.cpp +++ b/src/runtime_src/hip/api/hip_stream.cpp @@ -1,11 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // Copyright (C) 2023-2024 Advanced Micro Device, Inc. All rights reserved. -#include "core/common/error.h" - -#include "hip/config.h" -#include "hip/hip_runtime_api.h" - +#include "hip/core/common.h" #include "hip/core/stream.h" namespace xrt::core::hip { @@ -18,8 +14,7 @@ hip_stream_create_with_flags(unsigned int flags) static void hip_stream_destroy(hipStream_t stream) { - if (!stream) - throw xrt_core::system_error(hipErrorInvalidHandle, "stream is nullptr"); + throw_invalid_handle_if(!stream, "stream is nullptr"); throw std::runtime_error("Not implemented"); } @@ -33,8 +28,7 @@ hip_stream_synchronize(hipStream_t stream) static void hip_stream_wait_event(hipStream_t stream, hipEvent_t event, unsigned int flags) { - if (!event) - throw xrt_core::system_error(hipErrorInvalidHandle, "event is nullptr"); + throw_invalid_handle_if(!event, "event is nullptr"); throw std::runtime_error("Not implemented"); } @@ -46,8 +40,7 @@ hipError_t hipStreamCreateWithFlags(hipStream_t* stream, unsigned int flags) { try { - if (!stream) - throw xrt_core::system_error(hipErrorInvalidValue, "stream passed is nullptr"); + throw_invalid_value_if(!stream, "stream passed is nullptr"); *stream = xrt::core::hip::hip_stream_create_with_flags(flags); return hipSuccess; @@ -112,3 +105,4 @@ hipStreamWaitEvent(hipStream_t stream, hipEvent_t event, unsigned int flags) } return hipErrorUnknown; } + diff --git a/src/runtime_src/hip/core/CMakeLists.txt b/src/runtime_src/hip/core/CMakeLists.txt index 9d83304514f..7c8edd11427 100644 --- a/src/runtime_src/hip/core/CMakeLists.txt +++ b/src/runtime_src/hip/core/CMakeLists.txt @@ -4,6 +4,8 @@ add_library(hip_core_library_objects OBJECT context.cpp device.cpp memory.cpp + module.cpp + stream.cpp ) target_include_directories(hip_core_library_objects diff --git a/src/runtime_src/hip/core/common.h b/src/runtime_src/hip/core/common.h index 83e59876084..b12ff6c669f 100644 --- a/src/runtime_src/hip/core/common.h +++ b/src/runtime_src/hip/core/common.h @@ -3,6 +3,11 @@ #ifndef xrthip_common_h #define xrthip_common_h +#include "core/common/error.h" + +#include "hip/config.h" +#include "hip/hip_runtime_api.h" + #include "context.h" #include "device.h" @@ -13,17 +18,63 @@ namespace xrt::core::hip { struct ctx_info { context_handle ctx_hdl{nullptr}; - device_handle dev_hdl{std::numeric_limits::max()}; bool active{false}; }; // thread local hip objects struct hip_tls_objs { + device_handle dev_hdl{std::numeric_limits::max()}; std::stack> ctx_stack; ctx_info pri_ctx_info; }; extern thread_local hip_tls_objs tls_objs; + +// generic function for adding shared_ptr to handle_map +// {key , value} -> {shared_ptr.get(), shared_ptr} +// returns void* (handle returned to application) +template +inline void* +insert_in_map(map& m, value&& v) +{ + auto handle = v.get(); + m.add(handle, std::move(v)); + return handle; +} } // xrt::core::hip +namespace { +// common functions for throwing hip errors +inline void +throw_if(bool check, hipError_t err, const std::string& err_msg) +{ + if (check) + throw xrt_core::system_error(err, err_msg); +} + +inline void +throw_invalid_value_if(bool check, const std::string& err_msg) +{ + throw_if(check, hipErrorInvalidValue, err_msg); +} + +inline void +throw_invalid_handle_if(bool check, const std::string& err_msg) +{ + throw_if(check, hipErrorInvalidHandle, err_msg); +} + +inline void +throw_invalid_device_if(bool check, const std::string& err_msg) +{ + throw_if(check, hipErrorInvalidDevice, err_msg); +} + +inline void +throw_invalid_resource_if(bool check, const std::string& err_msg) +{ + throw_if(check, hipErrorInvalidResourceHandle, err_msg); +} +} #endif + diff --git a/src/runtime_src/hip/core/context.cpp b/src/runtime_src/hip/core/context.cpp index 44e9b2e79ad..875acb8f6d7 100644 --- a/src/runtime_src/hip/core/context.cpp +++ b/src/runtime_src/hip/core/context.cpp @@ -20,7 +20,8 @@ thread_local hip_tls_objs tls_objs; // returns current context // if primary context is active it is current // else returns top of ctx stack -// this function can return null if no context is active +// this function returns primary ctx on active device if +// no context is active std::shared_ptr get_current_context() { @@ -37,6 +38,12 @@ get_current_context() tls_objs.ctx_stack.pop(); } - return ctx; -} + if (ctx) + return ctx; + + // if no active ctx, create primary ctx on active device + auto ctx_hdl = hip_device_primary_ctx_retain(tls_objs.dev_hdl); + return context_cache.get(ctx_hdl); } +} // xrt::core::hip + diff --git a/src/runtime_src/hip/core/context.h b/src/runtime_src/hip/core/context.h index 6ae85c9f8c1..e7145fcd514 100644 --- a/src/runtime_src/hip/core/context.h +++ b/src/runtime_src/hip/core/context.h @@ -15,6 +15,8 @@ class context std::shared_ptr m_device; public: + context() = default; + context(std::shared_ptr device); uint32_t @@ -22,12 +24,22 @@ class context { return m_device->get_device_id(); } + + const xrt::device& + get_xrt_device() const + { + return m_device->get_xrt_device(); + } }; std::shared_ptr get_current_context(); +context_handle +hip_device_primary_ctx_retain(device_handle dev); + extern xrt_core::handle_map> context_cache; } // xrt::core::hip #endif + diff --git a/src/runtime_src/hip/core/device.cpp b/src/runtime_src/hip/core/device.cpp index e79d26f852f..c47b14d4238 100644 --- a/src/runtime_src/hip/core/device.cpp +++ b/src/runtime_src/hip/core/device.cpp @@ -13,3 +13,4 @@ device(uint32_t device_id) , m_xrt_device{device_id} {} } + diff --git a/src/runtime_src/hip/core/device.h b/src/runtime_src/hip/core/device.h index 6537f8a4c93..e0f630e3de0 100644 --- a/src/runtime_src/hip/core/device.h +++ b/src/runtime_src/hip/core/device.h @@ -25,6 +25,8 @@ class device std::weak_ptr pri_ctx; public: + device() = default; + explicit device(uint32_t device_id); @@ -64,3 +66,4 @@ extern xrt_core::handle_map> device_cache } // xrt::core::hip #endif + diff --git a/src/runtime_src/hip/core/module.cpp b/src/runtime_src/hip/core/module.cpp index 24cd7041603..c0e4db9495d 100644 --- a/src/runtime_src/hip/core/module.cpp +++ b/src/runtime_src/hip/core/module.cpp @@ -1,14 +1,60 @@ // SPDX-License-Identifier: Apache-2.0 // Copyright (C) 2024 Advanced Micro Device, Inc. All rights reserved. +#include "hip/config.h" +#include "hip/hip_runtime_api.h" + #include "module.h" namespace xrt::core::hip { +void module:: -module(std::shared_ptr ctx) +create_hw_context() +{ + auto xrt_dev = m_ctx->get_xrt_device(); + auto uuid = xrt_dev.register_xclbin(m_xclbin); + m_hw_ctx = xrt::hw_context{xrt_dev, uuid}; +} + +module:: +module(std::shared_ptr ctx, const std::string& file_name) : m_ctx{std::move(ctx)} -{} +{ + m_xclbin = xrt::xclbin{file_name}; + create_hw_context(); +} + +module:: +module(std::shared_ptr ctx, void* image) + : m_ctx{std::move(ctx)} +{ + // we trust pointer sent by application and treat + // it as xclbin data. Application can crash/seg fault + // when improper data is passed + m_xclbin = xrt::xclbin{static_cast(image)}; + create_hw_context(); +} -// Global map of streams +xrt::kernel +module:: +create_kernel(std::string& name) +{ + return xrt::kernel{m_hw_ctx, name}; +} + +function:: +function(module_handle mod_hdl, std::string&& name) + : m_module{static_cast(mod_hdl)} + , m_func_name{std::move(name)} +{ + if (!module_cache.count(mod_hdl)) + throw xrt_core::system_error(hipErrorInvalidResourceHandle, "module not available"); + + + m_kernel = m_module->create_kernel(m_func_name); +} + +// Global map of modules xrt_core::handle_map> module_cache; } + diff --git a/src/runtime_src/hip/core/module.h b/src/runtime_src/hip/core/module.h index 57c4b9085c4..8e0423fdaab 100644 --- a/src/runtime_src/hip/core/module.h +++ b/src/runtime_src/hip/core/module.h @@ -3,22 +3,72 @@ #ifndef xrthip_module_h #define xrthip_module_h +#include "common.h" #include "context.h" +#include "xrt/xrt_hw_context.h" +#include "xrt/xrt_kernel.h" namespace xrt::core::hip { // module_handle - opaque module handle using module_handle = void*; +// function_handle - opaque function handle +using function_handle = void*; + +// forward declaration +class function; + class module { std::shared_ptr m_ctx; + xrt::xclbin m_xclbin; + xrt::hw_context m_hw_ctx; + xrt_core::handle_map> function_cache; public: - module(std::shared_ptr ctx); + module() = default; + module(std::shared_ptr ctx, const std::string& file_name); + module(std::shared_ptr ctx, void* image); + + void + create_hw_context(); + + xrt::kernel + create_kernel(std::string& name); + + function_handle + add_function(std::shared_ptr&& f) + { + return insert_in_map(function_cache, f); + } + + std::shared_ptr + get_function(function_handle handle) + { + return function_cache.get(handle); + } +}; + +class function +{ + module* m_module; + std::string m_func_name; + xrt::kernel m_kernel; + +public: + function() = default; + function(module_handle mod_hdl, std::string&& name); + + module* + get_module() + { + return m_module; + } }; extern xrt_core::handle_map> module_cache; } // xrt::core::hip #endif + diff --git a/src/runtime_src/hip/core/stream.cpp b/src/runtime_src/hip/core/stream.cpp index d10e1171391..787eba24e10 100644 --- a/src/runtime_src/hip/core/stream.cpp +++ b/src/runtime_src/hip/core/stream.cpp @@ -12,3 +12,4 @@ stream(std::shared_ptr ctx) // Global map of streams xrt_core::handle_map> stream_cache; } + diff --git a/src/runtime_src/hip/core/stream.h b/src/runtime_src/hip/core/stream.h index 1991c9be138..8877018c41f 100644 --- a/src/runtime_src/hip/core/stream.h +++ b/src/runtime_src/hip/core/stream.h @@ -15,6 +15,7 @@ class stream std::shared_ptr m_ctx; public: + stream() = default; stream(std::shared_ptr ctx); }; @@ -22,3 +23,4 @@ extern xrt_core::handle_map> stream_cache } // xrt::core::hip #endif +