Skip to content

Commit

Permalink
[aot] Enable Metal AOT test (taichi-dev#7461)
Browse files Browse the repository at this point in the history
Issue: #

### Brief Summary

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
PENGUINLIONG and pre-commit-ci[bot] authored Mar 10, 2023
1 parent db9fdb1 commit e8babe3
Show file tree
Hide file tree
Showing 21 changed files with 272 additions and 9 deletions.
3 changes: 3 additions & 0 deletions c_api/src/taichi_metal_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ class MetalRuntime : public GfxRuntime {
taichi::lang::gfx::GfxRuntime &get_gfx_runtime() override;

taichi::lang::metal::MetalDevice &get_mtl();
virtual TiImage allocate_image(
const taichi::lang::ImageParams &params) override final;
virtual void free_image(TiImage image) override final;
};

} // namespace capi
Expand Down
11 changes: 11 additions & 0 deletions c_api/src/taichi_metal_impl.mm
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,17 @@
return *mtl_device_;
}

TiImage MetalRuntime::allocate_image(const taichi::lang::ImageParams &params) {
taichi::lang::DeviceAllocation devalloc =
get_gfx_runtime().create_image(params);
return devalloc2devimg(*this, devalloc);
}
void MetalRuntime::free_image(TiImage image) {
taichi::lang::DeviceAllocation devimg = devimg2devalloc(*this, image);
get_mtl().destroy_image(devimg);
get_gfx_runtime().untrack_image(devimg);
}

} // namespace capi

// -----------------------------------------------------------------------------
Expand Down
21 changes: 21 additions & 0 deletions c_api/tests/c_api_aot_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,13 @@ TEST_F(CapiTest, AotTestVulkanKernel) {
}
}

TEST_F(CapiTest, AotTestMetalKernel) {
if (ti::is_arch_available(TI_ARCH_METAL)) {
TiArch arch = TiArch::TI_ARCH_METAL;
kernel_aot_test(arch);
}
}

