Skip to content

Commit

Permalink
Initial implementation of Hip module and function apis (Xilinx#7964)
Browse files Browse the repository at this point in the history
* Initial implementation of module and function hip objects

Signed-off-by: rbramand <[email protected]>

* Add function for throwing error for reusability

Signed-off-by: rbramand <[email protected]>

* Address comments on PR

Signed-off-by: rbramand <[email protected]>

* Remove boiler plate code

Signed-off-by: rbramand <[email protected]>

---------

Signed-off-by: rbramand <[email protected]>
Co-authored-by: rbramand <[email protected]>
  • Loading branch information
rbramand-xilinx and rbramand authored Feb 28, 2024
1 parent 253d71d commit 6e6f7d5
Show file tree
Hide file tree
Showing 14 changed files with 287 additions and 119 deletions.
53 changes: 21 additions & 32 deletions src/runtime_src/hip/api/hip_context.cpp
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
// SPDX-License-Identifier: Apache-2.0
// Copyright (C) 2023-2024 Advanced Micro Device, Inc. All rights reserved.

#include "hip/config.h"
#include "hip/hip_runtime_api.h"

#include "hip/core/common.h"
#include "hip/core/context.h"
#include "hip/core/device.h"
Expand All @@ -23,27 +20,24 @@ static context_handle
hip_ctx_create(unsigned int flags, hipDevice_t device)
{
auto hip_dev = device_cache.get(static_cast<device_handle>(device));
if (!hip_dev)
throw xrt_core::system_error(hipErrorInvalidValue, "device requested is not available");
throw_invalid_value_if(!hip_dev, "device requested is not available");

hip_dev->set_flags(flags);
auto hip_ctx = std::make_shared<context>(hip_dev);
tls_objs.ctx_stack.push(hip_ctx); // make it current
context_cache.add(hip_ctx.get(), std::move(hip_ctx));
return hip_ctx.get();
tls_objs.dev_hdl = device;
// insert handle in ctx map and return handle
return insert_in_map(context_cache, std::move(hip_ctx));
}

static void
hip_ctx_destroy(hipCtx_t ctx)
{
auto handle = reinterpret_cast<context_handle>(ctx);
if (!handle) {
throw xrt_core::system_error(hipErrorInvalidValue, "device requested is not available");
}
throw_invalid_value_if(!handle, "device requested is not available");

auto hip_ctx = context_cache.get(handle);
if (!hip_ctx)
throw xrt_core::system_error(hipErrorInvalidValue, "context handle not found");
throw_invalid_value_if(!hip_ctx, "context handle not found");

// Need to remove the ctx of calling thread if its the top one
if (!tls_objs.ctx_stack.empty() && tls_objs.ctx_stack.top().lock() == hip_ctx) {
Expand All @@ -57,8 +51,7 @@ static device_handle
hip_ctx_get_device()
{
auto ctx = get_current_context();
if (!ctx)
throw xrt_core::system_error(hipErrorInvalidValue, "Error retrieving context");
throw_invalid_value_if(!ctx, "Error retrieving context");

return ctx->get_dev_id();
}
Expand All @@ -74,8 +67,10 @@ hip_ctx_set_current(hipCtx_t ctx)

auto handle = reinterpret_cast<context_handle>(ctx);
auto hip_ctx = context_cache.get(handle);
if (hip_ctx)
if (hip_ctx) {
tls_objs.ctx_stack.push(hip_ctx);
tls_objs.dev_hdl = hip_ctx->get_dev_id();
}
}

// remove primary ctx as active
Expand All @@ -85,8 +80,7 @@ hip_device_primary_ctx_release(hipDevice_t dev)
{
auto dev_hdl = static_cast<device_handle>(dev);
auto hip_dev = device_cache.get(dev_hdl);
if (!hip_dev)
throw xrt_core::system_error(hipErrorInvalidDevice, "Invalid device");
throw_invalid_device_if(!hip_dev, "Invalid device");

auto ctx = hip_dev->get_pri_ctx();
if (!ctx)
Expand All @@ -97,21 +91,19 @@ hip_device_primary_ctx_release(hipDevice_t dev)
auto ctx_hdl =
reinterpret_cast<context_handle>(std::hash<std::thread::id>{}(std::this_thread::get_id()));
context_cache.remove(ctx_hdl);
if (tls_objs.pri_ctx_info.active && tls_objs.pri_ctx_info.dev_hdl == dev_hdl) {
if (tls_objs.pri_ctx_info.active && tls_objs.dev_hdl == dev_hdl) {
tls_objs.pri_ctx_info.active = false;
tls_objs.pri_ctx_info.dev_hdl = UINT32_MAX;
tls_objs.pri_ctx_info.ctx_hdl = nullptr;
}
}

// create primary context on given device if not already present
// else increment reference count
static context_handle
hip_device_primary_ctx_retain(hipDevice_t dev)
context_handle
hip_device_primary_ctx_retain(device_handle dev)
{
auto hip_dev = device_cache.get(dev);
if (!hip_dev)
throw xrt_core::system_error(hipErrorInvalidDevice, "Invalid device");
throw_invalid_device_if(!hip_dev, "Invalid device");

auto hip_ctx = hip_dev->get_pri_ctx();
// create primary context
Expand All @@ -123,13 +115,12 @@ hip_device_primary_ctx_retain(hipDevice_t dev)
// unqiue handle, using thread id here as primary context is unique per thread
auto ctx_hdl =
reinterpret_cast<context_handle>(std::hash<std::thread::id>{}(std::this_thread::get_id()));
auto handle = hip_ctx.get();
context_cache.add(ctx_hdl, std::move(hip_ctx));

tls_objs.pri_ctx_info.active = true;
tls_objs.pri_ctx_info.ctx_hdl = ctx_hdl;
tls_objs.pri_ctx_info.dev_hdl = dev;
return handle;
tls_objs.dev_hdl = dev;
return ctx_hdl;
}
} // xrt::core::hip

Expand All @@ -139,8 +130,7 @@ hipError_t
hipCtxCreate(hipCtx_t* ctx, unsigned int flags, hipDevice_t device)
{
try {
if (!ctx)
throw xrt_core::system_error(hipErrorInvalidValue, "ctx passed is nullptr");
throw_invalid_value_if(!ctx, "ctx passed is nullptr");

auto handle = xrt::core::hip::hip_ctx_create(flags, device);
*ctx = reinterpret_cast<hipCtx_t>(handle);
Expand Down Expand Up @@ -177,8 +167,7 @@ hipError_t
hipCtxGetDevice(hipDevice_t* device)
{
try {
if (!device)
throw xrt_core::system_error(hipErrorInvalidValue, "device passed is nullptr");
throw_invalid_value_if(!device, "device passed is nullptr");

*device = xrt::core::hip::hip_ctx_get_device();
return hipSuccess;
Expand Down Expand Up @@ -214,8 +203,7 @@ hipError_t
hipDevicePrimaryCtxRetain(hipCtx_t* pctx, hipDevice_t dev)
{
try {
if (!pctx)
throw xrt_core::system_error(hipErrorInvalidValue, "nullptr passed");
throw_invalid_value_if(!pctx, "nullptr passed");

auto handle = xrt::core::hip::hip_device_primary_ctx_retain(dev);
*pctx = reinterpret_cast<hipCtx_t>(handle);
Expand Down Expand Up @@ -247,3 +235,4 @@ hipDevicePrimaryCtxRelease(hipDevice_t dev)
}
return hipErrorUnknown;
}

55 changes: 24 additions & 31 deletions src/runtime_src/hip/api/hip_device.cpp
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
// SPDX-License-Identifier: Apache-2.0
// Copyright (C) 2023-2024 Advanced Micro Device, Inc. All rights reserved.

#include "core/common/error.h"
#include "core/include/experimental/xrt_system.h"

#include "hip/config.h"
#include "hip/hip_runtime_api.h"

#include "hip/core/common.h"
#include "hip/core/device.h"

#include <cstring>
Expand Down Expand Up @@ -50,16 +47,18 @@ device_init()
auto dev = std::make_shared<xrt::core::hip::device>(i);
device_cache.add(i, std::move(dev));
}
// make first device as default device
if (dev_count > 0)
tls_objs.dev_hdl = static_cast<device_handle>(0);
}

static void
hip_init(unsigned int flags)
{
// Flags should be zero as per Hip doc
if (flags != 0)
throw xrt_core::system_error(hipErrorInvalidValue, "non zero flags passed to hipinit");
throw_invalid_value_if(flags != 0, "non zero flags passed to hipinit");

// call device_init function, device enumeration might not have happened
// call device_init function, device enumeration might not have happened
// at library load because of some exception
// std::once_flag ensures init is called only once
std::call_once(device_init_flag, xrt::core::hip::device_init);
Expand All @@ -71,55 +70,55 @@ hip_get_device_count()
// Get device count
auto count = xrt::core::hip::device_cache.size();

if (count < 1)
throw xrt_core::system_error(hipErrorNoDevice, "No valid device available");
throw_if(count < 1, hipErrorNoDevice, "No valid device available");

return count;
}

inline bool
check(int dev_id)
{
return (dev_id < 0 || device_cache.count(static_cast<device_handle>(dev_id)) == 0);
}

// Returns a handle to compute device
// Throws on error
static int
hip_device_get(int ordinal)
{
if (ordinal < 0 || device_cache.count(static_cast<device_handle>(ordinal)) == 0)
throw xrt_core::system_error(hipErrorInvalidDevice, "device requested is not available");
throw_invalid_device_if(check(ordinal), "device requested is not available");

return ordinal;
}

static std::string
hip_device_get_name(hipDevice_t device)
{
if (device < 0 || xrt::core::hip::device_cache.count(static_cast<xrt::core::hip::device_handle>(device)) == 0)
throw xrt_core::system_error(hipErrorInvalidDevice, " - device requested is not available");
throw_invalid_device_if(check(device), "device requested is not available");

throw std::runtime_error("Not implemented");
}

static hipDeviceProp_t
hip_get_device_properties(hipDevice_t device)
{
if (device < 0 || xrt::core::hip::device_cache.count(static_cast<xrt::core::hip::device_handle>(device)) == 0)
throw xrt_core::system_error(hipErrorInvalidDevice, "device requested is not available");
throw_invalid_device_if(check(device), "device requested is not available");

throw std::runtime_error("Not implemented");
}

static hipUUID
hip_device_get_uuid(hipDevice_t device)
{
if (device < 0 || xrt::core::hip::device_cache.count(static_cast<xrt::core::hip::device_handle>(device)) == 0)
throw xrt_core::system_error(hipErrorInvalidDevice, "device requested is not available");
throw_invalid_device_if(check(device), "device requested is not available");

throw std::runtime_error("Not implemented");
}

static int
hip_device_get_attribute(hipDeviceAttribute_t attr, int device)
{
if (device < 0 || xrt::core::hip::device_cache.count(static_cast<xrt::core::hip::device_handle>(device)) == 0)
throw xrt_core::system_error(hipErrorInvalidDevice, "device requested is not available");
throw_invalid_device_if(check(device), "device requested is not available");

throw std::runtime_error("Not implemented");
}
Expand Down Expand Up @@ -148,8 +147,7 @@ hipError_t
hipGetDeviceCount(int* count)
{
try {
if (!count)
throw xrt_core::system_error(hipErrorInvalidValue, "arg passed is nullptr");
throw_invalid_value_if(!count, "arg passed is nullptr");

*count = xrt::core::hip::hip_get_device_count();
return hipSuccess;
Expand All @@ -168,8 +166,7 @@ hipError_t
hipDeviceGet(hipDevice_t* device, int ordinal)
{
try {
if (!device)
throw xrt_core::system_error(hipErrorInvalidValue, "device is nullptr");
throw_invalid_value_if(!device, "device is nullptr");

*device = xrt::core::hip::hip_device_get(ordinal);
return hipSuccess;
Expand All @@ -188,8 +185,7 @@ hipError_t
hipDeviceGetName(char* name, int len, hipDevice_t device)
{
try {
if (!name || len <= 0)
throw xrt_core::system_error(hipErrorInvalidValue, "invalid arg");
throw_invalid_value_if((!name || len <= 0), "invalid arg");

auto name_str = xrt::core::hip::hip_device_get_name(device);
// Only copy partial name if size of `dest` is smaller than size of `src` including
Expand All @@ -213,8 +209,7 @@ hipError_t
hipGetDeviceProperties(hipDeviceProp_t* props, hipDevice_t device)
{
try {
if (!props)
throw xrt_core::system_error(hipErrorInvalidValue, "arg passed is nullptr");
throw_invalid_value_if(!props, "arg passed is nullptr");

*props = xrt::core::hip::hip_get_device_properties(device);
return hipSuccess;
Expand All @@ -233,8 +228,7 @@ hipError_t
hipDeviceGetUuid(hipUUID* uuid, hipDevice_t device)
{
try {
if (!uuid)
throw xrt_core::system_error(hipErrorInvalidValue, "arg passed is nullptr");
throw_invalid_value_if(!uuid, "arg passed is nullptr");

*uuid = xrt::core::hip::hip_device_get_uuid(device);
return hipSuccess;
Expand All @@ -253,8 +247,7 @@ hipError_t
hipDeviceGetAttribute(int* pi, hipDeviceAttribute_t attr, int device)
{
try {
if (!pi)
throw xrt_core::system_error(hipErrorInvalidValue, "arg passed is nullptr");
throw_invalid_value_if(!pi, "arg passed is nullptr");

*pi = xrt::core::hip::hip_device_get_attribute(attr, device);
return hipSuccess;
Expand Down
Loading

0 comments on commit 6e6f7d5

Please sign in to comment.