-
Notifications
You must be signed in to change notification settings - Fork 509
/
Copy pathmethod.h
410 lines (359 loc) · 13.3 KB
/
method.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once
#ifdef __GNUC__
// Disable -Wdeprecated-declarations, as some builds use 'Werror'.
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
#endif
#include <executorch/runtime/core/evalue.h>
#include <executorch/runtime/core/event_tracer.h>
#include <executorch/runtime/core/exec_aten/exec_aten.h>
#include <executorch/runtime/core/named_data_map.h>
#include <executorch/runtime/core/span.h>
#include <executorch/runtime/executor/memory_manager.h>
#include <executorch/runtime/executor/method_meta.h>
#include <executorch/runtime/platform/compiler.h>
// Forward declare flatbuffer types. This is a public header and must not
// include the generated flatbuffer header.
namespace executorch_flatbuffer {
struct Chain;
struct ExecutionPlan;
struct EValue;
} // namespace executorch_flatbuffer
namespace executorch {
namespace runtime {
// Forward declare NamedData. This is a public header and must not include
// internal data types.
namespace deserialization {
struct NamedData;
} // namespace deserialization
// Forward declare Program to avoid a circular reference.
class Program;
// Forward declare internal types.
class BackendDelegate;
struct Chain;
class KernelRuntimeContext;
using OpFunction = void (*)(KernelRuntimeContext&, EValue**);
/// A list of pointers into the master values table that together compose the
/// argument list for a single instruction
using InstructionArgs = Span<EValue*>;
using deserialization::NamedData;
/**
* An executable method of an executorch program. Maps to a python method like
* `forward()` on the original nn.Module.
*/
class Method final {
public:
/**
* Move ctor. Takes ownership of resources previously owned by `rhs`,
* and leaves `rhs` in an uninitialized state.
*/
Method(Method&& rhs) noexcept
: step_state_(rhs.step_state_),
program_(rhs.program_),
memory_manager_(rhs.memory_manager_),
temp_allocator_(rhs.temp_allocator_),
serialization_plan_(rhs.serialization_plan_),
event_tracer_(rhs.event_tracer_),
n_value_(rhs.n_value_),
values_(rhs.values_),
n_delegate_(rhs.n_delegate_),
delegates_(rhs.delegates_),
n_chains_(rhs.n_chains_),
chains_(rhs.chains_),
external_constants_(rhs.external_constants_),
n_external_constants_(rhs.n_external_constants_),
init_state_(rhs.init_state_) {
// Required: clear out fields that the dtor looks at, so that we don't free
// anything twice.
rhs.n_value_ = 0;
rhs.values_ = nullptr;
rhs.n_delegate_ = 0;
rhs.delegates_ = nullptr;
rhs.n_external_constants_ = 0;
rhs.external_constants_ = nullptr;
// Helpful: Try to ensure that any other interactions with the old object
// result in failures.
rhs.init_state_ = InitializationState::Uninitialized;
rhs.step_state_ = {};
rhs.program_ = nullptr;
rhs.memory_manager_ = nullptr;
rhs.serialization_plan_ = nullptr;
rhs.event_tracer_ = nullptr;
rhs.n_chains_ = 0;
rhs.chains_ = nullptr;
}
/**
* Sets the internal input value to be equivalent to the to the provided
* value.
*
* @param[in] input_evalue The evalue to copy into the method input. If the
* evalue is a tensor, the data is copied in most cases, so the tensor
* passed in here does not always need to outlive this call. But there is
* a case where the Method will keep a pointer to the tensor's data.
* Based on the memory plan of the method, the inputs may not have
* buffer space pre-allocated for them. In this case the executor will
* alias the memory of the tensors provided as inputs here rather then
* deepcopy the input into the memory planned arena.
*
* @param[in] input_idx Zero-based index of the input to set. Must be less
* than the value returned by inputs_size().
*
* @returns Error::Ok on success, non-Ok on failure.
*/
ET_NODISCARD Error set_input(const EValue& input_evalue, size_t input_idx);
/**
* Sets the values of all method inputs.
*
* See set_input() for a more detailed description of the behavior.
*
* @param[in] input_evalues The new values for all of the method inputs. The
* type of each element must match the type of corresponding input. If the
* value of an element is a tensor, attempts to allow dynamic shape, but
* the dtype must always agree.
*
* @returns Error::Ok on success, non-Ok on failure.
*/
ET_NODISCARD Error
set_inputs(const executorch::aten::ArrayRef<EValue>& input_evalues);
/**
* Sets the data buffer of the specified method output to the provided value.
*
* NOTE: Based on the memory plan of the method, the output tensors may not
* have buffer space pre-allocated for them, in this case the executor will
* point those tensors to the buffer provided here, so the user should take
* care that the life span of this memory outlasts the executor forward.
*
* @param[in] buffer The block of memory to point the specified tensor at.
*
* @param[in] size the length of buffer in bytes, must be >= the nbytes of the
* specified tensor.
*
* @param[in] output_idx The index of the output to set the data_ptr for. Must
* correspond to a tensor, and that tensor must not have had a buffer
* allocated by the memory plan.
*
* @returns Error::Ok on success, non-Ok on failure.
*/
ET_NODISCARD Error
set_output_data_ptr(void* buffer, size_t size, size_t output_idx);
/**
* Copies the method's outputs into the provided array.
*
* WARNING: The output contains shallow copies of internal tensor outputs.
* Please do not mutate returned Tensor elements.
*
* TODO(T139259264): Add checks to detect output mutation, or deep-copy
* outputs.
*
* @param[in] output_evalues The array to copy the outputs into. The first
* `outputs_size()` elements will be set to the corresponding output
* values. The rest of the array will be set to the EValue value None.
* @param[in] length The size of the `output_evalues` array in elements. Must
* be greater than or equal to `outputs_size()`.
*
* @returns Error::Ok on success, non-Ok on failure.
*/
ET_NODISCARD Error get_outputs(EValue* output_evalues, size_t length);
/**
* Copies the method's inputs into the provided array.
*
* WARNING: The input contains shallow copies of internal tensor inputs.
* Please do not mutate returned Tensor elements.
*
* @param[in] input_evalues The array to copy the inputs into. The first
* `inputs_size()` elements will be set to the corresponding input
* values. The rest of the array will be set to the EValue value None.
* @param[in] length The size of the `input_evalues` array in elements. Must
* be greater than or equal to `inputs_size()`.
*
* @returns Error::Ok on success, non-Ok on failure.
*/
ET_NODISCARD Error get_inputs(EValue* input_evalues, size_t length);
/**
* Execute the method.
*
* NOTE: Will fail if the method has been partially executed using the
* `step()` api.
*
* @returns Error::Ok on success, non-Ok on failure.
*/
ET_NODISCARD Error execute();
/**
* EXPERIMENTAL: Advances/executes a single instruction in the method.
*
* @retval Error::Ok step succeeded
* @retval non-Ok step failed
* @retval Error::EndOfMethod method finished executing successfully
*/
ET_EXPERIMENTAL ET_NODISCARD Error step();
/// DEPRECATED: Use `step()` instead.
ET_DEPRECATED ET_NODISCARD Error experimental_step();
/**
* EXPERIMENTAL: Resets execution state to the start of the Method. For use
* with the `step()` API.
*
* @retval Error:Ok on success
* @retval Error::InvalidState if called before step-based execution reached
* the end of the Method. This means it is not possible to recover a
* Method that failed mid-execution.
*/
ET_EXPERIMENTAL ET_NODISCARD Error reset_execution();
/// DEPRECATED: Use `reset_execution()` instead.
ET_DEPRECATED ET_NODISCARD Error experimental_reset_execution();
/**
* Returns the MethodMeta that corresponds to the calling Method.
*/
MethodMeta method_meta() const;
/**
* Returns the number of inputs the Method expects.
*/
size_t inputs_size() const;
/**
* Returns the number of outputs the Method returns.
*/
size_t outputs_size() const;
/**
* Retrieves the output at the specified index.
*/
const EValue& get_output(size_t i) const;
EventTracer* get_event_tracer();
/// DEPRECATED: Use MethodMeta instead to access metadata, and set_input to
/// update Method inputs.
ET_DEPRECATED const EValue& get_input(size_t i) const;
/// DEPRECATED: Use MethodMeta instead to access metadata, and set_input to
/// update Method inputs.
ET_DEPRECATED EValue& mutable_input(size_t i);
/// DEPRECATED: Use MethodMeta instead to access metadata, and get_output to
/// retrieve Method outputs.
ET_DEPRECATED EValue& mutable_output(size_t i);
~Method();
private:
// Delete other rule-of-five methods.
Method(const Method&) = delete;
Method& operator=(const Method&) noexcept = delete;
Method& operator=(Method&&) = delete;
// Let Program call load().
friend class Program;
// Let Executor call the ctor and init().
friend class Executor;
enum class InitializationState : uint8_t {
Uninitialized,
Initialized,
InitializationFailed,
};
/// Tracks what step in program execution we are on
struct StepState {
size_t chain_idx;
size_t instr_idx;
};
Method(
const Program* program,
MemoryManager* memory_manager,
EventTracer* event_tracer,
MemoryAllocator* temp_allocator)
: step_state_(),
program_(program),
memory_manager_(memory_manager),
temp_allocator_(temp_allocator),
serialization_plan_(nullptr),
event_tracer_(event_tracer),
n_value_(0),
values_(nullptr),
n_delegate_(0),
delegates_(nullptr),
n_chains_(0),
chains_(nullptr),
external_constants_(nullptr),
n_external_constants_(0),
init_state_(InitializationState::Uninitialized) {}
/// Static factory used by Program.
ET_NODISCARD static Result<Method> load(
executorch_flatbuffer::ExecutionPlan* s_plan,
const Program* program,
MemoryManager* memory_manager,
EventTracer* event_tracer,
const NamedDataMap* named_data_map);
/**
* Initialize the method from its serialized representation.
*
* @returns Error::Ok on success, non-Ok on failure.
*/
ET_NODISCARD Error init(
executorch_flatbuffer::ExecutionPlan* s_plan,
const NamedDataMap* named_data_map);
/// Returns true if the Method was successfully initialized.
inline bool initialized() const {
return init_state_ == InitializationState::Initialized;
}
const EValue& get_value(size_t i) const;
EValue& mutable_value(size_t i);
size_t get_input_index(size_t i) const;
size_t get_output_index(size_t i) const;
// Executes a single instruction using the state in step_state_
ET_NODISCARD Error execute_instruction();
StepState step_state_;
const Program* program_;
MemoryManager* memory_manager_;
MemoryAllocator* temp_allocator_;
executorch_flatbuffer::ExecutionPlan* serialization_plan_;
EventTracer* event_tracer_;
size_t n_value_;
EValue* values_;
size_t n_delegate_;
BackendDelegate* delegates_;
size_t n_chains_;
Chain* chains_;
NamedData* external_constants_;
size_t n_external_constants_ = 0;
InitializationState init_state_;
/**
* Counts the number of tensors marked as EXTERNAL in the flatbuffer
* for this method.
*/
ET_NODISCARD Result<size_t> get_num_external_constants();
/**
* Parses the flatbuffer for constant tensors tagged as EXTERNAL.
* Retrieves the external constants using the named_data_map and places them
* into `external_constants_`. Updates `n_external_constants_` to count the
* number of successfully-initialized external constants.
* FreeableBuffers returned by the named_data_map are owned by the
* method and are freed on method destruction.
*
* @param[in] named_data_map, to retrieve external constants from.
* @returns Error::Ok on success, non-Ok on failure.
*/
ET_NODISCARD Error
parse_external_constants(const NamedDataMap* named_data_map);
/**
* Parses the elements of the values_ array. On error, n_value_ will be set to
* the number of successfully-initialized entries so that ~Method doesn't try
* to clean up uninitialized entries.
*/
ET_NODISCARD Error parse_values(const NamedDataMap* named_data_map);
ET_NODISCARD Error resolve_operator(
int32_t op_index,
OpFunction* kernels,
size_t kernel_index,
InstructionArgs args,
size_t n_args);
void log_outputs();
};
} // namespace runtime
} // namespace executorch
namespace torch {
namespace executor {
// TODO(T197294990): Remove these deprecated aliases once all users have moved
// to the new `::executorch` namespaces.
using ::executorch::runtime::Method;
} // namespace executor
} // namespace torch
#ifdef __GNUC__
#pragma GCC diagnostic pop
#endif