Skip to content

Commit

Permalink
[mosaic_gpu] The profiler now uses FFI calls for creating events and …
Browse files Browse the repository at this point in the history
…computing elapsed time

PiperOrigin-RevId: 695798787
  • Loading branch information
superbobry authored and Google-ML-Automation committed Nov 12, 2024
1 parent 1221da8 commit d304025
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 96 deletions.
1 change: 1 addition & 0 deletions jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -722,6 +722,7 @@ py_library(
":jax",
":mlir",
"//jax/_src/lib",
"//jax/extend:ffi",
"//jaxlib/mlir:arithmetic_dialect",
"//jaxlib/mlir:builtin_dialect",
"//jaxlib/mlir:func_dialect",
Expand Down
114 changes: 56 additions & 58 deletions jax/experimental/mosaic/gpu/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,15 @@
# ==============================================================================

import contextlib
import ctypes
import functools
import itertools
import json
import math
from typing import Callable, ParamSpec, TypeVar
import warnings

import jax
from jax._src.interpreters import mlir
from jax._src.lib import xla_client
from jax.extend import ffi
import jax.numpy as jnp
from jaxlib.mlir import ir
from jaxlib.mlir.dialects import arith
Expand All @@ -34,72 +33,71 @@

from .utils import * # noqa: F403


try:
from jax._src.lib import mosaic_gpu as mosaic_gpu_lib

xla_client.register_custom_call_target(
"mosaic_gpu_record_event",
mosaic_gpu_lib._mosaic_gpu_ext._record_event_capsule(),
platform="CUDA",
)
except ImportError:
pass
else:
for name, handler in mosaic_gpu_lib._mosaic_gpu_ext.registrations():
xla_client.register_custom_call_target(
name, handler, platform="CUDA", api_version=1
)

# ruff: noqa: F405
# mypy: ignore-errors

T = TypeVar("T")
P = ParamSpec("P")

record_event_p = jax.core.Primitive("record_event")
record_event_p.multiple_results = True

@record_event_p.def_abstract_eval
def _record_event_abstract_eval(*args, event):
del event # Unused.
return args

@functools.partial(mlir.register_lowering, record_event_p, platform="cuda")
def _record_event_lowering_rule(ctx, *args, event):
ptr_bytes = ctypes.cast(event, ctypes.c_void_p).value.to_bytes(
8, byteorder="little"
) # pytype: disable=attribute-error
op = mlir.custom_call(
"mosaic_gpu_record_event",
result_types=[mlir.aval_to_ir_type(aval) for aval in ctx.avals_out],
operands=args,
backend_config=ptr_bytes,
operand_output_aliases={i: i for i in range(len(args))},
)
return op.results

def _record_event(args, event):
def _event_record(args, *, copy_before):
flat_args, treedef = jax.tree.flatten(args)
return jax.tree.unflatten(
treedef, record_event_p.bind(*flat_args, event=event)
)

def measure(f, *args, **kwargs):
# TODO(apaszke): Raise if this is called under jit.
start_event = mosaic_gpu_lib._mosaic_gpu_ext._gpu_event_create()
end_event = mosaic_gpu_lib._mosaic_gpu_ext._gpu_event_create()
try:

@jax.jit
def run(*args, **kwargs):
flat_args, treedef = jax.tree.flatten((args, kwargs))
flat_args = _record_event(flat_args, start_event)
args, kwargs = jax.tree.unflatten(treedef, flat_args)
return _record_event(f(*args, **kwargs), end_event)

jax.block_until_ready(run(*args, **kwargs)) # Warmup.
results = jax.block_until_ready(run(*args, **kwargs))
elapsed = mosaic_gpu_lib._mosaic_gpu_ext._gpu_event_elapsed(
start_event, end_event
event, *flat_outs = ffi.ffi_call(
"mgpu_event_record",
result_shape_dtypes=(jax.core.ShapedArray((), jnp.uint64), *flat_args),
input_output_aliases={i: i + 1 for i in range(len(flat_args))},
)(*flat_args, copy_before=copy_before)
return event, treedef.unflatten(flat_outs)


def _event_elapsed(start_event, end_event):
return ffi.ffi_call(
"mgpu_event_elapsed",
result_shape_dtypes=jax.core.ShapedArray((), jnp.float32),
)(start_event, end_event)


def measure(
f: Callable[P, T], *args: P.args, **kwargs: P.kwargs
) -> tuple[T, float]:
"""Measures the time it takes to execute the function on the GPU.
Args:
f: The function to measure. It must accept at least one argument and return
at least one output to be measurable.
*args: The arguments to pass to ``f``.
**kwargs: The keyword arguments to pass to ``f``.
Returns:
The return value of ``f`` and the elapsed time in milliseconds.
"""
if not (args or kwargs):
# We require at least one argument and at least one output to ensure
# that there is a data dependency between `_event_record` calls in
# the resulting HLO program.
raise ValueError("Can only measure functions with arguments")

@jax.jit
def run(*args, **kwargs):
start_event, (args, kwargs) = _event_record(
(args, kwargs), copy_before=True
)
finally:
mosaic_gpu_lib._mosaic_gpu_ext._gpu_event_destroy(start_event)
mosaic_gpu_lib._mosaic_gpu_ext._gpu_event_destroy(end_event)
return results, elapsed
end_event, outs = _event_record(f(*args, **kwargs), copy_before=False)
if jax.tree.structure(outs).num_leaves == 0:
raise ValueError("Can only measure functions with at least one output")
return outs, _event_elapsed(start_event, end_event)

outs, elapsed = run(*args, **kwargs)
return outs, float(elapsed)


class ProfilerSpec:
Expand Down
4 changes: 3 additions & 1 deletion jaxlib/mosaic/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,11 @@ pybind_extension(
deps = [
"//jaxlib:kernel_nanobind_helpers",
"//jaxlib/cuda:cuda_vendor",
"@com_google_absl//absl/cleanup",
"@com_google_absl//absl/strings",
"@nanobind",
"@xla//xla/service:custom_call_status",
"@xla//xla/ffi/api:c_api",
"@xla//xla/ffi/api:ffi",
"@xla//xla/tsl/cuda:cudart",
],
)
Expand Down
122 changes: 85 additions & 37 deletions jaxlib/mosaic/gpu/mosaic_gpu_ext.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,24 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <cstdint>
#include <memory>
#include <stdexcept>
#include <string>

#include "nanobind/nanobind.h"
#include "absl/cleanup/cleanup.h"
#include "absl/strings/str_cat.h"
#include "jaxlib/gpu/vendor.h"
#include "jaxlib/kernel_nanobind_helpers.h"
#include "xla/service/custom_call_status.h"
#include "xla/ffi/api/c_api.h"
#include "xla/ffi/api/ffi.h"

namespace jax::cuda {
namespace {

namespace ffi = xla::ffi;
namespace nb = nanobind;

static std::string ToString(CUresult result) {
const char* error_name;
if (cuGetErrorName(result, &error_name)) {
Expand All @@ -38,45 +43,88 @@ static std::string ToString(CUresult result) {
return absl::StrCat(error_name, ": ", error_string);
}

void EventRecordCall(void* stream, void** buffers, char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto* event = reinterpret_cast<gpuEvent_t**>(opaque);
if (auto res = gpuEventRecord(**event, reinterpret_cast<gpuStream_t>(stream));
res) {
auto message = absl::StrCat("Failed to record event: ", ToString(res));
XlaCustomCallStatusSetFailure(status, message.c_str(), message.size());
}
// Ensure it is safe to store gpuEvent_t in a uint64_t buffer.
static_assert(sizeof(gpuEvent_t) <= sizeof(uint64_t));

static const auto* kEventRecord =
ffi::Ffi::Bind()
.Ctx<ffi::PlatformStream<gpuStream_t>>()
.Attr<bool>("copy_before")
.RemainingArgs()
.Ret<ffi::BufferR0<ffi::U64>>() // event
.RemainingRets()
.To([](gpuStream_t stream, bool copy_before,
auto remaining_args, auto ret, auto remaining_rets) {
static auto* event = new gpuEvent_t;
if (auto res = gpuEventCreate(event, GPU_EVENT_DEFAULT);
res) {
return ffi::Error::Internal(
absl::StrCat("Failed to create event: ", ToString(res)));
}
auto do_copy = [&]() {
gpuMemcpyAsync(ret->untyped_data(), event,
sizeof(gpuEvent_t), gpuMemcpyHostToDevice, stream);
};
if (copy_before) {
do_copy();
}
if (auto res = gpuEventRecord(*event, stream); res) {
return ffi::Error::Internal(
absl::StrCat("Failed to record event: ", ToString(res)));
}
if (!copy_before) {
do_copy();
}
return ffi::Error::Success();
})
.release();

XLA_FFI_Error* EventRecord(XLA_FFI_CallFrame* call_frame) {
return kEventRecord->Call(call_frame);
}

static const auto* kEventElapsed =
ffi::Ffi::Bind()
.Ctx<ffi::PlatformStream<gpuStream_t>>()
.Arg<ffi::BufferR0<ffi::U64>>() // start_event
.Arg<ffi::BufferR0<ffi::U64>>() // end_event
.Ret<ffi::BufferR0<ffi::F32>>() // elapsed_ms
.To([](gpuStream_t stream, auto start, auto end, auto out) {
gpuStreamSynchronize(stream);
auto start_event = std::make_unique<gpuEvent_t>();
auto end_event = std::make_unique<gpuEvent_t>();
absl::MakeCleanup([&]() {
gpuEventDestroy(*start_event);
gpuEventDestroy(*end_event);
});
gpuMemcpy(start_event.get(), start.untyped_data(), sizeof(gpuEvent_t),
gpuMemcpyDeviceToHost);
gpuMemcpy(end_event.get(), end.untyped_data(), sizeof(gpuEvent_t),
gpuMemcpyDeviceToHost);
float elapsed;
if (auto res =
gpuEventElapsedTime(&elapsed, *start_event, *end_event);
res) {
return ffi::Error::Internal(absl::StrCat(
"Failed to get elapsed time between events: ", ToString(res)));
}
gpuMemcpy(out->untyped_data(), &elapsed, sizeof(float),
gpuMemcpyHostToDevice);
return ffi::Error::Success();
})
.release();

XLA_FFI_Error* EventElapsed(XLA_FFI_CallFrame* call_frame) {
return kEventElapsed->Call(call_frame);
}

NB_MODULE(_mosaic_gpu_ext, m) {
m.def("_gpu_event_create", []() {
gpuEvent_t* event = new gpuEvent_t();
if (auto res = gpuEventCreate(event, GPU_EVENT_DEFAULT); res) {
throw std::runtime_error(
absl::StrCat("Failed to create event: ", ToString(res)));
}
return reinterpret_cast<uintptr_t>(event);
});
m.def("_gpu_event_destroy", [](uintptr_t event) {
if (auto res = gpuEventDestroy(*reinterpret_cast<gpuEvent_t*>(event));
res) {
throw std::runtime_error(
absl::StrCat("Failed to destroy event: ", ToString(res)));
}
});
m.def("_gpu_event_elapsed", [](uintptr_t start_event, uintptr_t end_event) {
float elapsed_ms = -1;
if (auto res = gpuEventElapsedTime(
&elapsed_ms, *reinterpret_cast<gpuEvent_t*>(start_event),
*reinterpret_cast<gpuEvent_t*>(end_event));
res) {
throw std::runtime_error(absl::StrCat(
"Failed to get elapsed time between events: ", ToString(res)));
}
return elapsed_ms;
m.def("registrations", []() {
return nb::make_tuple(
nb::make_tuple("mgpu_event_record", EncapsulateFunction(EventRecord)),
nb::make_tuple("mgpu_event_elapsed", EncapsulateFunction(EventElapsed))
);
});
m.def("_record_event_capsule",
[]() { return EncapsulateFunction(EventRecordCall); });
m.def("_sync_all_devices", []() {
int devices = 0;
if (cudaGetDeviceCount(&devices) != gpuSuccess) {
Expand Down

0 comments on commit d304025

Please sign in to comment.