forked from iree-org/iree
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[cuda] Port over existing semaphore impl (iree-org#14325)
The current semaphore implementation actually does nothing. Still port over it so that we can have a full implementation passing various end-to-end tests to be based on. This makes having a proper implementation of semaphores easier later as we can verify correctness immediately afterwards. Progress towards iree-org#13245
- Loading branch information
1 parent
d01a83c
commit 04beef2
Showing
4 changed files
with
148 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
// Copyright 2023 The IREE Authors | ||
// | ||
// Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
#include "experimental/cuda2/nop_semaphore.h" | ||
|
||
#include <stddef.h> | ||
|
||
#include "iree/base/api.h" | ||
#include "iree/hal/utils/semaphore_base.h" | ||
|
||
typedef struct iree_hal_cuda2_semaphore_t { | ||
iree_hal_semaphore_t base; | ||
iree_allocator_t host_allocator; | ||
iree_atomic_int64_t value; | ||
} iree_hal_cuda2_semaphore_t; | ||
|
||
static const iree_hal_semaphore_vtable_t iree_hal_cuda2_semaphore_vtable; | ||
|
||
static iree_hal_cuda2_semaphore_t* iree_hal_cuda2_semaphore_cast( | ||
iree_hal_semaphore_t* base_value) { | ||
IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_cuda2_semaphore_vtable); | ||
return (iree_hal_cuda2_semaphore_t*)base_value; | ||
} | ||
|
||
iree_status_t iree_hal_cuda2_semaphore_create( | ||
uint64_t initial_value, iree_allocator_t host_allocator, | ||
iree_hal_semaphore_t** out_semaphore) { | ||
IREE_ASSERT_ARGUMENT(out_semaphore); | ||
IREE_TRACE_ZONE_BEGIN(z0); | ||
|
||
iree_hal_cuda2_semaphore_t* semaphore = NULL; | ||
iree_status_t status = iree_allocator_malloc( | ||
host_allocator, sizeof(*semaphore), (void**)&semaphore); | ||
if (iree_status_is_ok(status)) { | ||
iree_hal_semaphore_initialize(&iree_hal_cuda2_semaphore_vtable, | ||
&semaphore->base); | ||
semaphore->host_allocator = host_allocator; | ||
iree_atomic_store_int64(&semaphore->value, initial_value, | ||
iree_memory_order_release); | ||
*out_semaphore = &semaphore->base; | ||
} | ||
|
||
IREE_TRACE_ZONE_END(z0); | ||
return status; | ||
} | ||
|
||
static void iree_hal_cuda2_semaphore_destroy( | ||
iree_hal_semaphore_t* base_semaphore) { | ||
iree_hal_cuda2_semaphore_t* semaphore = | ||
iree_hal_cuda2_semaphore_cast(base_semaphore); | ||
iree_allocator_t host_allocator = semaphore->host_allocator; | ||
IREE_TRACE_ZONE_BEGIN(z0); | ||
|
||
iree_hal_semaphore_deinitialize(&semaphore->base); | ||
iree_allocator_free(host_allocator, semaphore); | ||
|
||
IREE_TRACE_ZONE_END(z0); | ||
} | ||
|
||
static iree_status_t iree_hal_cuda2_semaphore_query( | ||
iree_hal_semaphore_t* base_semaphore, uint64_t* out_value) { | ||
iree_hal_cuda2_semaphore_t* semaphore = | ||
iree_hal_cuda2_semaphore_cast(base_semaphore); | ||
// TODO: Support semaphores completely. | ||
*out_value = | ||
iree_atomic_load_int64(&semaphore->value, iree_memory_order_acquire); | ||
return iree_ok_status(); | ||
} | ||
|
||
static iree_status_t iree_hal_cuda2_semaphore_signal( | ||
iree_hal_semaphore_t* base_semaphore, uint64_t new_value) { | ||
iree_hal_cuda2_semaphore_t* semaphore = | ||
iree_hal_cuda2_semaphore_cast(base_semaphore); | ||
// TODO: Support semaphores completely. Return OK currently as everything is | ||
// synchronized for each submit to allow things to run. | ||
iree_atomic_store_int64(&semaphore->value, new_value, | ||
iree_memory_order_release); | ||
iree_hal_semaphore_poll(&semaphore->base); | ||
return iree_ok_status(); | ||
} | ||
|
||
static void iree_hal_cuda2_semaphore_fail(iree_hal_semaphore_t* base_semaphore, | ||
iree_status_t status) { | ||
iree_hal_cuda2_semaphore_t* semaphore = | ||
iree_hal_cuda2_semaphore_cast(base_semaphore); | ||
// TODO: save status and mark timepoint as failed. | ||
iree_status_ignore(status); | ||
iree_hal_semaphore_poll(&semaphore->base); | ||
} | ||
|
||
static iree_status_t iree_hal_cuda2_semaphore_wait( | ||
iree_hal_semaphore_t* base_semaphore, uint64_t value, | ||
iree_timeout_t timeout) { | ||
iree_hal_cuda2_semaphore_t* semaphore = | ||
iree_hal_cuda2_semaphore_cast(base_semaphore); | ||
// TODO: Support semaphores completely. Return OK currently as everything is | ||
// synchronized for each submit to allow things to run. | ||
iree_hal_semaphore_poll(&semaphore->base); | ||
return iree_ok_status(); | ||
} | ||
|
||
static const iree_hal_semaphore_vtable_t iree_hal_cuda2_semaphore_vtable = { | ||
.destroy = iree_hal_cuda2_semaphore_destroy, | ||
.query = iree_hal_cuda2_semaphore_query, | ||
.signal = iree_hal_cuda2_semaphore_signal, | ||
.fail = iree_hal_cuda2_semaphore_fail, | ||
.wait = iree_hal_cuda2_semaphore_wait, | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
// Copyright 2023 The IREE Authors | ||
// | ||
// Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
#ifndef EXPERIMENTAL_CUDA2_NOP_SEMAPHORE_H_ | ||
#define EXPERIMENTAL_CUDA2_NOP_SEMAPHORE_H_ | ||
|
||
#include <stdint.h> | ||
|
||
#include "iree/base/api.h" | ||
#include "iree/hal/api.h" | ||
|
||
#ifdef __cplusplus | ||
extern "C" { | ||
#endif // __cplusplus | ||
|
||
// Creates a HAL semaphore for CUDA that does not perform real synchronization. | ||
// This is expected to work with a command buffer that serializes all commands. | ||
iree_status_t iree_hal_cuda2_semaphore_create( | ||
uint64_t initial_value, iree_allocator_t host_allocator, | ||
iree_hal_semaphore_t** out_semaphore); | ||
|
||
#ifdef __cplusplus | ||
} // extern "C" | ||
#endif // __cplusplus | ||
|
||
#endif // EXPERIMENTAL_CUDA2_NOP_SEMAPHORE_H_ |