Skip to content

Commit 5a9e7a4

Browse files
authoredDec 3, 2024··
Add a simple multi-threaded test
Differential Revision: D65162642 Pull Request resolved: #7143
1 parent b4eda5f commit 5a9e7a4

File tree

6 files changed

+294
-19
lines changed

6 files changed

+294
-19
lines changed
 

‎backends/test/README.md

Whitespace-only changes.

‎backends/test/TARGETS

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# Any targets that should be shared between fbcode and xplat must be defined in
2+
# targets.bzl. This file can contain fbcode-only targets.
3+
4+
load(":targets.bzl", "define_common_targets")
5+
6+
oncall("executorch")
7+
8+
define_common_targets(is_fbcode = True)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
#include <gtest/gtest.h>
2+
3+
#include <iostream>
4+
#include <string>
5+
#include <thread>
6+
#include <vector>
7+
8+
#include <executorch/runtime/executor/program.h>
9+
#include <executorch/runtime/platform/runtime.h>
10+
11+
#include <executorch/extension/data_loader/file_data_loader.h>
12+
#include <executorch/extension/memory_allocator/malloc_memory_allocator.h>
13+
#include <executorch/extension/runner_util/inputs.h>
14+
15+
using executorch::runtime::Error;
16+
using executorch::runtime::EValue;
17+
using executorch::runtime::HierarchicalAllocator;
18+
using executorch::runtime::MemoryManager;
19+
using executorch::runtime::Method;
20+
using executorch::runtime::MethodMeta;
21+
using executorch::runtime::Program;
22+
using executorch::runtime::Result;
23+
using executorch::runtime::Span;
24+
25+
using executorch::extension::FileDataLoader;
26+
using executorch::extension::MallocMemoryAllocator;
27+
using executorch::extension::prepare_input_tensors;
28+
29+
/*
30+
* Backend agnostic base class.
31+
*/
32+
class ETPTEMethodRunBaseTest : public ::testing::Test {
33+
protected:
34+
void SetUp() override {
35+
executorch::runtime::runtime_init();
36+
}
37+
38+
// Runs the PTE e2e without using outside resources.
39+
// This will run in a single thread.
40+
// TODO(T208989128) - Add Synchronizer based run method.
41+
void run(
42+
const int id,
43+
const std::string& kTestPTEPath,
44+
const std::string& kMethodName,
45+
std::atomic<size_t>& count) const {
46+
Result<FileDataLoader> loader = FileDataLoader::from(kTestPTEPath.c_str());
47+
ASSERT_EQ(loader.error(), Error::Ok);
48+
49+
Result<Program> program = Program::load(
50+
&loader.get(), Program::Verification::InternalConsistency);
51+
ASSERT_EQ(program.error(), Error::Ok);
52+
53+
Result<MethodMeta> method_meta = program->method_meta(kMethodName.c_str());
54+
ASSERT_EQ(method_meta.error(), Error::Ok);
55+
56+
const size_t num_memory_planned_buffers =
57+
method_meta->num_memory_planned_buffers();
58+
59+
std::vector<std::unique_ptr<uint8_t[]>> planned_buffers;
60+
std::vector<Span<uint8_t>> planned_spans;
61+
for (size_t i = 0; i < num_memory_planned_buffers; ++i) {
62+
const size_t buffer_size =
63+
static_cast<size_t>(method_meta->memory_planned_buffer_size(i).get());
64+
planned_buffers.push_back(std::make_unique<uint8_t[]>(buffer_size));
65+
planned_spans.push_back({planned_buffers.back().get(), buffer_size});
66+
}
67+
68+
auto method_allocator = std::make_unique<MallocMemoryAllocator>();
69+
auto memory_planned_allocator = std::make_unique<HierarchicalAllocator>(
70+
Span(planned_spans.data(), planned_spans.size()));
71+
auto temp_allocator = std::make_unique<MallocMemoryAllocator>();
72+
73+
auto memory_manager = std::make_unique<MemoryManager>(
74+
method_allocator.get(),
75+
memory_planned_allocator.get(),
76+
temp_allocator.get());
77+
78+
Result<Method> method =
79+
program->load_method(kMethodName.c_str(), memory_manager.get());
80+
ASSERT_EQ(method.error(), Error::Ok);
81+
82+
auto inputs = prepare_input_tensors(*method);
83+
ASSERT_EQ(inputs.error(), Error::Ok);
84+
85+
Error err = method->execute();
86+
for (int i = 0; i < id % 7; i++) {
87+
err = method->execute();
88+
ASSERT_EQ(err, Error::Ok);
89+
}
90+
91+
std::vector<EValue> outputs(method->outputs_size());
92+
err = method->get_outputs(outputs.data(), outputs.size());
93+
ET_CHECK(err == Error::Ok);
94+
// TODO(T208989129) - Add validation of outputs using bundled
95+
// inputs/outputs.
96+
count++;
97+
}
98+
};
99+
100+
class XNNPACKMultiDelegateTest : public ETPTEMethodRunBaseTest {
101+
protected:
102+
std::string kTestPTE1Path, kTestPTE2Path;
103+
std::string kMethodName;
104+
int num_threads;
105+
106+
void SetUp() override {
107+
ETPTEMethodRunBaseTest::SetUp();
108+
const char* pte1_path =
109+
std::getenv("ET_XNNPACK_GENERATED_ADD_LARGE_PTE_PATH");
110+
if (pte1_path == nullptr) {
111+
std::cerr << "ET_XNNPACK_GENERATED_ADD_LARGE_PTE_PATH is not set"
112+
<< std::endl;
113+
FAIL();
114+
}
115+
kTestPTE1Path = std::string(pte1_path);
116+
117+
const char* pte2_path =
118+
std::getenv("ET_XNNPACK_GENERATED_SUB_LARGE_PTE_PATH");
119+
if (pte1_path == nullptr) {
120+
std::cerr << "ET_XNNPACK_GENERATED_SUB_LARGE_PTE_PATH is not set"
121+
<< std::endl;
122+
FAIL();
123+
}
124+
kTestPTE2Path = std::string(pte2_path);
125+
126+
num_threads = 40;
127+
kMethodName = "forward";
128+
}
129+
};
130+
131+
// This test is to validate the assumption that the delegate is thread safe.
132+
// That includes the following:
133+
// 1. The delegate can be initilized by multiple threads in parallel.
134+
// 2. The delegate can be executed by multiple threads in parallel.
135+
// 3. The delegate can be destroyed by multiple threads in parallel.
136+
// Regardless of the underlying implementation of the delegate.
137+
// This is particularly important when we have shared resources across
138+
// delegate instances through a singleton backend instance.
139+
TEST_F(XNNPACKMultiDelegateTest, MultipleThreads) {
140+
ASSERT_NE(kTestPTE1Path.size(), 0);
141+
ASSERT_NE(kTestPTE2Path.size(), 0);
142+
ASSERT_NE(num_threads, 0);
143+
ASSERT_NE(kMethodName.size(), 0);
144+
145+
std::vector<std::thread> threads(num_threads);
146+
std::atomic<size_t> count{0};
147+
148+
for (int i = 0; i < num_threads; i++) {
149+
threads[i] = std::thread([&, i]() {
150+
run(i, i % 7 ? kTestPTE1Path : kTestPTE2Path, kMethodName, count);
151+
});
152+
}
153+
for (int i = 0; i < num_threads; i++) {
154+
threads[i].join();
155+
}
156+
ASSERT_EQ(count, num_threads);
157+
}
158+
159+
// TODO(T208989291): Add more tests here. For example,
160+
// - PTEs with multiple methods
161+
// - PTEs with proucer and consumer relationships in different threads
162+
// - PTEs with more than 1 delegate instances
163+
// - PTEs with different type of delegate instances
164+
// - Add more patterns of delegate initialization and execution

