forked from pytorch/executorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmulti_runner.cpp
378 lines (325 loc) · 11 KB
/
multi_runner.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
/**
* @file
*
* Creates multiple Executor instances at the same time, demonstrating that the
* same process can handle multiple runtimes at once.
*
* Usage:
* multi_runner --models=<model.pte>[,<m2.pte>[,...]] [--num_instances=<num>]
*/
#include <gflags/gflags.h>
#include <sys/stat.h>
#include <cassert>
#include <condition_variable>
#include <cstdio>
#include <functional>
#include <memory>
#include <sstream>
#include <thread>
#include <tuple>
#include <executorch/extension/data_loader/buffer_data_loader.h>
#include <executorch/runtime/core/error.h>
#include <executorch/runtime/core/result.h>
#include <executorch/runtime/executor/method.h>
#include <executorch/runtime/executor/program.h>
#include <executorch/runtime/executor/test/managed_memory_manager.h>
#include <executorch/runtime/platform/log.h>
#include <executorch/runtime/platform/runtime.h>
#include <executorch/util/read_file.h>
#include <executorch/util/util.h>
DEFINE_string(
models,
"",
"Comma-separated list of paths to serialized ExecuTorch model files");
DEFINE_int32(
num_instances,
10,
"Number of Executor instances to create in parallel, for each model");
static bool validate_path_list(
const char* flagname,
const std::string& path_list);
DEFINE_validator(models, &validate_path_list);
static bool validate_positive_int32(const char* flagname, int32_t val);
DEFINE_validator(num_instances, &validate_positive_int32);
namespace {
using torch::executor::DataLoader;
using torch::executor::Error;
using torch::executor::FreeableBuffer;
using torch::executor::MemoryAllocator;
using torch::executor::MemoryManager;
using torch::executor::Method;
using torch::executor::Program;
using torch::executor::Result;
using torch::executor::testing::ManagedMemoryManager;
using torch::executor::util::BufferDataLoader;
/**
* A model that has been loaded and has had its execution plan and inputs
* prepared. Can be run once.
*
* Creates and owns the underyling state, making things easier to manage.
*/
class PreparedModel final {
public:
PreparedModel(
const std::string& name,
const void* model_data,
size_t model_data_size,
size_t non_const_mem_bytes,
size_t runtime_mem_bytes)
: name_(name),
loader_(model_data, model_data_size),
program_(load_program_or_die(loader_)),
memory_manager_(non_const_mem_bytes, runtime_mem_bytes),
method_(load_method_or_die(program_, &memory_manager_.get())),
has_run_(false) {
inputs_ = torch::executor::util::PrepareInputTensors(method_);
}
void run() {
ET_CHECK_MSG(!has_run_, "A PreparedModel may only be run once");
has_run_ = true;
Error status = method_.execute();
ET_CHECK_MSG(
status == Error::Ok,
"plan.execute() failed with status 0x%" PRIx32,
status);
// TODO(T131578656): Do something with the outputs.
}
const std::string& name() const {
return name_;
}
~PreparedModel() {
torch::executor::util::FreeInputs(inputs_);
}
private:
static Program load_program_or_die(DataLoader& loader) {
Result<Program> program = Program::load(&loader);
ET_CHECK(program.ok());
return std::move(program.get());
}
static Method load_method_or_die(
const Program& program,
MemoryManager* memory_manager) {
Result<Method> method = program.load_method("forward", memory_manager);
ET_CHECK(method.ok());
return std::move(method.get());
}
const std::string name_;
BufferDataLoader loader_; // Needs to outlive program_
Program program_; // Needs to outlive executor_
ManagedMemoryManager memory_manager_; // Needs to outlive executor_
Method method_;
exec_aten::ArrayRef<void*> inputs_;
bool has_run_;
};
/**
* Creates PreparedModels based on the provided serialized data and memory
* parameters.
*/
class ModelFactory {
public:
ModelFactory(
const std::string& name, // For debugging
std::shared_ptr<const char> model_data,
size_t model_data_size,
size_t non_const_mem_bytes = 40 * 1024U * 1024U, // 40 MB
size_t runtime_mem_bytes = 2 * 1024U * 1024U) // 2 MB
: name_(name),
model_data_(model_data),
model_data_size_(model_data_size),
non_const_mem_bytes_(non_const_mem_bytes),
runtime_mem_bytes_(runtime_mem_bytes) {}
std::unique_ptr<PreparedModel> prepare(
std::string_view name_affix = "") const {
return std::make_unique<PreparedModel>(
name_affix.empty() ? name_ : std::string(name_affix) + ":" + name_,
model_data_.get(),
model_data_size_,
non_const_mem_bytes_,
runtime_mem_bytes_);
}
const std::string& name() const {
return name_;
}
private:
const std::string name_;
std::shared_ptr<const char> model_data_;
const size_t model_data_size_;
const size_t non_const_mem_bytes_;
const size_t runtime_mem_bytes_;
};
/// Synchronizes a set of model threads as they walk through prepare/run states.
class Synchronizer {
public:
explicit Synchronizer(size_t total_threads)
: total_threads_(total_threads), state_(State::INIT_THREAD) {}
/// The states for threads to move through. Must advance in order.
enum class State {
/// Initial state.
INIT_THREAD,
/// Thread is ready to prepare its model instance.
PREPARE_MODEL,
/// Thread is ready to run its model instance.
RUN_MODEL,
};
/// Wait until all threads have requested to advance to this state, then
/// advance all of them.
void advance_to(State new_state) {
std::unique_lock<std::mutex> lock(lock_);
// Enforce valid state machine transitions.
assert(
(new_state == State::PREPARE_MODEL && state_ == State::INIT_THREAD) ||
(new_state == State::RUN_MODEL && state_ == State::PREPARE_MODEL));
// Indicate that this thread is ready to move to the new state.
num_ready_++;
if (num_ready_ == total_threads_) {
// We were the last thread to become ready. Tell all threads to
// move to the next state.
state_ = new_state;
num_ready_ = 0;
cv_.notify_all();
} else {
// Wait until all other threads are ready.
cv_.wait(lock, [=] { return this->state_ == new_state; });
}
}
private:
/// The total number of threads to wait for.
const size_t total_threads_;
/// Locks all mutable fields in this class.
std::mutex lock_;
/// The number of threads that are ready to move to the next state.
size_t num_ready_ = 0;
/// The state that all threads should be in.
State state_;
/// Signals threads to check for state updates.
std::condition_variable cv_;
};
/**
* Waits for all threads to begin running; prepares a model and waits for all
* threads to finish preparation; runs the model and exits.
*/
void model_thread(ModelFactory& factory, Synchronizer& sync, size_t thread_id) {
ET_LOG(
Info,
"[%zu] Thread has started for %s.",
thread_id,
factory.name().c_str());
sync.advance_to(Synchronizer::State::PREPARE_MODEL);
// Create and prepare our model instance.
ET_LOG(Info, "[%zu] Preparing %s...", thread_id, factory.name().c_str());
std::unique_ptr<PreparedModel> model =
factory.prepare(/*name_affix=*/std::to_string(thread_id));
ET_LOG(Info, "[%zu] Prepared %s.", thread_id, model->name().c_str());
sync.advance_to(Synchronizer::State::RUN_MODEL);
// Run our model.
ET_LOG(Info, "[%zu] Running %s...", thread_id, model->name().c_str());
model->run();
ET_LOG(
Info, "[%zu] Finished running %s...", thread_id, model->name().c_str());
// TODO(T131578656): Check the model output.
}
/**
* Splits the provided string on `,` and returns a vector of the non-empty
* elements. Does not string whitespace.
*/
std::vector<std::string> split_string_list(const std::string& list) {
std::vector<std::string> items;
std::stringstream sstream(list);
while (sstream.good()) {
std::string item;
getline(sstream, item, ',');
if (!item.empty()) {
items.push_back(item);
}
}
return items;
}
} // namespace
int main(int argc, char** argv) {
torch::executor::runtime_init();
// Parse and extract flags.
gflags::SetUsageMessage(
"Creates multiple Executor instances at the same time, demonstrating "
"that the same process can handle multiple runtimes at once.");
gflags::ParseCommandLineFlags(&argc, &argv, true);
std::vector<std::string> model_paths = split_string_list(FLAGS_models);
size_t num_instances = FLAGS_num_instances;
// Create a factory for each model provided on the commandline.
std::vector<std::unique_ptr<ModelFactory>> factories;
for (const auto& model_path : model_paths) {
std::shared_ptr<char> file_data;
size_t file_size;
Error err = torch::executor::util::read_file_content(
model_path.c_str(), &file_data, &file_size);
ET_CHECK(err == Error::Ok);
factories.push_back(std::make_unique<ModelFactory>(
/*name=*/model_path, file_data, file_size));
}
// Spawn threads to prepare and run separate instances of the models in
// parallel.
const size_t num_threads = factories.size() * num_instances;
Synchronizer state(num_threads);
std::vector<std::thread> threads;
size_t thread_id = 0; // Unique ID for every thread.
ET_LOG(Info, "Creating %zu threads...", num_threads);
for (const auto& factory : factories) {
for (size_t i = 0; i < num_instances; ++i) {
threads.push_back(std::thread(
model_thread, std::ref(*factory), std::ref(state), thread_id++));
}
}
// Wait for all threads to finish.
ET_LOG(Info, "Waiting for %zu threads to exit...", threads.size());
for (auto& thread : threads) {
thread.join();
}
ET_LOG(Info, "All %zu threads exited.", threads.size());
}
//
// Flag validation
//
/// Returns true if the specified path exists in the filesystem.
static bool path_exists(const std::string& path) {
struct stat st;
return stat(path.c_str(), &st) == 0;
}
/// Returns true if `path_list` contains a comma-separated list of at least one
/// path that exists in the filesystem.
static bool validate_path_list(
const char* flagname,
const std::string& path_list) {
const std::vector<std::string> paths = split_string_list(path_list);
if (paths.empty()) {
fprintf(
stderr, "Must specify at least one valid path with --%s\n", flagname);
return false;
}
for (const auto& path : split_string_list(path_list)) {
if (!path_exists(path)) {
fprintf(
stderr,
"Path '%s' does not exist in --%s='%s'\n",
path.c_str(),
flagname,
path_list.c_str());
return false;
}
}
return true;
}
/// Returns true if `val` is positive.
static bool validate_positive_int32(const char* flagname, int32_t val) {
if (val <= 0) {
fprintf(
stderr, "Value must be positive for --%s=%" PRId32 "\n", flagname, val);
return false;
}
return true;
}