Skip to content

Commit

Permalink
[jax_triton] Add user-specified name field to serialized format.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 557415723
  • Loading branch information
chr1sj0nes authored and jax authors committed Aug 16, 2023
1 parent c7e8b81 commit 4ac2bdc
Show file tree
Hide file tree
Showing 9 changed files with 126 additions and 54 deletions.
18 changes: 11 additions & 7 deletions jax/_src/pallas/triton/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,13 @@
from jax._src.lax.control_flow import for_loop
from jax._src.lib import gpu_triton as triton_kernel_call_lib
from jax._src.lib import hlo_helpers
from jax._src.lib import version
from jax._src.lib.mlir import ir
from jax._src.pallas import core as pallas_core
from jax._src.pallas import indexing
from jax._src.pallas import primitives
from jax._src.pallas import utils as pallas_utils
from jax._src.pallas.pallas_call import pallas_call_p
from jax._src.state import AbstractRef
from jax._src.state import discharge
from jax._src.state import primitives as sp
Expand All @@ -47,11 +53,6 @@
from jax.interpreters import partial_eval as pe
from jax.lib import xla_client as xc
import jax.numpy as jnp
from jax._src.pallas import core as pallas_core
from jax._src.pallas.pallas_call import pallas_call_p
from jax._src.pallas import primitives
from jax._src.pallas import indexing
from jax._src.pallas import utils as pallas_utils
from jax_triton import triton_lib
from jax_triton.triton_lib import compile_ttir_to_ptx_inplace
from jax_triton.triton_lib import get_triton_type
Expand Down Expand Up @@ -1687,12 +1688,15 @@ def pallas_call_lowering(
if triton_params is None:
triton_params = {}
serialized_metadata = triton_params.get("serialized_metadata", b"")

if version >= (0, 4, 15):
kernel_call_proto = kernel_call.to_proto(name, serialized_metadata)
else:
kernel_call_proto = kernel_call.to_proto(serialized_metadata)
return hlo_helpers.custom_call(
call_target_name=name,
out_types=out_types,
operands=in_nodes,
backend_config=zlib.compress(kernel_call.to_proto(serialized_metadata)),
backend_config=zlib.compress(kernel_call_proto),
operand_layouts=triton_lib.avals_to_layouts(ctx.avals_in),
result_layouts=triton_lib.avals_to_layouts(ctx.avals_out),
operand_output_aliases=dict(input_output_aliases),
Expand Down
16 changes: 16 additions & 0 deletions jaxlib/cuda/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,7 @@ cc_library(
deps = [
":cuda_gpu_kernel_helpers",
":cuda_vendor",
":triton_utils",
"//jaxlib/gpu:triton_cc_proto",
"@xla//xla/service:custom_call_status",
"@xla//xla/stream_executor/cuda:cudart_stub",
Expand All @@ -403,6 +404,20 @@ cc_library(
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/synchronization",
],
)

cc_library(
name = "triton_utils",
srcs = ["//jaxlib/gpu:triton_utils.cc"],
hdrs = ["//jaxlib/gpu:triton_utils.h"],
deps = [
":cuda_gpu_kernel_helpers",
":cuda_vendor",
"//jaxlib/gpu:triton_cc_proto",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@zlib",
],
)
Expand All @@ -426,6 +441,7 @@ pybind_extension(
":cuda_gpu_kernel_helpers",
":cuda_vendor",
":triton_kernels",
":triton_utils",
"//jaxlib:kernel_pybind11_helpers",
"//jaxlib/gpu:triton_cc_proto",
"@com_google_absl//absl/status:statusor",
Expand Down
2 changes: 2 additions & 0 deletions jaxlib/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ exports_files(srcs = [
"triton.cc",
"triton_kernels.cc",
"triton_kernels.h",
"triton_utils.cc",
"triton_utils.h",
"vendor.h",
])

Expand Down
36 changes: 18 additions & 18 deletions jaxlib/gpu/triton.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "jaxlib/gpu/gpu_kernel_helpers.h"
#include "jaxlib/gpu/triton.pb.h"
#include "jaxlib/gpu/triton_kernels.h"
#include "jaxlib/gpu/triton_utils.h"
#include "jaxlib/gpu/vendor.h"
#include "jaxlib/kernel_pybind11_helpers.h"
#include "pybind11_abseil/status_casters.h" // IWYU pragma: keep
Expand Down Expand Up @@ -80,9 +81,11 @@ PYBIND11_MODULE(_triton, m) {
py::class_<KernelCall>(m, "TritonKernelCall")
.def(py::init<Kernel, uint32_t, uint32_t, uint32_t,
std::vector<KernelCall::Parameter>>())
.def("to_proto", [](const KernelCall& kernel_call, std::string metadata) {
.def("to_proto", [](const KernelCall& kernel_call, std::string name,
std::string metadata) {
jax_triton::TritonAnyKernelCall proto;
*proto.mutable_kernel_call() = kernel_call.ToProto();
proto.set_name(std::move(name));
proto.set_metadata(std::move(metadata));
return py::bytes(proto.SerializeAsString());
});
Expand All @@ -102,13 +105,14 @@ PYBIND11_MODULE(_triton, m) {
std::move(name), std::move(configs),
std::move(input_output_aliases));
}))
.def("to_proto",
[](const AutotunedKernelCall& kernel_call, std::string metadata) {
jax_triton::TritonAnyKernelCall proto;
*proto.mutable_autotuned_kernel_call() = kernel_call.ToProto();
proto.set_metadata(std::move(metadata));
return py::bytes(proto.SerializeAsString());
});
.def("to_proto", [](const AutotunedKernelCall& kernel_call,
std::string name, std::string metadata) {
jax_triton::TritonAnyKernelCall proto;
*proto.mutable_autotuned_kernel_call() = kernel_call.ToProto();
proto.set_name(std::move(name));
proto.set_metadata(std::move(metadata));
return py::bytes(proto.SerializeAsString());
});

m.def("get_custom_call",
[] { return EncapsulateFunction(&TritonKernelCall); });
Expand All @@ -123,16 +127,12 @@ PYBIND11_MODULE(_triton, m) {
return major * 10 + minor;
});

m.def(
"get_serialized_metadata",
[](absl::string_view opaque) -> absl::StatusOr<py::bytes> {
JAX_ASSIGN_OR_RETURN(std::string serialized, ZlibUncompress(opaque));
jax_triton::TritonAnyKernelCall proto;
if (!proto.ParseFromString(serialized)) {
return absl::InvalidArgumentError("Failed to parse serialized data.");
}
return py::bytes(proto.metadata());
});
m.def("get_serialized_metadata",
[](absl::string_view opaque) -> absl::StatusOr<py::bytes> {
JAX_ASSIGN_OR_RETURN(std::string metadata,
GetTritonKernelCallSerializedMetadata(opaque));
return py::bytes(metadata);
});
}

} // namespace jax::JAX_GPU_NAMESPACE
5 changes: 3 additions & 2 deletions jaxlib/gpu/triton.proto
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ syntax = "proto3";
package jax_triton;

message TritonKernel {
string kernel_name = 1;
string kernel_name = 1; // Kernel function name within module.
uint32 num_warps = 2;
uint32 shared_mem_bytes = 3;
string ptx = 4;
Expand Down Expand Up @@ -49,7 +49,7 @@ message TritonAutotunedKernelCall {
uint64 buffer_size_bytes = 3;
}

string name = 1;
string name = 1; // Name used in auto-tuning log messages.
repeated Config configs = 2;
repeated InputOutputAlias input_output_aliases = 3;
}
Expand All @@ -60,4 +60,5 @@ message TritonAnyKernelCall {
TritonAutotunedKernelCall autotuned_kernel_call = 2;
}
bytes metadata = 3;
string name = 4; // User assigned name.
}
24 changes: 1 addition & 23 deletions jaxlib/gpu/triton_kernels.cc
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
#include "jaxlib/gpu/triton_kernels.h"