TEST_F(CapiTest, AotTestOpenglKernel) {
if (ti::is_arch_available(TI_ARCH_OPENGL)) {
TiArch arch = TiArch::TI_ARCH_OPENGL;
Expand All @@ -202,6 +209,13 @@ TEST_F(CapiTest, GraphTestVulkanTextureKernel) {
}
}

TEST_F(CapiTest, GraphTestMetalTextureKernel) {
if (ti::is_arch_available(TI_ARCH_METAL)) {
TiArch arch = TiArch::TI_ARCH_METAL;
texture_aot_kernel_test(arch);
}
}

TEST_F(CapiTest, AotTestCudaSharedArray) {
if (ti::is_arch_available(TI_ARCH_CUDA)) {
TiArch arch = TiArch::TI_ARCH_CUDA;
Expand All @@ -215,3 +229,10 @@ TEST_F(CapiTest, AotTestVulkanSharedArray) {
shared_array_aot_test(arch);
}
}

TEST_F(CapiTest, AotTestMetalSharedArray) {
if (ti::is_arch_available(TI_ARCH_METAL)) {
TiArch arch = TiArch::TI_ARCH_METAL;
shared_array_aot_test(arch);
}
}
15 changes: 15 additions & 0 deletions c_api/tests/c_api_cgraph_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,28 @@ TEST_F(CapiTest, GraphTestVulkanGraph) {
graph_aot_test(arch);
}
}

TEST_F(CapiTest, GraphTestMetalGraph) {
if (ti::is_arch_available(TI_ARCH_METAL)) {
TiArch arch = TiArch::TI_ARCH_METAL;
graph_aot_test(arch);
}
}

TEST_F(CapiTest, GraphTestVulkanTextureGraph) {
if (ti::is_arch_available(TI_ARCH_VULKAN)) {
TiArch arch = TiArch::TI_ARCH_VULKAN;
texture_aot_test(arch);
}
}

TEST_F(CapiTest, GraphTestMetalTextureGraph) {
if (ti::is_arch_available(TI_ARCH_METAL)) {
TiArch arch = TiArch::TI_ARCH_METAL;
texture_aot_test(arch);
}
}

TEST_F(CapiTest, GraphTestOpenglGraph) {
if (ti::is_arch_available(TI_ARCH_OPENGL)) {
TiArch arch = TiArch::TI_ARCH_OPENGL;
Expand Down
137 changes: 134 additions & 3 deletions c_api/tests/c_api_interface_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,13 @@ TEST_F(CapiTest, DryRunRuntime) {
runtime.destroy();
}

if (ti::is_arch_available(TI_ARCH_METAL)) {
// Metal Runtime
TiArch arch = TiArch::TI_ARCH_METAL;
ti::Runtime runtime(arch);
runtime.destroy();
}

if (ti::is_arch_available(TI_ARCH_CUDA)) {
// CUDA Runtime
TiArch arch = TiArch::TI_ARCH_CUDA;
Expand Down Expand Up @@ -129,6 +136,14 @@ TEST_F(CapiTest, DryRunMemoryAllocation) {
ti::NdArray<uint8_t> ndarray = runtime.allocate_ndarray<uint8_t>({100}, {});
}

if (ti::is_arch_available(TI_ARCH_METAL)) {
// Vulkan Runtime
TiArch arch = TiArch::TI_ARCH_METAL;
ti::Runtime runtime(arch);
ti::Memory memory = runtime.allocate_memory(100);
ti::NdArray<uint8_t> ndarray = runtime.allocate_ndarray<uint8_t>({100}, {});
}

if (ti::is_arch_available(TI_ARCH_OPENGL)) {
// Opengl Runtime
TiArch arch = TiArch::TI_ARCH_OPENGL;
Expand Down Expand Up @@ -158,6 +173,18 @@ TEST_F(CapiTest, FailMapDeviceOnlyMemory) {
"Mapping Memory without VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT set",
/*reset_error=*/true);
}

if (ti::is_arch_available(TI_ARCH_METAL)) {
ti::Runtime runtime(TI_ARCH_METAL);

ti::Memory mem = runtime.allocate_memory(100);
mem.map();

EXPECT_TAICHI_ERROR(
TI_ERROR_INVALID_STATE,
"Mapping Memory without VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT set",
/*reset_error=*/true);
}
}

TEST_F(CapiTest, FailOutOfRangeReadWrite) {
Expand All @@ -174,6 +201,20 @@ TEST_F(CapiTest, FailOutOfRangeReadWrite) {

EXPECT_TAICHI_ERROR(TI_ERROR_ARGUMENT_OUT_OF_RANGE);
}

if (ti::is_arch_available(TI_ARCH_METAL)) {
ti::Runtime runtime(TI_ARCH_METAL);

std::vector<float> data(101);
ti::NdArray<float> arr = runtime.allocate_ndarray<float>({50}, {2});

TI_ASSERT(arr.elem_count() == 50);
TI_ASSERT(arr.scalar_count() == 100);

arr.write(data);

EXPECT_TAICHI_ERROR(TI_ERROR_ARGUMENT_OUT_OF_RANGE);
}
}

TEST_F(CapiTest, DryRunImageAllocation) {
Expand All @@ -186,6 +227,16 @@ TEST_F(CapiTest, DryRunImageAllocation) {
runtime.allocate_texture2d(4, 4, TI_FORMAT_RGBA8, TI_NULL_HANDLE);
}
}

if (ti::is_arch_available(TI_ARCH_METAL)) {
{
// Vulkan Runtime
TiArch arch = TiArch::TI_ARCH_METAL;
ti::Runtime runtime(arch);
ti::Texture texture =
runtime.allocate_texture2d(4, 4, TI_FORMAT_RGBA8, TI_NULL_HANDLE);
}
}
}

TEST_F(CapiTest, DryRunVulkanAotModule) {
Expand All @@ -204,6 +255,22 @@ TEST_F(CapiTest, DryRunVulkanAotModule) {
}
}

TEST_F(CapiTest, DryRunMetalAotModule) {
if (ti::is_arch_available(TI_ARCH_METAL)) {
const auto folder_dir = getenv("TAICHI_AOT_FOLDER_PATH");

std::stringstream aot_mod_ss;
aot_mod_ss << folder_dir;

{
// Vulkan Runtime
TiArch arch = TiArch::TI_ARCH_METAL;
ti::Runtime runtime(arch);
ti::AotModule aot_mod = runtime.load_aot_module(aot_mod_ss.str());
}
}
}

TEST_F(CapiTest, DryRunOpenglAotModule) {
if (ti::is_arch_available(TI_ARCH_OPENGL)) {
const auto folder_dir = getenv("TAICHI_AOT_FOLDER_PATH");
Expand All @@ -212,7 +279,7 @@ TEST_F(CapiTest, DryRunOpenglAotModule) {
aot_mod_ss << folder_dir;

{
// Vulkan Runtime
// OpenGL Runtime
TiArch arch = TiArch::TI_ARCH_OPENGL;
ti::Runtime runtime(arch);

Expand All @@ -221,7 +288,7 @@ TEST_F(CapiTest, DryRunOpenglAotModule) {
}
}

TEST_F(CapiTest, TestLoadTcmAotModule) {
TEST_F(CapiTest, TestLoadTcmAotModuleVulkan) {
if (ti::is_arch_available(TI_ARCH_VULKAN)) {
const auto folder_dir = getenv("TAICHI_AOT_FOLDER_PATH");

Expand All @@ -248,7 +315,34 @@ TEST_F(CapiTest, TestLoadTcmAotModule) {
}
}

TEST_F(CapiTest, TestCreateTcmAotModule) {
TEST_F(CapiTest, TestLoadTcmAotModuleMetal) {
if (ti::is_arch_available(TI_ARCH_METAL)) {
const auto folder_dir = getenv("TAICHI_AOT_FOLDER_PATH");

std::stringstream aot_mod_ss;
aot_mod_ss << folder_dir << "/module.tcm";

{
// Metal Runtime
TiArch arch = TiArch::TI_ARCH_METAL;
ti::Runtime runtime(arch);
ti::AotModule aot_mod = runtime.load_aot_module(aot_mod_ss.str());
ti::Kernel run = aot_mod.get_kernel("run");
ti::NdArray<int32_t> arr =
runtime.allocate_ndarray<int32_t>({16}, {}, true);
run[0] = arr;
run.launch();
runtime.wait();
std::vector<int32_t> data(16);
arr.read(data);
for (int32_t i = 0; i < 16; ++i) {
TI_ASSERT(data.at(i) == i);
}
}
}
}

TEST_F(CapiTest, TestCreateTcmAotModuleVulkan) {
if (ti::is_arch_available(TI_ARCH_VULKAN)) {
const auto folder_dir = getenv("TAICHI_AOT_FOLDER_PATH");

Expand Down Expand Up @@ -284,3 +378,40 @@ TEST_F(CapiTest, TestCreateTcmAotModule) {
}
}
}

TEST_F(CapiTest, TestCreateTcmAotModuleMetal) {
if (ti::is_arch_available(TI_ARCH_METAL)) {
const auto folder_dir = getenv("TAICHI_AOT_FOLDER_PATH");

std::stringstream aot_mod_ss;
aot_mod_ss << folder_dir << "/module.tcm";

std::vector<uint8_t> tcm;
{
std::fstream f(aot_mod_ss.str(),
std::ios::in | std::ios::binary | std::ios::ate);
TI_ASSERT(f.is_open());
tcm.resize(f.tellg());
f.seekg(std::ios::beg);
f.read((char *)tcm.data(), tcm.size());
}

{
// Vulkan Runtime
TiArch arch = TiArch::TI_ARCH_METAL;
ti::Runtime runtime(arch);
ti::AotModule aot_mod = runtime.create_aot_module(tcm);
ti::Kernel run = aot_mod.get_kernel("run");
ti::NdArray<int32_t> arr =
runtime.allocate_ndarray<int32_t>({16}, {}, true);
run[0] = arr;
run.launch();
runtime.wait();
std::vector<int32_t> data(16);
arr.read(data);
for (int32_t i = 0; i < 16; ++i) {
TI_ASSERT(data.at(i) == i);
}
}
}
}
13 changes: 13 additions & 0 deletions c_api/tests/mpm88_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,19 @@ TEST_F(CapiTest, Mpm88TestVulkan) {
}
}

TEST_F(CapiTest, Mpm88TestMetal) {
if (ti::is_arch_available(TI_ARCH_METAL)) {
const auto folder_dir = getenv("TAICHI_AOT_FOLDER_PATH");

std::stringstream aot_mod_ss;
aot_mod_ss << folder_dir;

auto impl = std::make_unique<demo::MPM88DemoImpl>(aot_mod_ss.str().c_str(),
TiArch::TI_ARCH_METAL);
impl->Step();
}
}

TEST_F(CapiTest, Mpm88TestOpengl) {
if (ti::is_arch_available(TI_ARCH_OPENGL)) {
const auto folder_dir = getenv("TAICHI_AOT_FOLDER_PATH");
Expand Down
11 changes: 11 additions & 0 deletions c_api/tests/sph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,17 @@ TEST_F(CapiTest, SphTestVulkan) {
}
}

TEST_F(CapiTest, SphTestMetal) {
if (ti::is_arch_available(TI_ARCH_METAL)) {
const auto folder_dir = getenv("TAICHI_AOT_FOLDER_PATH");

std::stringstream aot_mod_ss;
aot_mod_ss << folder_dir;

run(TiArch::TI_ARCH_METAL, aot_mod_ss.str());
}
}

TEST_F(CapiTest, SphTestOpengl) {
if (ti::is_arch_available(TI_ARCH_OPENGL)) {
const auto folder_dir = getenv("TAICHI_AOT_FOLDER_PATH");
Expand Down
3 changes: 2 additions & 1 deletion taichi/rhi/metal/metal_device.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class MetalDevice;
struct MetalMemory {
public:
// `mtl_buffer` should be already retained.
explicit MetalMemory(MTLBuffer_id mtl_buffer);
explicit MetalMemory(MTLBuffer_id mtl_buffer, bool host_access);
~MetalMemory();

void dont_destroy();
Expand All @@ -54,6 +54,7 @@ struct MetalMemory {

private:
MTLBuffer_id mtl_buffer_;
bool can_map_{false};
bool dont_destroy_{false};
};

Expand Down
11 changes: 8 additions & 3 deletions taichi/rhi/metal/metal_device.mm
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
namespace taichi::lang {
namespace metal {

MetalMemory::MetalMemory(MTLBuffer_id mtl_buffer) : mtl_buffer_(mtl_buffer) {}
MetalMemory::MetalMemory(MTLBuffer_id mtl_buffer, bool can_map)
: mtl_buffer_(mtl_buffer), can_map_(can_map) {}
MetalMemory::~MetalMemory() {
if (!dont_destroy_) {
[mtl_buffer_ release];
Expand All @@ -19,6 +20,9 @@
MTLBuffer_id MetalMemory::mtl_buffer() const { return mtl_buffer_; }
size_t MetalMemory::size() const { return (size_t)[mtl_buffer_ length]; }
RhiResult MetalMemory::mapped_ptr(void **mapped_ptr) const {
if (!can_map_) {
return RhiResult::invalid_usage;
}
void *ptr = [mtl_buffer_ contents];
if (ptr == nullptr) {
return RhiResult::invalid_usage;
Expand Down Expand Up @@ -541,15 +545,16 @@ DeviceCapabilityConfig collect_metal_device_caps(MTLDevice_id mtl_device) {
MTLBuffer_id buffer = [mtl_device_ newBufferWithLength:params.size
options:resource_options];

MetalMemory &alloc = memory_allocs_.acquire(buffer);
MetalMemory &alloc = memory_allocs_.acquire(buffer, can_map);

DeviceAllocation out{};
out.device = this;
out.alloc_id = reinterpret_cast<uint64_t>(&alloc);
return out;
}
DeviceAllocation MetalDevice::import_mtl_buffer(MTLBuffer_id buffer) {
MetalMemory &alloc = memory_allocs_.acquire(buffer);
bool can_map = [buffer contents] != nullptr;
MetalMemory &alloc = memory_allocs_.acquire(buffer, can_map);
alloc.dont_destroy();

DeviceAllocation out{};
Expand Down
Loading

0 comments on commit e8babe3

Please sign in to comment.