‎backends/test/targets.bzl

+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
3+
def define_common_targets(is_fbcode = False):
4+
"""Defines targets that should be shared between fbcode and xplat.
5+
6+
The directory containing this targets.bzl file should also contain both
7+
TARGETS and BUCK files that call this function.
8+
"""
9+
if not runtime.is_oss and is_fbcode:
10+
modules_env = {
11+
"ET_XNNPACK_GENERATED_ADD_LARGE_PTE_PATH": "$(location fbcode//executorch/test/models:exported_xnnp_delegated_programs[ModuleAddLarge.pte])",
12+
"ET_XNNPACK_GENERATED_SUB_LARGE_PTE_PATH": "$(location fbcode//executorch/test/models:exported_xnnp_delegated_programs[ModuleSubLarge.pte])",
13+
}
14+
15+
runtime.cxx_test(
16+
name = "multi_method_delegate_test",
17+
srcs = [
18+
"multi_method_delegate_test.cpp",
19+
],
20+
deps = [
21+
"//executorch/runtime/executor:program",
22+
"//executorch/extension/data_loader:file_data_loader",
23+
"//executorch/extension/memory_allocator:malloc_memory_allocator",
24+
"//executorch/kernels/portable:generated_lib",
25+
"//executorch/backends/xnnpack:xnnpack_backend",
26+
"//executorch/extension/runner_util:inputs",
27+
],
28+
env = modules_env,
29+
)