#include <zlib.h>

#include <algorithm>
#include <cstdint>
#include <memory>
Expand All @@ -23,6 +21,7 @@
#include "absl/synchronization/mutex.h"
#include "jaxlib/gpu/gpu_kernel_helpers.h"
#include "jaxlib/gpu/triton.pb.h"
#include "jaxlib/gpu/triton_utils.h"
#include "jaxlib/gpu/vendor.h"
#include "xla/service/custom_call_status.h"
#include "xla/stream_executor/gpu/asm_compiler.h"
Expand Down Expand Up @@ -519,25 +518,4 @@ void TritonKernelCall(CUstream stream, void** buffers, const char* opaque,
}
}

absl::StatusOr<std::string> ZlibUncompress(absl::string_view compressed) {
std::string data;
uLongf dest_len = 5 * compressed.size();
while (true) {
data.resize(dest_len);
int ret = uncompress(reinterpret_cast<Bytef*>(data.data()), &dest_len,
reinterpret_cast<const Bytef*>(compressed.data()),
compressed.size());
if (ret == Z_OK) {
// `uncompress` overwrites `dest_len` with the uncompressed size.
data.resize(dest_len);
break;
} else if (ret == Z_BUF_ERROR) {
dest_len *= 2; // The string buffer wasn't large enough.
} else {
return absl::InvalidArgumentError("Failed to uncompress opaque data.");
}
}
return data;
}

} // namespace jax::JAX_GPU_NAMESPACE
4 changes: 0 additions & 4 deletions jaxlib/gpu/triton_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,8 @@
#include <variant>
#include <vector>

