Skip to content

Commit

Permalink
[mosaic_gpu] Check the return code of gpuEventCreate and `gpuEventD…
Browse files Browse the repository at this point in the history
…estroy`

PiperOrigin-RevId: 693260326
  • Loading branch information
superbobry authored and Google-ML-Automation committed Nov 5, 2024
1 parent 63e59c5 commit 34b4787
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 10 deletions.
1 change: 1 addition & 0 deletions jaxlib/mosaic/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ pybind_extension(
deps = [
"//jaxlib:kernel_nanobind_helpers",
"//jaxlib/cuda:cuda_vendor",
"@com_google_absl//absl/strings",
"@nanobind",
"@xla//xla/service:custom_call_status",
"@xla//xla/tsl/cuda:cudart",
Expand Down
42 changes: 32 additions & 10 deletions jaxlib/mosaic/gpu/mosaic_gpu_ext.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,42 +14,64 @@ limitations under the License.
==============================================================================*/

#include <cstdint>
#include <stdexcept>
#include <string>

#include "nanobind/nanobind.h"
#include "absl/strings/str_cat.h"
#include "jaxlib/gpu/vendor.h"
#include "jaxlib/kernel_nanobind_helpers.h"
#include "xla/service/custom_call_status.h"

namespace jax::cuda {
namespace {

namespace nb = nanobind;
static std::string ToString(CUresult result) {
const char* error_name;
if (cuGetErrorName(result, &error_name)) {
return absl::StrCat("UNKNOWN ERROR (", static_cast<int>(result), ")");
}
const char* error_string;
if (cuGetErrorString(result, &error_string)) {
return error_name;
}
return absl::StrCat(error_name, ": ", error_string);
}

void EventRecordCall(void* stream, void** buffers, char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto* event = reinterpret_cast<gpuEvent_t**>(opaque);
if (gpuEventRecord(**event, reinterpret_cast<gpuStream_t>(stream)) !=
gpuSuccess) {
const char message[] = "Failed to record event";
XlaCustomCallStatusSetFailure(status, message, sizeof(message));
if (auto res = gpuEventRecord(**event, reinterpret_cast<gpuStream_t>(stream));
res) {
auto message = absl::StrCat("Failed to record event: ", ToString(res));
XlaCustomCallStatusSetFailure(status, message.c_str(), message.size());
}
}

NB_MODULE(_mosaic_gpu_ext, m) {
m.def("_gpu_event_create", []() {
gpuEvent_t* event = new gpuEvent_t();
gpuEventCreate(event, GPU_EVENT_DEFAULT);
if (auto res = gpuEventCreate(event, GPU_EVENT_DEFAULT); res) {
throw std::runtime_error(
absl::StrCat("Failed to create event: ", ToString(res)));
}
return reinterpret_cast<uintptr_t>(event);
});
m.def("_gpu_event_destroy", [](uintptr_t event) {
gpuEventDestroy(*reinterpret_cast<gpuEvent_t*>(event));
if (auto res = gpuEventDestroy(*reinterpret_cast<gpuEvent_t*>(event));
res) {
throw std::runtime_error(
absl::StrCat("Failed to destroy event: ", ToString(res)));
}
});
m.def("_gpu_event_elapsed", [](uintptr_t start_event, uintptr_t end_event) {
float elapsed_ms = -1;
if (gpuEventElapsedTime(
if (auto res = gpuEventElapsedTime(
&elapsed_ms, *reinterpret_cast<gpuEvent_t*>(start_event),
*reinterpret_cast<gpuEvent_t*>(end_event)) != gpuSuccess) {
throw std::runtime_error("Failed to get elapsed time between events");
*reinterpret_cast<gpuEvent_t*>(end_event));
res) {
throw std::runtime_error(absl::StrCat(
"Failed to get elapsed time between events: ", ToString(res)));
}
return elapsed_ms;
});
Expand Down

0 comments on commit 34b4787

Please sign in to comment.