Skip to content

Commit

Permalink
[metal] Choose the proper msl version according to the device capabil…
Browse files Browse the repository at this point in the history
…ity (taichi-dev#7506)

Issue: #

### Brief Summary
Subgroups are only supported in Metal 2.1 and up, we upgrade the default
msl version to 2.1.0 to enable subgroup functionalities, otherwise,
metal crashes when launching subgroups-involved kernels.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
YuCrazing and pre-commit-ci[bot] authored Mar 11, 2023
1 parent e93e076 commit fcd4676
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 7 deletions.
3 changes: 2 additions & 1 deletion taichi/rhi/metal/metal_device.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,8 @@ class MetalPipeline final : public Pipeline {

static MetalPipeline *create(const MetalDevice &device,
const uint32_t *spv_data,
size_t spv_size);
size_t spv_size,
const std::string &name);
void destroy();

inline MTLComputePipelineState_id mtl_compute_pipeline_state() const {
Expand Down
37 changes: 31 additions & 6 deletions taichi/rhi/metal/metal_device.mm
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,40 @@
workgroup_size_(workgroup_size) {}
MetalPipeline::~MetalPipeline() { destroy(); }
MetalPipeline *MetalPipeline::create(const MetalDevice &device,
const uint32_t *spv_data,
size_t spv_size) {
const uint32_t *spv_data, size_t spv_size,
const std::string &name) {
RHI_ASSERT((size_t)spv_data % sizeof(uint32_t) == 0);
RHI_ASSERT(spv_size % sizeof(uint32_t) == 0);
spirv_cross::CompilerMSL compiler(spv_data, spv_size / sizeof(uint32_t));
spirv_cross::CompilerMSL::Options options{};
options.enable_decoration_binding = true;

// Choose a proper msl version according to the device capability.
DeviceCapabilityConfig caps = device.get_caps();
bool feature_simd_scoped_permute_operations =
caps.contains(DeviceCapability::spirv_has_subgroup_vote) ||
caps.contains(DeviceCapability::spirv_has_subgroup_ballot);
bool feature_simd_scoped_reduction_operations =
caps.contains(DeviceCapability::spirv_has_subgroup_arithmetic);

if (feature_simd_scoped_permute_operations ||
feature_simd_scoped_reduction_operations) {
// Subgroups are only supported in Metal 2.1 and up.
options.set_msl_version(2, 1, 0);
}

compiler.set_msl_options(options);
std::string msl = compiler.compile();

std::string msl = "";
try {
msl = compiler.compile();
} catch (const spirv_cross::CompilerError &e) {
std::array<char, 4096> msgbuf;
snprintf(msgbuf.data(), msgbuf.size(), "(spirv-cross compiler) %s: %s",
name.c_str(), e.what());
RHI_LOG_ERROR(msgbuf.data());
return nullptr;
}

MTLLibrary_id mtl_library = nil;
{
Expand All @@ -96,7 +121,7 @@
{
NSString *entry_name_ns = [[NSString alloc] initWithUTF8String:"main0"];
mtl_function = [mtl_library newFunctionWithName:entry_name_ns];
if (mtl_library == nil) {
if (mtl_function == nil) {
// FIXME: (penguinliong) Specify the actual entry name after we compile
// directly to MSL in codegen.
RHI_LOG_ERROR(
Expand Down Expand Up @@ -721,8 +746,8 @@ MTLTextureUsage usage2mtl(ImageAllocUsage usage) {
PipelineCache *cache) noexcept {
RHI_ASSERT(src.type == PipelineSourceType::spirv_binary);
try {
*out_pipeline =
MetalPipeline::create(*this, (const uint32_t *)src.data, src.size);
*out_pipeline = MetalPipeline::create(*this, (const uint32_t *)src.data,
src.size, name);
} catch (const std::exception &e) {
return RhiResult::error;
}
Expand Down

0 comments on commit fcd4676

Please sign in to comment.