Skip to content

Commit

Permalink
Updating various tests to the latest changes.
Browse files Browse the repository at this point in the history
  • Loading branch information
benvanik committed Jul 30, 2024
1 parent 7f5d847 commit e900692
Show file tree
Hide file tree
Showing 26 changed files with 69 additions and 61 deletions.
1 change: 1 addition & 0 deletions compiler/plugins/input/Torch/InputConversion/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ void createTorchToIREEPipeline(
TorchInput::createConvertTMTensorToLinalgExtPass());
pm.addNestedPass<func::FuncOp>(torch::createConvertTorchToTensorPass());
pm.addNestedPass<func::FuncOp>(torch::createConvertTorchToLinalgPass());
pm.addNestedPass<func::FuncOp>(createCSEPass());
pm.addNestedPass<func::FuncOp>(torch::createConvertTorchToSCFPass());
pm.addNestedPass<func::FuncOp>(torch::createConvertTorchToArithPass());
pm.addPass(torch::createConvertTorchConversionToMLProgramPass());
Expand Down
1 change: 1 addition & 0 deletions compiler/plugins/target/LLVMCPU/test/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ iree_lit_test_suite(
name = "lit",
srcs = enforce_glob(
[
"materialize_homogeneous_encodings.mlir",
"smoketest_embedded.mlir",
"smoketest_system.mlir",
],
Expand Down
1 change: 1 addition & 0 deletions compiler/plugins/target/LLVMCPU/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ iree_lit_test_suite(
NAME
lit
SRCS
"materialize_homogeneous_encodings.mlir"
"smoketest_embedded.mlir"
"smoketest_system.mlir"
TOOLS
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// RUN: iree-opt --split-input-file --iree-hal-device-assignment-pipeline --iree-global-opt-materialize-homogeneous-encodings %s | FileCheck %s

#executable_target_embedded_elf_x86_64_ = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", {target_triple = "x86_64-none-elf", cpu_features = "+avx512f"}>
#map = affine_map<()[s0, s1] -> (-s1 + (s1 ceildiv s0) * s0)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
#map2 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
#device_target_llvm_cpu = #hal.device.target<"llvm-cpu", [#executable_target_embedded_elf_x86_64_]> : !hal.device
module attributes {hal.device.targets = [#device_target_llvm_cpu]} {
util.func public @lhs_encoding(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
%cst = arith.constant 0.000000e+00 : f32
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%dim = tensor.dim %arg0, %c0 : tensor<?x?xf32>
%dim_0 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
%0:2 = iree_encoding.upper_bound_tile_size tensor<?x?xf32, #iree_encoding.encoding<operand_index = 0, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map1, #map2, #map3]>> -> index, index
%1 = affine.apply #map()[%0#0, %dim]
%2 = affine.apply #map()[%0#1, %dim_0]
%padded = tensor.pad %arg0 low[0, 0] high[%1, %2] {
^bb0(%arg1: index, %arg2: index):
tensor.yield %cst : f32
} : tensor<?x?xf32> to tensor<?x?xf32>
%3 = iree_encoding.set_encoding %padded : tensor<?x?xf32> -> tensor<?x?xf32, #iree_encoding.encoding<operand_index = 0, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map1, #map2, #map3]>>
%4 = iree_encoding.unset_encoding %3 : tensor<?x?xf32, #iree_encoding.encoding<operand_index = 0, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map1, #map2, #map3]>> -> tensor<?x?xf32>
util.return %4 : tensor<?x?xf32>
}
}
// CHECK-LABEL: util.func public @lhs_encoding
// CHECK: tensor.pack
// CHECK: tensor.unpack
8 changes: 4 additions & 4 deletions compiler/plugins/target/ROCM/ROCMTarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ class ROCMTargetDevice final : public TargetDevice {
targetRegistry.getTargetBackend("rocm")->getDefaultExecutableTargets(
context, "rocm", configAttr, executableTargetAttrs);

return IREE::HAL::DeviceTargetAttr::get(context, b.getStringAttr("rocm"),
return IREE::HAL::DeviceTargetAttr::get(context, b.getStringAttr("hip"),
configAttr, executableTargetAttrs);
}

Expand All @@ -238,7 +238,7 @@ class ROCMTargetBackend final : public TargetBackend {
public:
ROCMTargetBackend(const ROCmOptions &options) : options(options) {}

std::string getLegacyDefaultDeviceID() const override { return "rocm"; }
std::string getLegacyDefaultDeviceID() const override { return "hip"; }

void getDefaultExecutableTargets(
MLIRContext *context, StringRef deviceID, DictionaryAttr deviceConfigAttr,
Expand Down Expand Up @@ -702,8 +702,8 @@ struct ROCMSession final
: PluginSession<ROCMSession, ROCmOptions,
PluginActivationPolicy::DefaultActivated> {
void populateHALTargetDevices(IREE::HAL::TargetDeviceList &targets) {
// #hal.device.target<"rocm", ...
targets.add("rocm",
// #hal.device.target<"hip", ...
targets.add("hip",
[&]() { return std::make_shared<ROCMTargetDevice>(options); });
}
void populateHALTargetBackends(IREE::HAL::TargetBackendList &targets) {
Expand Down
4 changes: 2 additions & 2 deletions compiler/plugins/target/ROCM/test/smoketest.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

module attributes {
hal.device.targets = [
#hal.device.target<"rocm", [
#hal.device.target<"hip", [
#hal.executable.target<"rocm", "rocm-hsaco-fb">
]> : !hal.device
]
Expand Down Expand Up @@ -46,7 +46,7 @@ stream.executable public @add_dispatch_0 {
#loc = loc(unknown)
module attributes {
hal.device.targets = [
#hal.device.target<"rocm", [
#hal.device.target<"hip", [
#hal.executable.target<"rocm", "rocm-hsaco-fb">
]> : !hal.device
]
Expand Down
8 changes: 4 additions & 4 deletions compiler/plugins/target/ROCM/test/target_device_features.mlir
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetDevices=rocm},iree-hal-transformation-pipeline{serialize-executables=false})' --iree-rocm-target-chip=mi300x %s | FileCheck %s --check-prefix=GFX942
// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetDevices=rocm},iree-hal-transformation-pipeline{serialize-executables=false})' --iree-rocm-target-chip=gfx940 %s | FileCheck %s --check-prefix=GFX940
// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetDevices=rocm},iree-hal-transformation-pipeline{serialize-executables=false})' --iree-rocm-target-chip=rx7900xtx %s | FileCheck %s --check-prefix=GFX1100
// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetDevices=rocm},iree-hal-transformation-pipeline{serialize-executables=false})' --iree-rocm-target-chip=gfx941 --iree-rocm-target-features=+sramecc,-xnack %s | FileCheck %s --check-prefix=GFX941
// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetDevices=hip},iree-hal-transformation-pipeline{serialize-executables=false})' --iree-rocm-target-chip=mi300x %s | FileCheck %s --check-prefix=GFX942
// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetDevices=hip},iree-hal-transformation-pipeline{serialize-executables=false})' --iree-rocm-target-chip=gfx940 %s | FileCheck %s --check-prefix=GFX940
// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetDevices=hip},iree-hal-transformation-pipeline{serialize-executables=false})' --iree-rocm-target-chip=rx7900xtx %s | FileCheck %s --check-prefix=GFX1100
// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetDevices=hip},iree-hal-transformation-pipeline{serialize-executables=false})' --iree-rocm-target-chip=gfx941 --iree-rocm-target-features=+sramecc,-xnack %s | FileCheck %s --check-prefix=GFX941

// GFX942: target = #iree_gpu.target<arch = "gfx942",
// GFX942-SAME: wgp = <compute = fp64|fp32|fp16|int64|int32|int16|int8, storage = b64|b32|b16|b8,
Expand Down
1 change: 1 addition & 0 deletions compiler/plugins/target/VulkanSPIRV/test/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ iree_lit_test_suite(
name = "lit",
srcs = enforce_glob(
[
"materialize_homogeneous_encodings.mlir",
"smoketest.mlir",
],
include = ["*.mlir"],
Expand Down
1 change: 1 addition & 0 deletions compiler/plugins/target/VulkanSPIRV/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ iree_lit_test_suite(
NAME
lit
SRCS
"materialize_homogeneous_encodings.mlir"
"smoketest.mlir"
TOOLS
FileCheck
Expand Down
Original file line number Diff line number Diff line change
@@ -1,36 +1,5 @@
// RUN: iree-opt --split-input-file --iree-hal-device-assignment-pipeline --iree-global-opt-materialize-homogeneous-encodings %s | FileCheck %s

#executable_target_embedded_elf_x86_64_ = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", {target_triple = "x86_64-none-elf", cpu_features = "+avx512f"}>
#map = affine_map<()[s0, s1] -> (-s1 + (s1 ceildiv s0) * s0)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
#map2 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
#device_target_llvm_cpu = #hal.device.target<"llvm-cpu", [#executable_target_embedded_elf_x86_64_]> : !hal.device
module attributes {hal.device.targets = [#device_target_llvm_cpu]} {
util.func public @lhs_encoding(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
%cst = arith.constant 0.000000e+00 : f32
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%dim = tensor.dim %arg0, %c0 : tensor<?x?xf32>
%dim_0 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
%0:2 = iree_encoding.upper_bound_tile_size tensor<?x?xf32, #iree_encoding.encoding<operand_index = 0, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map1, #map2, #map3]>> -> index, index
%1 = affine.apply #map()[%0#0, %dim]
%2 = affine.apply #map()[%0#1, %dim_0]
%padded = tensor.pad %arg0 low[0, 0] high[%1, %2] {
^bb0(%arg1: index, %arg2: index):
tensor.yield %cst : f32
} : tensor<?x?xf32> to tensor<?x?xf32>
%3 = iree_encoding.set_encoding %padded : tensor<?x?xf32> -> tensor<?x?xf32, #iree_encoding.encoding<operand_index = 0, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map1, #map2, #map3]>>
%4 = iree_encoding.unset_encoding %3 : tensor<?x?xf32, #iree_encoding.encoding<operand_index = 0, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map1, #map2, #map3]>> -> tensor<?x?xf32>
util.return %4 : tensor<?x?xf32>
}
}
// CHECK-LABEL: util.func public @lhs_encoding
// CHECK: tensor.pack
// CHECK: tensor.unpack

// -----

#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb">
#map = affine_map<()[s0, s1] -> (-s1 + (s1 ceildiv s0) * s0)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -565,9 +565,10 @@ ChangeStatus ValueProducerAffinityPVS::updateValue(Value value,
if (auto affinityOp =
dyn_cast_if_present<IREE::Stream::AffinityOpInterface>(
result.getDefiningOp())) {
auto &opPVS = solver.getElementFor<OpAffinityPVS>(
*this, Position::forOperation(result.getOwner()),
DFX::Resolution::OPTIONAL);
auto &opPVS = solver.getOrCreateElementFor<OpAffinityPVS>(
Position::forOperation(result.getOwner()), *this,
DFX::Resolution::OPTIONAL, /*forceUpdate=*/false,
/*updateAfterInit=*/false);
LLVM_DEBUG({
llvm::dbgs() << "[ValueProducerAffinityPVS] value ";
value.printAsOperand(llvm::dbgs(), solver.getAsmState());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ iree_lit_test_suite(
"global_loop_invariant_code_motion.mlir",
"hoist_into_globals.mlir",
"infer_numeric_narrowing.mlir",
"materialize_homogeneous_encodings.mlir",
"optimize_numerics.mlir",
"propagate_linalg_transpose.mlir",
"raise_special_ops.mlir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ iree_lit_test_suite(
"global_loop_invariant_code_motion.mlir"
"hoist_into_globals.mlir"
"infer_numeric_narrowing.mlir"
"materialize_homogeneous_encodings.mlir"
"optimize_numerics.mlir"
"propagate_linalg_transpose.mlir"
"raise_special_ops.mlir"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def SD3_CLIP_COMMON_RUN_FLAGS(
"--iree-opt-aggressively-propagate-transposes=true",
"--iree-codegen-llvmgpu-use-vector-distribution=true",
"--iree-execution-model=async-external",
"--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, util.func(iree-preprocessing-pad-to-intrinsics{pad-target-type=conv}))",
"--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline,iree-preprocessing-pad-to-intrinsics{pad-target-type=conv})",
]

###############################################################################
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def SD3_MMDIT_COMMON_RUN_FLAGS(
"--iree-codegen-llvmgpu-use-vector-distribution",
"--iree-rocm-waves-per-eu=2",
"--iree-execution-model=async-external",
"--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, util.func(iree-preprocessing-pad-to-intrinsics))",
"--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline,iree-preprocessing-pad-to-intrinsics)",
]

###############################################################################
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def SD3_VAE_COMMON_RUN_FLAGS(
"--iree-flow-enable-aggressive-fusion=true",
"--iree-codegen-llvmgpu-use-vector-distribution=true",
"--iree-execution-model=async-external",
"--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, util.func(iree-preprocessing-pad-to-intrinsics))",
"--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline,iree-preprocessing-pad-to-intrinsics)",
]

###############################################################################
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def SDXL_CLIP_COMMON_RUN_FLAGS(
"--iree-opt-aggressively-propagate-transposes=true",
"--iree-codegen-llvmgpu-use-vector-distribution=true",
"--iree-execution-model=async-external",
"--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, util.func(iree-preprocessing-pad-to-intrinsics{pad-target-type=conv}))",
"--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline,iree-preprocessing-pad-to-intrinsics{pad-target-type=conv})",
"--iree-scheduling-dump-statistics-format=json",
"--iree-scheduling-dump-statistics-file=compilation_info.json",
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def SDXL_UNET_COMMON_RUN_FLAGS(
"--iree-codegen-llvmgpu-use-vector-distribution",
"--iree-rocm-waves-per-eu=2",
"--iree-execution-model=async-external",
"--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, util.func(iree-preprocessing-pad-to-intrinsics))",
"--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline,iree-preprocessing-pad-to-intrinsics)",
"--iree-scheduling-dump-statistics-format=json",
"--iree-scheduling-dump-statistics-file=compilation_info.json",
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def SDXL_VAE_COMMON_RUN_FLAGS(
"--iree-flow-enable-aggressive-fusion=true",
"--iree-codegen-llvmgpu-use-vector-distribution=true",
"--iree-execution-model=async-external",
"--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, util.func(iree-preprocessing-pad-to-intrinsics))",
"--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline,iree-preprocessing-pad-to-intrinsics)",
"--iree-scheduling-dump-statistics-format=json",
"--iree-scheduling-dump-statistics-file=compilation_info.json",
]
Expand Down
1 change: 0 additions & 1 deletion runtime/src/iree/modules/check/test/success.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ func.func @floats() {
%p8 = arith.addf %p7, %cp1 : tensor<f32>
%p9 = arith.addf %p8, %cp1 : tensor<f32>
%approximately_1 = arith.addf %p9, %cp1 : tensor<f32>

check.expect_almost_eq(%approximately_1, %c1) : tensor<f32>
return
}
2 changes: 1 addition & 1 deletion samples/simple_embedding/device_vmvx_sync.c
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ iree_status_t create_sample_device(iree_allocator_t host_allocator,
iree_vm_instance_release(instance);

// Use the default host allocator for buffer allocations.
iree_string_view_t identifier = iree_make_cstring_view("vmvx");
iree_string_view_t identifier = iree_make_cstring_view("local-sync");
iree_hal_allocator_t* device_allocator = NULL;
if (iree_status_is_ok(status)) {
status = iree_hal_allocator_create_heap(identifier, host_allocator,
Expand Down
2 changes: 1 addition & 1 deletion samples/static_library/static_library_demo.c
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ iree_status_t create_device_with_static_loader(iree_allocator_t host_allocator,
&library_loader);

// Use the default host allocator for buffer allocations.
iree_string_view_t identifier = iree_make_cstring_view("sync");
iree_string_view_t identifier = iree_make_cstring_view("local-sync");
iree_hal_allocator_t* device_allocator = NULL;
if (iree_status_is_ok(status)) {
status = iree_hal_allocator_create_heap(identifier, host_allocator,
Expand Down
7 changes: 5 additions & 2 deletions tools/testing/e2e/iree-e2e-conv2d-test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -549,14 +549,17 @@ int main(int argc, char** argv) {
return EXIT_FAILURE;
}

// Run the tests. Note that some modules may be compiled for other platforms
// and not have the required architectures for execution within them - to keep
// the test runner dumber we gracefully fail those cases by returning success.
iree_status_t status = iree_test_utils_load_and_run_e2e_tests(
iree_allocator_system(), conv2d_test_module_create);
int exit_code = EXIT_SUCCESS;
if (!iree_status_is_ok(status)) {
iree_status_fprint(stderr, status);
bool is_unavailable = iree_status_is_unavailable(status);
bool is_device_unavailable = iree_status_is_not_found(status);
iree_status_free(status);
exit_code = is_unavailable ? EXIT_SUCCESS : EXIT_FAILURE;
exit_code = is_device_unavailable ? EXIT_SUCCESS : EXIT_FAILURE;
}

IREE_TRACE_APP_EXIT(exit_code);
Expand Down
7 changes: 5 additions & 2 deletions tools/testing/e2e/iree-e2e-matmul-test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -725,14 +725,17 @@ int main(int argc, char** argv) {
return EXIT_FAILURE;
}

// Run the tests. Note that some modules may be compiled for other platforms
// and not have the required architectures for execution within them - to keep
// the test runner dumber we gracefully fail those cases by returning success.
iree_status_t status = iree_test_utils_load_and_run_e2e_tests(
iree_allocator_system(), matmul_test_module_create);
int exit_code = EXIT_SUCCESS;
if (!iree_status_is_ok(status)) {
iree_status_fprint(stderr, status);
bool is_unavailable = iree_status_is_unavailable(status);
bool is_device_unavailable = iree_status_is_not_found(status);
iree_status_free(status);
exit_code = is_unavailable ? EXIT_SUCCESS : EXIT_FAILURE;
exit_code = is_device_unavailable ? EXIT_SUCCESS : EXIT_FAILURE;
}

IREE_TRACE_APP_EXIT(exit_code);
Expand Down
2 changes: 1 addition & 1 deletion tools/testing/e2e/test_utils.c
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ iree_status_t iree_test_utils_check_module_requirements(
return iree_make_status(
// The error status matters. We distinguish "feature not supported"
// which is a normal thing to happen from actual errors.
IREE_STATUS_UNAVAILABLE,
IREE_STATUS_NOT_FOUND,
"target device does not have the required feature '%.*s'",
(int)required_feature.size, required_feature.data);
}
Expand Down
2 changes: 1 addition & 1 deletion tools/testing/e2e/test_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ iree_status_t iree_test_utils_run_all_test_functions(
iree_allocator_t host_allocator);

// Returns OK if there are declared requirements on |module| and they are all
// met and otherwise UNAVAILABLE indicating that the module should not be run.
// met and otherwise NOT_FOUND indicating that the module should not be run.
iree_status_t iree_test_utils_check_module_requirements(
iree_vm_module_t* module);

Expand Down

0 comments on commit e900692

Please sign in to comment.