-
Notifications
You must be signed in to change notification settings - Fork 512
/
Copy pathbackend_integration_test.cpp
725 lines (629 loc) · 24.6 KB
/
backend_integration_test.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <cstdlib>
#include <cstring>
#include <functional>
#include <optional>
#include <vector>
#include <executorch/extension/data_loader/buffer_data_loader.h>
#include <executorch/extension/data_loader/file_data_loader.h>
#include <executorch/extension/runner_util/inputs.h>
#include <executorch/runtime/backend/interface.h>
#include <executorch/runtime/core/error.h>
#include <executorch/runtime/core/result.h>
#include <executorch/runtime/executor/method.h>
#include <executorch/runtime/executor/program.h>
#include <executorch/runtime/executor/test/managed_memory_manager.h>
#include <executorch/runtime/platform/runtime.h>
#include <executorch/test/utils/DeathTest.h>
#include <executorch/test/utils/alignment.h>
#include <gtest/gtest.h>
using namespace ::testing;
using executorch::aten::ArrayRef;
using executorch::runtime::BackendExecutionContext;
using executorch::runtime::BackendInitContext;
using executorch::runtime::BackendInterface;
using executorch::runtime::CompileSpec;
using executorch::runtime::DataLoader;
using executorch::runtime::DelegateHandle;
using executorch::runtime::Error;
using executorch::runtime::EValue;
using executorch::runtime::FreeableBuffer;
using executorch::runtime::MemoryAllocator;
using executorch::runtime::Method;
using executorch::runtime::Program;
using executorch::runtime::Result;
using executorch::runtime::testing::ManagedMemoryManager;
using torch::executor::util::FileDataLoader;
/**
* A backend class whose methods can be overridden individually.
*/
class StubBackend final : public BackendInterface {
public:
// Function signature types that match the BackendInterface methods.
using IsAvailableFn = std::function<bool()>;
using InitFn = std::function<Result<DelegateHandle*>(
FreeableBuffer*,
ArrayRef<CompileSpec>,
BackendInitContext&)>;
using ExecuteFn =
std::function<Error(BackendExecutionContext&, DelegateHandle*, EValue**)>;
using DestroyFn = std::function<void(DelegateHandle*)>;
// Default name that this backend is registered as.
static constexpr char kName[] = "StubBackend";
void install_is_available(IsAvailableFn fn) {
is_available_fn_ = fn;
}
bool is_available() const override {
if (is_available_fn_) {
return is_available_fn_.value()();
}
// Return a benign value otherwise.
return true;
}
void install_init(InitFn fn) {
init_fn_ = fn;
}
Result<DelegateHandle*> init(
BackendInitContext& context,
FreeableBuffer* processed,
ArrayRef<CompileSpec> compile_specs) const override {
if (init_fn_) {
return init_fn_.value()(processed, compile_specs, context);
}
// Return a benign value otherwise.
return nullptr;
}
void install_execute(ExecuteFn fn) {
execute_fn_ = fn;
}
Error execute(
BackendExecutionContext& context,
DelegateHandle* handle,
EValue** args) const override {
if (execute_fn_) {
return execute_fn_.value()(context, handle, args);
}
// Return a benign value otherwise.
return Error::Ok;
}
void install_destroy(DestroyFn fn) {
destroy_fn_ = fn;
}
void destroy(DelegateHandle* handle) const override {
if (destroy_fn_) {
destroy_fn_.value()(handle);
}
}
/**
* Resets to the original constructed state.
*/
void reset() {
is_available_fn_.reset();
init_fn_.reset();
execute_fn_.reset();
destroy_fn_.reset();
}
/**
* Registers the singleton instance if not already registered.
*
* Note that this can be used to install the stub as the implementation for
* any export-time backend by passing in the right name, as long as no other
* backend with that name has been registered yet.
*/
static Error register_singleton(const char* name = kName) {
if (!registered_) {
registered_ = true;
return executorch::runtime::register_backend({name, &singleton_});
}
return Error::Ok;
}
/**
* Returns the instance that was added to the backend registry.
*/
static StubBackend& singleton() {
return singleton_;
}
private:
static bool registered_;
static StubBackend singleton_;
std::optional<IsAvailableFn> is_available_fn_;
std::optional<InitFn> init_fn_;
std::optional<ExecuteFn> execute_fn_;
std::optional<DestroyFn> destroy_fn_;
};
bool StubBackend::registered_ = false;
StubBackend StubBackend::singleton_;
/**
* A DataLoader that wraps a real DataLoader and records the operations
* performed on it and the FreeableBuffers it loads.
*/
class DataLoaderSpy final : public DataLoader {
public:
/// A record of an operation performed on this DataLoader.
struct Operation {
enum { Load, Free } op;
size_t offset; // Set for Load; zero for Free.
void* data; // Set for Free; nullptr for Load.
size_t size; // Set for Load and Free.
std::unique_ptr<const DataLoader::SegmentInfo>
segment_info; // Set for Load; nullptr for Free.
};
explicit DataLoaderSpy(DataLoader* delegate) : delegate_(delegate) {}
Result<FreeableBuffer> load(
size_t offset,
size_t size,
const SegmentInfo& segment_info) const override {
Result<FreeableBuffer> buf = delegate_->load(offset, size, segment_info);
if (!buf.ok()) {
return buf.error();
}
auto segment_info_cpy =
std::make_unique<const DataLoader::SegmentInfo>(segment_info);
operations_.push_back(
{Operation::Load,
offset,
/*data=*/nullptr,
size,
/*segment_info=*/std::move(segment_info_cpy)});
auto* context = new SpyContext(&operations_, std::move(buf.get()));
// Use context->buffer since buf has been moved.
return FreeableBuffer(
context->buffer.data(), context->buffer.size(), FreeBuffer, context);
}
Result<size_t> size() const override {
return delegate_->size();
}
/**
* Returns records of the operations performed on this DataLoader and the
* FreeableBuffers it returned, in order they were performed.
*/
const std::vector<Operation>& operations() const {
return operations_;
}
/**
* Returns true if the DataLoader::load() method was called with the correct
* segment info.
*/
bool UsedLoad(
DataLoader::SegmentInfo::Type segment_type,
const char* descriptor = nullptr) const {
for (const auto& op : operations_) {
if (op.op != Operation::Load) {
continue;
}
// We have a load op.
if (op.segment_info->segment_type == segment_type) {
if (segment_type != DataLoader::SegmentInfo::Type::Backend) {
// For non-backend segments, the descriptor is irrelevant / a nullptr.
return true;
} else {
if (strcmp(op.segment_info->descriptor, descriptor) == 0) {
return true;
}
}
}
}
return false;
}
/**
* Returns true if the operations list shows that the provided data pointer
* was freed.
*/
bool WasFreed(const void* data) const {
for (const auto& op : operations_) {
if (op.op == Operation::Free && op.data == data) {
return true;
}
}
return false;
}
private:
struct SpyContext {
SpyContext(std::vector<Operation>* operations, FreeableBuffer&& buffer)
: operations(operations), buffer(std::move(buffer)) {}
std::vector<Operation>* operations;
FreeableBuffer buffer;
};
static void FreeBuffer(void* context, void* data, size_t size) {
auto* sc = reinterpret_cast<SpyContext*>(context);
sc->operations->push_back(
{Operation::Free, /*offset=*/0, data, size, /*segment_info=*/nullptr});
delete sc;
}
/// The real loader to delegate to.
DataLoader* delegate_;
mutable std::vector<Operation> operations_;
};
constexpr size_t kDefaultNonConstMemBytes = 32 * 1024;
constexpr size_t kDefaultRuntimeMemBytes = 32 * 1024;
class BackendIntegrationTest : public ::testing::TestWithParam<bool> {
protected:
void SetUp() override {
// Since these tests cause ET_LOG to be called, the PAL must be initialized
// first.
executorch::runtime::runtime_init();
// Make sure that the backend has been registered. Safe to call multiple
// times. Doing this at runtime ensures that it's only registered if these
// tests are run.
ASSERT_EQ(StubBackend::register_singleton(), Error::Ok);
// Paths to the test program files.
program_path_ = std::getenv("ET_MODULE_ADD_MUL_PATH");
ASSERT_FALSE(program_path_.empty());
program_nosegments_path_ = std::getenv("ET_MODULE_ADD_MUL_NOSEGMENTS_PATH");
ASSERT_FALSE(program_nosegments_path_.empty());
}
void TearDown() override {
// Clean up any modifications to the singleton.
StubBackend::singleton().reset();
}
/**
* Returns true if program_path() returns a file with extracted segments.
*/
bool using_segments() const {
return GetParam();
}
/**
* Returns tha path to the program to load. May or may not have extracted
* segments, depending on the return value of using_segments().
*/
const char* program_path() const {
if (using_segments()) {
return program_path_.c_str();
} else {
return program_nosegments_path_.c_str();
}
}
private:
std::string program_path_;
std::string program_nosegments_path_;
};
TEST_P(BackendIntegrationTest, BackendIsPresent) {
BackendInterface* backend =
executorch::runtime::get_backend_class(StubBackend::kName);
ASSERT_EQ(backend, &StubBackend::singleton());
}
// Demonstrate that installed StubBackend initializes successfully by default.
TEST_P(BackendIntegrationTest, BasicInitSucceeds) {
Result<FileDataLoader> loader = FileDataLoader::from(program_path());
ASSERT_EQ(loader.error(), Error::Ok);
Result<Program> program = Program::load(&loader.get());
ASSERT_EQ(program.error(), Error::Ok);
auto method_meta = program->method_meta("forward");
EXPECT_EQ(method_meta->uses_backend(StubBackend::kName), true);
EXPECT_EQ(method_meta->uses_backend("INVALID_BACKEND_NAME"), false);
ManagedMemoryManager mmm(kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes);
Result<Method> method_res = program->load_method("forward", &mmm.get());
EXPECT_EQ(method_res.error(), Error::Ok);
}
TEST_P(BackendIntegrationTest, GetBackendNamesSuccess) {
// Load the program from file.
Result<FileDataLoader> loader = FileDataLoader::from(program_path());
ASSERT_EQ(loader.error(), Error::Ok);
Result<Program> program = Program::load(&loader.get());
ASSERT_EQ(program.error(), Error::Ok);
// Get method metadata for the "forward" method.
auto method_meta = program->method_meta("forward");
// Ensure the StubBackend is used.
EXPECT_TRUE(method_meta->uses_backend(StubBackend::kName));
// Retrieve the number of backends.
const size_t num_backends = method_meta->num_backends();
EXPECT_GT(num_backends, 0u);
// Iterate through each backend and verify its name.
for (size_t i = 0; i < num_backends; ++i) {
auto backend_name_result = method_meta->get_backend_name(i);
ASSERT_TRUE(backend_name_result.ok());
const char* name = backend_name_result.get();
// For this test, we expect that the only backend is StubBackend.
EXPECT_STREQ(name, StubBackend::kName);
}
// Check that an out-of-range index returns an error.
auto out_of_range_result = method_meta->get_backend_name(num_backends);
EXPECT_FALSE(out_of_range_result.ok());
}
TEST_P(BackendIntegrationTest, FreeingProcessedBufferSucceeds) {
// Install an init() implementation that frees its processed buffer, and lets
// us know that it was actually called by setting init_called.
bool init_called = false;
const void* processed_data = nullptr;
StubBackend::singleton().install_init(
[&](FreeableBuffer* processed,
ET_UNUSED ArrayRef<CompileSpec> compile_specs,
ET_UNUSED BackendInitContext& backend_init_context)
-> Result<DelegateHandle*> {
init_called = true;
processed_data = processed->data();
processed->Free();
return nullptr;
});
// Wrap the real loader in a spy so we can see which operations were
// performed.
Result<FileDataLoader> loader = FileDataLoader::from(program_path());
ASSERT_EQ(loader.error(), Error::Ok);
DataLoaderSpy spy_loader(&loader.get());
// Load the program.
Result<Program> program = Program::load(&spy_loader);
ASSERT_EQ(program.error(), Error::Ok);
ManagedMemoryManager mmm(kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes);
Result<Method> method_res = program->load_method("forward", &mmm.get());
EXPECT_EQ(method_res.error(), Error::Ok);
// Demonstrate that our installed init was called.
EXPECT_TRUE(init_called);
// See if the processed data was freed.
bool processed_was_freed = spy_loader.WasFreed(processed_data);
if (using_segments()) {
// Used the loader to create the FreeableBuffer that was passed to the
// backend, so we can see its Free() call.
EXPECT_TRUE(processed_was_freed);
} else {
// Didn't use the loader to create the FreeableBuffer that was passed to the
// backend, so we can't see its Free() call.
EXPECT_FALSE(processed_was_freed);
}
}
TEST_P(BackendIntegrationTest, EndToEndTestWithProcessedAsHandle) {
// Install an init() implementation that does not free its processed buffer,
// and returns the FreeableBuffer as the delegate handle.
FreeableBuffer* init_processed = nullptr;
StubBackend::singleton().install_init(
[&](FreeableBuffer* processed,
ET_UNUSED ArrayRef<CompileSpec> compile_specs,
ET_UNUSED BackendInitContext& backend_init_context)
-> Result<DelegateHandle*> {
init_processed = processed;
return processed;
});
// Install an execute() that expects the handle to be the processed
// FreeableBuffer.
DelegateHandle* execute_handle = nullptr;
StubBackend::singleton().install_execute(
[&](ET_UNUSED BackendExecutionContext& backend_execution_context,
DelegateHandle* handle,
ET_UNUSED EValue** args) -> Error {
execute_handle = handle;
auto* processed = reinterpret_cast<FreeableBuffer*>(handle);
// Read the data, which will tend to cause an ASAN error if it's not
// valid.
auto copy = std::make_unique<char[]>(processed->size());
std::memcpy(copy.get(), processed->data(), processed->size());
return Error::Ok;
});
// Install a destroy() that expects the handle to be the processed
// FreeableBuffer.
DelegateHandle* destroy_handle = nullptr;
StubBackend::singleton().install_destroy(
[&](DelegateHandle* handle) -> void { destroy_handle = handle; });
// Wrap the real loader in a spy so we can see which operations were
// performed.
Result<FileDataLoader> loader = FileDataLoader::from(program_path());
ASSERT_EQ(loader.error(), Error::Ok);
DataLoaderSpy spy_loader(&loader.get());
// Load the program.
Result<Program> program = Program::load(&spy_loader);
ASSERT_EQ(program.error(), Error::Ok);
// Hold onto the address of the processed buffer so we can compare against
// it after the FreeableBuffer has been destroyed.
const void* processed_data;
// Add a scope so we can watch executor be destroyed.
{
ManagedMemoryManager mmm(kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes);
Result<Method> method_res = program->load_method("forward", &mmm.get());
EXPECT_TRUE(method_res.ok());
// Demonstrate that our installed init was called.
EXPECT_NE(init_processed, nullptr);
// Not freed yet.
EXPECT_GT(init_processed->size(), 0);
EXPECT_NE(init_processed->data(), nullptr);
processed_data = init_processed->data();
// The processed data should not have been freed during init.
EXPECT_FALSE(spy_loader.WasFreed(init_processed->data()));
auto method(std::move(method_res.get()));
// Execute the model.
auto input_cleanup = executorch::extension::prepare_input_tensors(method);
ASSERT_EQ(input_cleanup.error(), Error::Ok);
auto err = method.execute();
EXPECT_EQ(err, Error::Ok);
// Check that the processed buffer was passed to execute() as the handle.
EXPECT_EQ(init_processed, execute_handle);
// The processed data should not have been freed during execution.
EXPECT_FALSE(spy_loader.WasFreed(init_processed->data()));
}
// `executor` has now been destroyed, which should have freed the processed
// data.
bool processed_was_freed = spy_loader.WasFreed(processed_data);
if (using_segments()) {
// Used the loader to create the FreeableBuffer that was passed to the
// backend, so we can see its Free() call.
EXPECT_TRUE(processed_was_freed);
} else {
// Didn't use the loader to create the FreeableBuffer that was passed to the
// backend, so we can't see its Free() call.
EXPECT_FALSE(processed_was_freed);
}
// And it should have destroyed the backend handle.
EXPECT_EQ(execute_handle, destroy_handle);
}
/**
* Tests that the DataLoader's load is receiving the correct segment info for
* different types of segments.
*/
TEST_P(BackendIntegrationTest, SegmentInfoIsPassedIntoDataLoader) {
const void* processed_data = nullptr;
StubBackend::singleton().install_init(
[&](FreeableBuffer* processed,
ET_UNUSED ArrayRef<CompileSpec> compile_specs,
ET_UNUSED BackendInitContext& backend_init_context)
-> Result<DelegateHandle*> {
processed_data = processed->data();
processed->Free();
return nullptr;
});
// Wrap the real loader in a spy so we can see which operations were
// performed.
Result<FileDataLoader> loader = FileDataLoader::from(program_path());
ASSERT_EQ(loader.error(), Error::Ok);
DataLoaderSpy spy_loader(&loader.get());
// Load the program.
Result<Program> program = Program::load(&spy_loader);
ASSERT_EQ(program.error(), Error::Ok);
ManagedMemoryManager mmm(kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes);
// Expect that load was called correctly on program segments.
bool program_load_was_called =
spy_loader.UsedLoad(DataLoader::SegmentInfo::Type::Program, nullptr);
// Load a method.
Result<Method> method_res = program->load_method("forward", &mmm.get());
EXPECT_EQ(method_res.error(), Error::Ok);
// Expect that load was called correctly on a backend segment.
bool backend_load_was_called = spy_loader.UsedLoad(
DataLoader::SegmentInfo::Type::Backend,
"StubBackend"); // This backend id is taken from the StubBackend defined
// in export_delegated_program.py.
EXPECT_TRUE(program_load_was_called);
EXPECT_EQ(backend_load_was_called, using_segments());
}
TEST_P(BackendIntegrationTest, GetMethodNameDuringInitSuccess) {
Result<FileDataLoader> loader = FileDataLoader::from(program_path());
ASSERT_EQ(loader.error(), Error::Ok);
const void* processed_data = nullptr;
StubBackend::singleton().install_init(
[&](FreeableBuffer* processed,
ET_UNUSED ArrayRef<CompileSpec> compile_specs,
ET_UNUSED BackendInitContext& backend_init_context)
-> Result<DelegateHandle*> {
auto method_name = backend_init_context.get_method_name();
// Ensure that we can get the method name during init via context
EXPECT_STREQ(method_name, "forward");
processed_data = processed->data();
return nullptr;
});
Result<Program> program = Program::load(&loader.get());
ManagedMemoryManager mmm(kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes);
Result<Method> method = program->load_method("forward", &mmm.get());
EXPECT_TRUE(method.ok());
ASSERT_EQ(program.error(), Error::Ok);
}
TEST_P(BackendIntegrationTest, GetMethodNameDuringExecuteSuccess) {
Result<FileDataLoader> loader = FileDataLoader::from(program_path());
ASSERT_EQ(loader.error(), Error::Ok);
StubBackend::singleton().install_execute(
[&](BackendExecutionContext& backend_execution_context,
ET_UNUSED DelegateHandle* handle,
ET_UNUSED EValue** args) -> Error {
// Ensure that we can get the method name during execution via context
auto method_name = backend_execution_context.get_method_name();
EXPECT_STREQ(method_name, "forward");
return Error::Ok;
});
Result<Program> program = Program::load(&loader.get());
ManagedMemoryManager mmm(kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes);
Result<Method> method = program->load_method("forward", &mmm.get());
EXPECT_TRUE(method.ok());
Error err = method->execute();
ASSERT_EQ(err, Error::Ok);
}
// TODO: Add more tests for the runtime-to-backend interface. E.g.:
// - Errors during init() or execute() result in runtime init/execution failures
// - Correct values are passed to init()/execute()
// - Demonstrate use of the runtime allocator
// - ...
// Run all BackendIntegrationTests multiple times, varying the return value of
// `GetParam()` based on the `testing::Values` list. The tests will interpret
// the boolean as "using segments".
INSTANTIATE_TEST_SUITE_P(
VariedSegments,
BackendIntegrationTest,
testing::Values(false, true));
class DelegateDataAlignmentTest : public ::testing::TestWithParam<bool> {
protected:
void SetUp() override {
// Since these tests cause ET_LOG to be called, the PAL must be initialized
// first.
executorch::runtime::runtime_init();
// Make sure that the backend has been registered. Safe to call multiple
// times. Doing this at runtime ensures that it's only registered if these
// tests are run.
ASSERT_EQ(StubBackend::register_singleton(), Error::Ok);
// Paths to the test program files.
default_alignment_program_path_ =
std::getenv("ET_MODULE_ADD_MUL_NOSEGMENTS_PATH");
ASSERT_FALSE(default_alignment_program_path_.empty());
override_alignment_program_path_ =
std::getenv("ET_MODULE_ADD_MUL_NOSEGMENTS_DA1024_PATH");
ASSERT_FALSE(override_alignment_program_path_.empty());
}
void TearDown() override {
// Clean up any modifications to the singleton.
StubBackend::singleton().reset();
}
/**
* Returns the expected minimum alignment of inline tensor data, given
* the testing parameter.
*/
size_t expected_alignment() const {
if (GetParam()) {
// The delegate data inline alignment used by the -da1024 file.
return 1024;
} else {
// A small alignment that's compatible with any realistic alignment.
return 4;
}
}
/**
* Returns tha path to the program to load. May or may not have an alignment
* override, depending on the return value of expected_alignment().
*/
const char* program_path() const {
if (GetParam()) {
return override_alignment_program_path_.c_str();
} else {
return default_alignment_program_path_.c_str();
}
}
private:
std::string default_alignment_program_path_;
std::string override_alignment_program_path_;
};
TEST_P(DelegateDataAlignmentTest, ExpectedDataAlignment) {
// Install an init() implementation that records the pointer to the delegate
// data blob so we can check its alignment.
const void* processed_data = nullptr;
StubBackend::singleton().install_init(
[&](FreeableBuffer* processed,
ET_UNUSED ArrayRef<CompileSpec> compile_specs,
ET_UNUSED BackendInitContext& backend_init_context)
-> Result<DelegateHandle*> {
processed_data = processed->data();
return nullptr;
});
// Create a loader that can satisfy the alignment required by this program.
Result<FileDataLoader> loader =
FileDataLoader::from(program_path(), /*alignment=*/expected_alignment());
ASSERT_EQ(loader.error(), Error::Ok);
// Wrap the real loader in a spy so we can see which operations were
// performed.
DataLoaderSpy spy_loader(&loader.get());
// Load the program.
Result<Program> program = Program::load(&spy_loader);
ASSERT_EQ(program.error(), Error::Ok);
ManagedMemoryManager mmm(kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes);
Result<Method> method = program->load_method("forward", &mmm.get());
EXPECT_TRUE(method.ok());
// Demonstrate that our installed init was called.
EXPECT_NE(processed_data, nullptr);
// Check that it had the required alignment. The alignment of 1024 is larger
// than the test file with default alignment, so the default alignment cannot
// accidentally satisfy it.
EXPECT_ALIGNED(processed_data, expected_alignment());
}
// Run all DelegateDataAlignmentTests multiple times, varying the return value
// of `GetParam()` based on the `testing::Values` list. The tests will interpret
// the boolean as "was inline delegate data alignment overridden to 1024".
INSTANTIATE_TEST_SUITE_P(
VariedAlignment,
DelegateDataAlignmentTest,
testing::Values(false, true));