-
Notifications
You must be signed in to change notification settings - Fork 511
/
Copy pathtensor_parser_aten.cpp
130 lines (113 loc) · 3.99 KB
/
tensor_parser_aten.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <executorch/runtime/executor/tensor_parser.h>
#include <executorch/runtime/core/exec_aten/util/dim_order_util.h>
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
#include <executorch/runtime/core/named_data_map.h>
#include <executorch/runtime/executor/memory_manager.h>
#include <executorch/runtime/executor/program.h>
#include <executorch/runtime/platform/profiler.h>
#include <executorch/schema/program_generated.h>
#include <ATen/ATen.h> // @donotremove @manual=//caffe2/aten:ATen-core
namespace executorch {
namespace runtime {
namespace deserialization {
namespace {
void deleteNothing(void*);
void deleteNothing(void*) {}
} // namespace
Result<at::Tensor> parseTensor(
const Program* program,
MemoryManager* memory_manager,
const executorch_flatbuffer::Tensor* s_tensor,
const NamedDataMap* named_data_map,
Span<NamedData> external_constants) {
EXECUTORCH_SCOPE_PROF("TensorParser::parseTensor");
ET_CHECK_OR_RETURN_ERROR(
s_tensor->storage_offset() == 0,
NotSupported,
"Non-zero storage offset %" PRId32 " not supported",
s_tensor->storage_offset());
// get metadata
at::ScalarType type = static_cast<at::ScalarType>(s_tensor->scalar_type());
ET_CHECK_OR_RETURN_ERROR(
isValid(type),
InvalidProgram,
"Invalid ScalarType %" PRId8,
static_cast<int8_t>(type));
auto options = at::CPU(type).options();
ET_CHECK_OR_RETURN_ERROR(
s_tensor->sizes() != nullptr, InvalidProgram, "Missing sizes field");
size_t ndim = s_tensor->sizes()->size();
ET_CHECK_OR_RETURN_ERROR(
s_tensor->dim_order() != nullptr,
InvalidProgram,
"Missing dim_order field");
ET_CHECK_OR_RETURN_ERROR(
s_tensor->dim_order()->size() == ndim,
InvalidProgram,
"dim_order size %" PRIu32 " != ndim %zu",
s_tensor->dim_order()->size(),
ndim);
// convert int32 in serialization to int64 for aten
std::vector<int64_t> sizes(
s_tensor->sizes()->begin(), s_tensor->sizes()->end());
std::vector<int64_t> strides(ndim);
auto status = dim_order_to_stride(
s_tensor->sizes()->data(),
s_tensor->dim_order()->data(),
ndim,
strides.data());
ET_CHECK_OR_RETURN_ERROR(
status == Error::Ok,
Internal,
"dim_order_to_stride returned invalid status");
// Create a tensor without data first so we can find its expected size before
// getting its memory.
at::Tensor tensor = at::from_blob(
/*data=*/nullptr,
sizes,
strides,
/*storage_offset=*/0,
deleteNothing,
options);
if (s_tensor->shape_dynamism() ==
executorch_flatbuffer::TensorShapeDynamism::DYNAMIC_UNBOUND) {
// Provide fully dynamic tensors with an allocator so they can be resized
// within aten kernels.
auto impl = tensor.unsafeGetTensorImpl();
at::StorageImpl* storage = impl->unsafe_storage().unsafeGetStorageImpl();
storage->set_allocator(at::getCPUAllocator());
storage->set_resizable(true);
storage->set_nbytes(0);
impl->set_sizes_contiguous(0);
// Leave the data as nullptr since it will be reallocated.
} else {
// Now that we know how big the tensor is, find and assign its memory.
Result<void*> data_ptr = getTensorDataPtr(
s_tensor,
program,
tensor.nbytes(),
memory_manager->planned_memory(),
named_data_map,
external_constants);
if (!data_ptr.ok()) {
ET_LOG(
Error,
"getTensorDataPtr() failed: 0x%" PRIx32,
static_cast<uint32_t>(data_ptr.error()));
return data_ptr.error();
}
tensor.unsafeGetTensorImpl()->unsafe_storage().set_data_ptr(
at::DataPtr(data_ptr.get(), c10::DeviceType::CPU));
}
return tensor;
}
} // namespace deserialization
} // namespace runtime
} // namespace executorch