forked from iree-org/iree
-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathiree-run-mlir-main.cc
535 lines (481 loc) · 22.2 KB
/
iree-run-mlir-main.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
// Copyright 2019 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// IREE source.mlir/mlirbc -> execution output runner.
// This is meant to be called from LIT for FileCheck tests or as a developer
// tool to emulate what an online compiler does. It tries to match the interface
// of iree-compile so it's easy to run tests or approximate an
// `iree-compile | iree-run-module` sequence. If you want a more generalized
// runner for standalone precompiled IREE modules use iree-run-module instead.
//
// If there's a single exported function that will be executed and if there are
// multiple functions --function= can be used to specify which is executed.
// Function inputs can be provided with --input=. Results from the executed
// function will be printed to stdout for checking or can be written to files
// with --output=.
//
// Similar to iree-run-module the --device= flag can be used to specify which
// drivers and devices should be used to execute the function. The tool will
// try to infer which iree-compile flags are required for the devices used but
// this can be overridden by passing the --iree-hal-target-backends= and related
// flags explicitly. Likewise if only the target backend is specified the
// devices to use will be inferred unless explicitly specified.
//
// Example usage to compile and run with CUDA:
// $ iree-run-mlir --device=cuda://0 file.mlir
// or to compile with the LLVM CPU backend and run with the local-task driver:
// $ iree-run-mlir file.mlir \
// --Xcompiler,iree-hal-target-backends=llvm-cpu --device=local-task
//
// Example usage in a lit test:
// // RUN: iree-run-mlir --device= %s --function=foo --input=2xf32=2,3 | \
// // RUN: FileCheck %s
// // CHECK-LABEL: @foo
// // CHECK: 2xf32=[2 3]
// func.func @foo(%arg0: tensor<2xf32>) -> tensor<2xf32> {
// return %arg0 : tensor<2xf32>
// }
//
// Command line arguments are handled by LLVM's parser by default but -- can be
// used to separate the compiler flags from the runtime flags, such as:
// $ iree-run-mlir source.mlir --device=local-task -- \
// --iree-hal-target-backends=llvm-cpu
//
// In addition compiler/runtime flags can be passed in any order by prefixing
// them with --Xcompiler or --Xruntime like `--Xruntime,device=local-task` or
// `--Xruntime --device=local-task`.
#include <cstdio>
#include <cstring>
#include <functional>
#include <memory>
#include <optional>
#include <set>
#include <string>
#include <string_view>
#include <tuple>
#include <type_traits>
#include <utility>
#include <vector>
#include "iree/base/api.h"
#include "iree/base/internal/flags.h"
#include "iree/compiler/embedding_api.h"
#include "iree/hal/api.h"
#include "iree/tooling/context_util.h"
#include "iree/tooling/device_util.h"
#include "iree/tooling/run_module.h"
#include "iree/vm/api.h"
namespace iree {
namespace {
// Polyfill for std::string_view::starts_with, coming in C++ 20.
// https://en.cppreference.com/w/cpp/string/basic_string_view/starts_with
bool starts_with(std::string_view prefix, std::string_view in_str) {
return in_str.size() >= prefix.size() &&
in_str.compare(0, prefix.size(), prefix) == 0;
}
// Tries to guess a default device name from the |target_backend| when possible.
// Users are still able to override this by passing in --device= flags.
std::string InferDefaultDeviceFromTargetBackend(
std::string_view target_backend) {
if (target_backend == "" || target_backend == "vmvx-inline") {
// Plain VM or vmvx-inline targets do not need a HAL device.
return "";
} else if (target_backend == "llvm-cpu" || target_backend == "vmvx") {
// Locally-executable targets default to the multithreaded task system
// driver; users can override by specifying --device=local-sync instead.
return "local-task";
}
// Many other backends use the `driver-pipeline` naming like `vulkan-spirv`
// and we try that; device creation will fail if it's a bad guess.
size_t dash = target_backend.find('-');
if (dash == std::string::npos) {
return std::string(target_backend);
} else {
return std::string(target_backend.substr(0, dash));
}
}
// Tries to guess a target backend from the given |device_uri| when possible.
// Returns empty string if no backend is required or one could not be inferred.
std::string InferTargetBackendFromDevice(iree_string_view_t device_uri) {
// Get the driver name from URIs in the `driver://...` form.
iree_string_view_t driver = iree_string_view_empty();
iree_string_view_split(device_uri, ':', &driver, nullptr);
if (iree_string_view_is_empty(driver)) {
// Plain VM or vmvx-inline targets do not need a HAL device.
return "";
} else if (iree_string_view_starts_with(driver, IREE_SV("local-"))) {
// Locally-executable devices default to the llvm-cpu target as that's
// usually what people want for CPU execution; users can override by
// specifying --iree-hal-target-backends=vmvx instead.
return "llvm-cpu";
}
// Many other backends have aliases that allow using the driver name. If there
// are multiple pipelines available whatever the compiler defaults to is
// chosen.
return std::string(driver.data, driver.size);
}
// Tries to guess a set of target backends from the |device_flag_values| when
// possible. Since multiple target backends can be used for a particular device
// (such as llvm-cpu or vmvx for local-sync and local-task) this is just
// guesswork. If we can't produce a target backend flag value we bail.
// Returns a comma-delimited list of target backends.
StatusOr<std::string> InferTargetBackendsFromDevices(
iree_string_view_list_t device_uris) {
// No-op when no devices specified (probably no HAL).
if (device_uris.count == 0) return "";
// If multiple devices were provided we need to target all of them.
std::set<std::string> target_backends;
for (iree_host_size_t i = 0; i < device_uris.count; ++i) {
auto target_backend = InferTargetBackendFromDevice(device_uris.values[i]);
if (!target_backend.empty()) {
target_backends.insert(std::move(target_backend));
}
}
// Join all target backends together.
std::string result;
for (auto& target_backend : target_backends) {
if (!result.empty()) result.append(",");
result.append(target_backend);
}
return result;
}
// Configures the --iree-hal-target-backends= flag based on the --device= flags
// set by the user. Ignored if any target backends are explicitly specified.
// Online compilers would want to do some more intelligent device selection on
// their own.
Status ConfigureTargetBackends(iree_compiler_session_t* session,
std::string* out_default_device_uri) {
// Query the session for the currently set --iree-hal-target-backends= flag.
// It may be empty string.
std::string target_backends_flag;
ireeCompilerSessionGetFlags(
session, /*nonDefaultOnly=*/true,
[](const char* flag_str, size_t length, void* user_data) {
// NOTE: flag_str has the full `--flag=value` string.
std::string_view prefix = "--iree-hal-target-backends=";
std::string_view flag = std::string_view(flag_str, length);
if (starts_with(prefix, flag)) {
flag.remove_prefix(prefix.size());
if (flag.empty()) return; // ignore empty
auto* result = static_cast<std::string*>(user_data);
*result = std::string(flag);
}
},
static_cast<void*>(&target_backends_flag));
// Query the tooling utils for the --device= flag values. Note that zero or
// more devices may be specified.
iree_string_view_list_t device_uris = iree_hal_device_flag_list();
// No-op if no target backends or devices are specified - this can be an
// intentional decision as the user may be running a program that doesn't use
// the HAL.
if (target_backends_flag.empty() && device_uris.count == 0) {
return OkStatus();
}
// No-op if both target backends and devices are set as the user has
// explicitly specified a configuration.
if (!target_backends_flag.empty() && device_uris.count > 0) {
return OkStatus();
}
// If target backends are specified then we can infer the runtime devices from
// the compiler configuration. This only works if there's a single backend
// specified; if the user wants multiple target backends then they must
// specify the device(s) to use.
if (device_uris.count == 0) {
if (target_backends_flag.find(',') != std::string::npos) {
return iree_make_status(
IREE_STATUS_INVALID_ARGUMENT,
"if multiple target backends are specified the device to use must "
"also be specified with --device= (have "
"`--iree-hal-target-backends=%.*s`)",
(int)target_backends_flag.size(), target_backends_flag.data());
}
*out_default_device_uri =
InferDefaultDeviceFromTargetBackend(target_backends_flag);
return OkStatus();
}
// Infer target backends from the runtime device configuration.
// This can get arbitrarily complex but for now this simple runner just
// guesses. In the future we'll have more ways of configuring the compiler
// from available runtime devices (not just the target backend but also
// target-specific settings).
IREE_ASSIGN_OR_RETURN(auto target_backends,
InferTargetBackendsFromDevices(device_uris));
if (!target_backends.empty()) {
auto target_backends_flag =
std::string("--iree-hal-target-backends=") + target_backends;
const char* compiler_argv[1] = {
target_backends_flag.c_str(),
};
auto error = ireeCompilerSessionSetFlags(
session, IREE_ARRAYSIZE(compiler_argv), compiler_argv);
if (error) {
return iree_make_status(
IREE_STATUS_INVALID_ARGUMENT,
"unable to set inferred target backend flag to `%.*s`",
(int)target_backends_flag.size(), target_backends_flag.data());
}
}
return OkStatus();
}
// Runs the given .mlir file based on the current flags.
StatusOr<int> CompileAndRunFile(iree_compiler_session_t* session,
const char* mlir_filename) {
IREE_TRACE_SCOPE_NAMED("CompileAndRunFile");
// Configure the --iree-hal-target-backends= flag and/or get the default
// device to use at runtime if either are not explicitly specified.
// Note that target backends and the runtime devices aren't 1:1 and this is
// an imperfect guess. In this simple online compiler we assume homogenous
// device sets and only a single global target backend but library/hosting
// layers can configure heterogenous and per-invocation target configurations.
std::string default_device_uri;
IREE_RETURN_IF_ERROR(ConfigureTargetBackends(session, &default_device_uri));
// RAII container for the compiler invocation.
struct InvocationState {
iree_compiler_invocation_t* invocation = nullptr;
iree_compiler_source_t* source = nullptr;
iree_compiler_output_t* output = nullptr;
explicit InvocationState(iree_compiler_session_t* session) {
invocation = ireeCompilerInvocationCreate(session);
}
~InvocationState() {
if (source) ireeCompilerSourceDestroy(source);
if (output) ireeCompilerOutputDestroy(output);
ireeCompilerInvocationDestroy(invocation);
}
Status emitError(iree_compiler_error_t* error,
iree_status_code_t status_code,
std::string_view while_performing = "") {
const char* msg = ireeCompilerErrorGetMessage(error);
fprintf(stderr, "error compiling input file: %s\n", msg);
iree_status_t status = iree_make_status(status_code, msg);
if (!while_performing.empty()) {
status = iree_status_annotate(
status, iree_make_string_view(while_performing.data(),
while_performing.size()));
}
ireeCompilerErrorDestroy(error);
return status;
}
} state(session);
// Open the source file on disk or stdin if `-`.
if (auto error =
ireeCompilerSourceOpenFile(session, mlir_filename, &state.source)) {
return state.emitError(error, IREE_STATUS_NOT_FOUND, "opening source file");
}
// Open a writeable memory buffer that we can stream the compilation outputs
// into. This may be backed by a memory-mapped file to allow for very large
// results.
if (auto error = ireeCompilerOutputOpenMembuffer(&state.output)) {
return state.emitError(error, IREE_STATUS_INTERNAL,
"open output memory buffer");
}
// TODO: make parsing/pipeline execution return an error object.
// We could capture diagnostics, stash them on the state, and emit with
// ireeCompilerInvocationEnableCallbackDiagnostics.
// For now we route all errors to stderr.
ireeCompilerInvocationEnableConsoleDiagnostics(state.invocation);
// Parse the source MLIR input and log verbose errors. Syntax errors or
// version mismatches will hit here.
if (!ireeCompilerInvocationParseSource(state.invocation, state.source)) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"failed to parse input file");
}
// Invoke the standard compilation pipeline to produce the compiled module.
if (!ireeCompilerInvocationPipeline(state.invocation,
IREE_COMPILER_PIPELINE_STD)) {
return iree_make_status(IREE_STATUS_INTERNAL,
"failed to invoke main compiler pipeline");
}
// Flush the output to the memory buffer.
if (auto error = ireeCompilerInvocationOutputVMBytecode(state.invocation,
state.output)) {
return state.emitError(error, IREE_STATUS_INTERNAL,
"emitting output VM module binary");
}
// Get a raw host pointer to the output that we can pass to the runtime.
void* binary_data = nullptr;
uint64_t binary_size = 0;
if (auto error = ireeCompilerOutputMapMemory(state.output, &binary_data,
&binary_size)) {
return state.emitError(error, IREE_STATUS_INTERNAL,
"mapping output buffer");
}
// Hosting libraries can route all runtime allocations to their own allocator
// for statistics, isolation, or efficiency. Here we use the system
// malloc/free.
iree_allocator_t host_allocator = iree_allocator_system();
// The same VM instance should be shared across many contexts. Here we only
// use this once but a library would want to retain this and the devices it
// creates for as long as practical.
vm::ref<iree_vm_instance_t> instance;
IREE_RETURN_IF_ERROR(iree_tooling_create_instance(host_allocator, &instance),
"creating instance");
// Run the compiled module using the global flags for I/O (if any).
// This loads the module, creates a VM context with it and any dependencies,
// parses inputs from flags, and routes/verifies outputs as specified. Hosting
// libraries should always reuse contexts if possible to amortize loading
// costs and carry state (variables/etc) across invocations.
//
// This returns a process exit code based on the run mode (verifying expected
// outputs, etc) that may be non-zero even if the status is success
// ("execution completed successfully but values did not match").
int exit_code = EXIT_SUCCESS;
IREE_RETURN_IF_ERROR(
iree_tooling_run_module_with_data(
instance.get(),
iree_make_string_view(default_device_uri.data(),
default_device_uri.size()),
iree_make_const_byte_span(binary_data, (iree_host_size_t)binary_size),
host_allocator, &exit_code),
"running compiled module");
return exit_code;
}
// Parses a combined list of compiler and runtime flags.
// Each argument list is stored in canonical argc/argv format with a trailing
// NULL string in the storage (excluded from the count).
class ArgParser {
public:
int compiler_argc() { return compiler_args_.size() - 1; }
const char** compiler_argv() {
return const_cast<const char**>(compiler_args_.data());
}
int runtime_argc() { return runtime_args_.size() - 1; }
char** runtime_argv() { return runtime_args_.data(); }
// Parses arguments from a raw command line argc/argv set.
// Returns true if parsing was successful.
bool Parse(int argc_raw, char** argv_raw) {
// Pre-process the arguments with the compiler's argument parser since it
// has super-powers on Windows and must work on the default main arguments.
ireeCompilerGetProcessCLArgs(&argc_raw,
const_cast<const char***>(&argv_raw));
// Always add the progname to both flag sets.
compiler_args_.push_back(argv_raw[0]);
runtime_args_.push_back(argv_raw[0]);
// Everything before -- goes to the runtime.
// Everything after -- goes to the compiler.
// To make it easier to form command lines in scripts we also allow
// prefixing flags with -Xcompiler/-Xruntime on either side of the --.
bool parsing_runtime_args = true;
for (int i = 1; i < argc_raw; ++i) {
char* current_arg_cstr = argv_raw[i];
char* next_arg_cstr =
argv_raw[i + 1]; // ok because list is NULL-terminated
auto current_arg = std::string_view(current_arg_cstr);
if (current_arg == "--") {
// Switch default parsing to compiler flags.
parsing_runtime_args = false;
} else if (current_arg == "-Xcompiler" || current_arg == "--Xcompiler") {
// Next arg is routed to the compiler.
compiler_args_.push_back(next_arg_cstr);
} else if (current_arg == "-Xruntime" || current_arg == "--Xruntime") {
// Next arg is routed to the runtime.
runtime_args_.push_back(next_arg_cstr);
} else if (starts_with("-Xcompiler,", current_arg) ||
starts_with("--Xcompiler,", current_arg)) {
// Split and send the rest of the flag to the compiler.
AppendPrefixedArg(current_arg, &compiler_args_);
} else if (starts_with("-Xruntime,", current_arg) ||
starts_with("--Xruntime,", current_arg)) {
// Split and send the rest of the flag to the runtime.
AppendPrefixedArg(current_arg, &runtime_args_);
} else {
// Route to either runtime or compiler arg sets based on which side of
// the -- we are on.
if (parsing_runtime_args) {
runtime_args_.push_back(current_arg_cstr);
} else {
compiler_args_.push_back(current_arg_cstr);
}
}
}
// Add nullptrs to end to match real argv behavior.
compiler_args_.push_back(nullptr);
runtime_args_.push_back(nullptr);
return true;
}
private:
// Drops the prefix from |prefixed_arg| and appends the arg to |out_args|.
// Example: --Xcompiler,ab=cd,ef -> --ab=cd,ef
void AppendPrefixedArg(std::string_view prefixed_arg,
std::vector<char*>* out_args) {
std::string_view sub_arg = prefixed_arg.substr(prefixed_arg.find(',') + 1);
auto stable_arg = std::make_unique<std::string>("--");
stable_arg->append(sub_arg);
temp_strings_.push_back(std::move(stable_arg));
out_args->push_back(temp_strings_.back()->data());
}
std::vector<std::unique_ptr<std::string>> temp_strings_;
std::vector<char*> runtime_args_;
std::vector<char*> compiler_args_;
};
} // namespace
extern "C" int main(int argc, char** argv) {
IREE_TRACE_APP_ENTER();
IREE_TRACE_ZONE_BEGIN_NAMED(z0, "iree-run-mlir");
// Initialize the compiler once on startup before using any other APIs.
ireeCompilerGlobalInitialize();
// Parse full argument list and split into compiler/runtime flag sets.
ArgParser arg_parser;
if (!arg_parser.Parse(argc, argv)) {
ireeCompilerGlobalShutdown();
IREE_TRACE_ZONE_END(z0);
IREE_TRACE_APP_EXIT(EXIT_FAILURE);
return EXIT_FAILURE;
}
// Pass along compiler flags.
// Since this is a command line tool we initialize the global compiler
// command line environment prior to processing the sources.
// In-process/library uses would usually not do this and would set session
// specific arguments as needed from whatever configuration mechanisms they
// use (kwargs passed to python functions, etc).
ireeCompilerSetupGlobalCL(arg_parser.compiler_argc(),
arg_parser.compiler_argv(), "iree-run-mlir",
/*installSignalHandlers=*/true);
// Pass along runtime flags.
// Note that positional args are left in runtime_argv (after progname).
// Runtime flags are generally only useful in command line tools where there's
// a fixed set of devices, a short lifetime, a single thread, and a single
// context/set of modules/etc. Hosting applications can programmatically
// do most of what the flags do in a way that avoids the downsides of such
// global one-shot configuration.
int runtime_argc = arg_parser.runtime_argc();
char** runtime_argv = arg_parser.runtime_argv();
iree_flags_parse_checked(IREE_FLAGS_PARSE_MODE_DEFAULT, &runtime_argc,
&runtime_argv);
// Ensure a source file was found.
if (runtime_argc != 2) {
fprintf(stderr,
"ERROR: one source MLIR file must be specified.\n"
"Pass either the path to a .mlir/mlirbc file or `-` to read from "
"stdin.\n");
fflush(stderr);
IREE_TRACE_ZONE_END(z0);
IREE_TRACE_APP_EXIT(EXIT_FAILURE);
return EXIT_FAILURE;
}
const char* source_filename = runtime_argv[1];
// Sessions can be reused for many compiler invocations.
iree_compiler_session_t* session = ireeCompilerSessionCreate();
// The process return code is 0 for success and non-zero otherwise.
// We don't differentiate between compiler or runtime error codes here but
// could if someone found it useful.
int exit_code = EXIT_SUCCESS;
// Compile and run the provided source file and get the exit code determined
// based on the run mode.
auto status_or = CompileAndRunFile(session, source_filename);
if (status_or.ok()) {
exit_code = status_or.value();
} else {
exit_code = 2;
iree_status_fprint(stderr, status_or.status().get());
fflush(stderr);
}
ireeCompilerSessionDestroy(session);
// No more compiler APIs can be called after this point.
ireeCompilerGlobalShutdown();
IREE_TRACE_ZONE_END(z0);
IREE_TRACE_APP_EXIT(exit_code);
return exit_code;
}
} // namespace iree