‎test/models/export_delegated_program.py

+69-19
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import executorch.exir as exir
1414

1515
import torch
16-
from executorch.exir import to_edge
16+
from executorch.exir import EdgeCompileConfig, to_edge, to_edge_transform_and_lower
1717
from executorch.exir.backend.backend_api import to_backend
1818
from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult
1919
from executorch.exir.backend.test.backend_with_compiler_demo import (
@@ -52,6 +52,41 @@ def get_random_inputs(self) -> Sequence[torch.Tensor]:
5252
return (torch.ones(2, 2), 2 * torch.ones(2, 2), 3 * torch.ones(2, 2))
5353

5454

55+
class ModuleAddLarge(nn.Module):
56+
def __init__(self):
57+
super().__init__()
58+
59+
def forward(
60+
self, a: torch.Tensor, b: torch.Tensor, c: torch.Tensor
61+
) -> torch.Tensor:
62+
x: torch.Tensor = torch.add(a, b)
63+
y: torch.Tensor = torch.add(x, c)
64+
z: torch.Tensor = torch.add(x, y)
65+
return z
66+
67+
def get_random_inputs(self) -> Sequence[torch.Tensor]:
68+
n = 10 # to create a large tensor
69+
return (torch.ones(n, n, n), 2 * torch.ones(n, n, n), 3 * torch.ones(n, n, n))
70+
71+
72+
class ModuleSubLarge(nn.Module):
73+
def __init__(self):
74+
super().__init__()
75+
76+
def forward(
77+
self, a: torch.Tensor, b: torch.Tensor, c: torch.Tensor
78+
) -> torch.Tensor:
79+
x: torch.Tensor = torch.sub(a, b)
80+
y: torch.Tensor = torch.sub(x, c)
81+
z: torch.Tensor = torch.sub(x, y)
82+
w: torch.Tensor = torch.sub(z, c)
83+
return w
84+
85+
def get_random_inputs(self) -> Sequence[torch.Tensor]:
86+
n = 10 # to create a large tensor
87+
return (torch.ones(n, n, n), 2 * torch.ones(n, n, n), 3 * torch.ones(n, n, n))
88+
89+
5590
#
5691
# Backends
5792
#
@@ -95,30 +130,45 @@ def __init__(self, fn):
95130
def forward(self, *args, **kwargs):
96131
return self.fn(*args, **kwargs)
97132

98-
edge: exir.EdgeProgramManager = to_edge(
99-
export(WrapperModule(getattr(eager_module, method)), args=inputs)
133+
exported_program = export(WrapperModule(getattr(eager_module, method)), args=inputs)
134+
135+
edge_config = EdgeCompileConfig(_check_ir_validity=False)
136+
et_config = exir.ExecutorchBackendConfig(
137+
extract_delegate_segments=extract_delegate_segments,
138+
constant_tensor_alignment=constant_tensor_alignemnt,
139+
delegate_alignment=delegate_alignment,
100140
)
101141

102-
lowered_module = to_backend(backend_id, edge.exported_program(), compile_specs=[])
142+
if backend_id == "XnnpackBackend":
143+
from executorch.backends.xnnpack.partition.xnnpack_partitioner import (
144+
XnnpackPartitioner,
145+
)
103146

104-
class CompositeModule(nn.Module):
105-
def __init__(self):
106-
super().__init__()
107-
self.lowered_module = lowered_module
147+
executorch_program = to_edge_transform_and_lower(
148+
exported_program,
149+
compile_config=edge_config,
150+
partitioner=[XnnpackPartitioner()],
151+
).to_executorch(config=et_config)
152+
else:
153+
edge: exir.EdgeProgramManager = to_edge(exported_program)
154+
lowered_module = to_backend(
155+
backend_id, edge.exported_program(), compile_specs=[]
156+
)
108157

109-
def forward(self, *args, **kwargs):
110-
return self.lowered_module(*args, **kwargs)
158+
class CompositeModule(nn.Module):
159+
def __init__(self):
160+
super().__init__()
161+
self.lowered_module = lowered_module
111162

112-
composite_module = CompositeModule()
113-
composite_module(*inputs)
163+
def forward(self, *args, **kwargs):
164+
return self.lowered_module(*args, **kwargs)
114165

115-
executorch_program = to_edge(export(composite_module, args=inputs)).to_executorch(
116-
config=exir.ExecutorchBackendConfig(
117-
extract_delegate_segments=extract_delegate_segments,
118-
constant_tensor_alignment=constant_tensor_alignemnt,
119-
delegate_alignment=delegate_alignment,
120-
)
121-
)
166+
composite_module = CompositeModule()
167+
composite_module(*inputs)
168+
169+
executorch_program = to_edge(
170+
export(composite_module, args=inputs)
171+
).to_executorch(config=et_config)
122172

123173
return executorch_program.buffer
124174

‎test/models/targets.bzl

+24
Original file line numberDiff line numberDiff line change
@@ -117,13 +117,17 @@ def define_common_targets():
117117
par_style = "xar",
118118
deps = [
119119
":export_delegated_program_lib",
120+
"//executorch/backends/xnnpack/partition:xnnpack_partitioner",
121+
120122
],
121123
visibility = [], # Private
122124
)
123125

124126
# Class names of nn.Modules for :exported_delegated_programs to export.
125127
DELEGATED_MODULES_TO_EXPORT = [
126128
"ModuleAddMul",
129+
"ModuleAddLarge",
130+
"ModuleSubLarge",
127131
]
128132

129133
# Name of the backend to use when exporting delegated programs.
@@ -153,3 +157,23 @@ def define_common_targets():
153157
"//executorch/test/...",
154158
],
155159
)
160+
161+
runtime.genrule(
162+
name = "exported_xnnp_delegated_programs",
163+
cmd = "$(exe :export_delegated_program)" +
164+
" --modules " + ",".join(DELEGATED_MODULES_TO_EXPORT) +
165+
" --backend_id " + "XnnpackBackend" +
166+
" --outdir $OUT",
167+
outs = {
168+
fname + ".pte": [fname + ".pte"]
169+
for fname in DELEGATED_MODULES_TO_EXPORT
170+
},
171+
default_outs = ["."],
172+
visibility = [
173+
"//executorch/runtime/executor/test/...",
174+
"//executorch/backends/test/...",
175+
"//executorch/test/...",
176+
"@EXECUTORCH_CLIENTS",
177+
],
178+
env = {"PYTORCH_DISABLE_JUSTKNOBS": "1",},
179+
)

0 commit comments

Comments
 (0)
Please sign in to comment.