#include "absl/cleanup/cleanup.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/synchronization/mutex.h"
#include "jaxlib/gpu/triton.pb.h"
#include "jaxlib/gpu/vendor.h"
#include "xla/service/custom_call_status.h"
Expand Down Expand Up @@ -101,8 +99,6 @@ class AutotunedKernelCall {
std::vector<std::tuple<size_t, size_t, size_t>> input_output_aliases_;
};

absl::StatusOr<std::string> ZlibUncompress(absl::string_view compressed);

} // namespace jax::JAX_GPU_NAMESPACE

#endif // JAXLIB_GPU_TRITON_H_
55 changes: 55 additions & 0 deletions jaxlib/gpu/triton_utils.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
#include "jaxlib/gpu/triton_utils.h"

#include <zlib.h>

#include <string>

#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "jaxlib/gpu/gpu_kernel_helpers.h"
#include "jaxlib/gpu/triton.pb.h"

namespace jax::JAX_GPU_NAMESPACE {

absl::StatusOr<std::string> ZlibUncompress(absl::string_view compressed) {
std::string data;
uLongf dest_len = 5 * compressed.size();
while (true) {
data.resize(dest_len);
int ret = uncompress(reinterpret_cast<Bytef*>(data.data()), &dest_len,
reinterpret_cast<const Bytef*>(compressed.data()),
compressed.size());
if (ret == Z_OK) {
// `uncompress` overwrites `dest_len` with the uncompressed size.
data.resize(dest_len);
break;
} else if (ret == Z_BUF_ERROR) {
dest_len *= 2; // The string buffer wasn't large enough.
} else {
return absl::InvalidArgumentError("Failed to uncompress opaque data.");
}
}
return data;
}

absl::StatusOr<std::string> GetTritonKernelCallName(absl::string_view opaque) {
JAX_ASSIGN_OR_RETURN(std::string serialized, ZlibUncompress(opaque));
jax_triton::TritonAnyKernelCall proto;
if (!proto.ParseFromString(serialized)) {
return absl::InvalidArgumentError("Failed to parse serialized data.");
}
return proto.name();
}

absl::StatusOr<std::string> GetTritonKernelCallSerializedMetadata(
absl::string_view opaque) {
JAX_ASSIGN_OR_RETURN(std::string serialized, ZlibUncompress(opaque));
jax_triton::TritonAnyKernelCall proto;
if (!proto.ParseFromString(serialized)) {
return absl::InvalidArgumentError("Failed to parse serialized data.");
}
return proto.metadata();
}

} // namespace jax::JAX_GPU_NAMESPACE
20 changes: 20 additions & 0 deletions jaxlib/gpu/triton_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#ifndef JAXLIB_GPU_TRITON_UTILS_H_
#define JAXLIB_GPU_TRITON_UTILS_H_

#include <string>

#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "jaxlib/gpu/vendor.h"

namespace jax::JAX_GPU_NAMESPACE {

absl::StatusOr<std::string> ZlibUncompress(absl::string_view compressed);
absl::StatusOr<std::string> GetTritonKernelCallName(absl::string_view opaque);
absl::StatusOr<std::string> GetTritonKernelCallSerializedMetadata(
absl::string_view opaque);

} // namespace jax::JAX_GPU_NAMESPACE

#endif // JAXLIB_GPU_TRITON_UTILS_H_

0 comments on commit 4ac2bdc

Please sign in to comment.