forked from tensorflow/tfjs
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.h
329 lines (278 loc) · 11.7 KB
/
utils.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
#ifndef TF_NODEJS_UTILS_H_
#define TF_NODEJS_UTILS_H_
#include <node_api.h>
#include <stdarg.h>
#include <stdio.h>
#include <cstdlib>
#include <cstring>
#include <vector>
#include "tensorflow/c/c_api.h"
#include "tf_auto_status.h"
#define MAX_TENSOR_SHAPE 4
#define ARRAY_SIZE(array) (sizeof(array) / sizeof(array[0]))
#ifndef DEBUG
#define DEBUG 0
#endif
#define DEBUG_LOG(message, file, line_number) \
do { \
if (DEBUG) \
fprintf(stderr, "** -%s:%zu\n-- %s\n", file, line_number, message); \
} while (0)
namespace tfnodejs {
#define NAPI_THROW_ERROR(env, message, ...) \
NapiThrowError(env, __FILE__, __LINE__, message, ##__VA_ARGS__);
inline void NapiThrowError(napi_env env, const char *file,
const size_t line_number, const char *message, ...) {
char buffer[500];
va_list args;
va_start(args, message);
std::vsnprintf(buffer, 500, message, args);
va_end(args);
DEBUG_LOG(buffer, file, line_number);
napi_throw_error(env, nullptr, buffer);
}
#define ENSURE_NAPI_OK(env, status) \
if (!EnsureNapiOK(env, status, __FILE__, __LINE__)) return;
#define ENSURE_NAPI_OK_RETVAL(env, status, retval) \
if (!EnsureNapiOK(env, status, __FILE__, __LINE__)) return retval;
inline bool EnsureNapiOK(napi_env env, napi_status status, const char *file,
const size_t line_number) {
if (status != napi_ok) {
const napi_extended_error_info *error_info = 0;
napi_get_last_error_info(env, &error_info);
NapiThrowError(
env, file, line_number, "Invalid napi_status: %s\n",
error_info->error_message ? error_info->error_message : "unknown");
}
return status == napi_ok;
}
#define ENSURE_TF_OK(env, status) \
if (!EnsureTFOK(env, status, __FILE__, __LINE__)) return;
#define ENSURE_TF_OK_RETVAL(env, status, retval) \
if (!EnsureTFOK(env, status, __FILE__, __LINE__)) return retval;
inline bool EnsureTFOK(napi_env env, TF_AutoStatus &status, const char *file,
const size_t line_number) {
TF_Code tf_code = TF_GetCode(status.status);
if (tf_code != TF_OK) {
NapiThrowError(env, file, line_number, "Invalid TF_Status: %u\nMessage: %s",
TF_GetCode(status.status), TF_Message(status.status));
}
return tf_code == TF_OK;
}
#define ENSURE_CONSTRUCTOR_CALL(env, info) \
if (!EnsureConstructorCall(env, info, __FILE__, __LINE__)) return;
#define ENSURE_CONSTRUCTOR_CALL_RETVAL(env, info, retval) \
if (!EnsureConstructorCall(env, info, __FILE__, __LINE__)) return retval;
inline bool EnsureConstructorCall(napi_env env, napi_callback_info info,
const char *file, const size_t line_number) {
napi_value js_target;
napi_status nstatus = napi_get_new_target(env, info, &js_target);
ENSURE_NAPI_OK_RETVAL(env, nstatus, false);
bool is_target = js_target != nullptr;
if (!is_target) {
NapiThrowError(env, file, line_number,
"Function not used as a constructor!");
}
return is_target;
}
#define ENSURE_VALUE_IS_OBJECT(env, value) \
if (!EnsureValueIsObject(env, value, __FILE__, __LINE__)) return;
#define ENSURE_VALUE_IS_OBJECT_RETVAL(env, value, retval) \
if (!EnsureValueIsObject(env, value, __FILE__, __LINE__)) return retval;
inline bool EnsureValueIsObject(napi_env env, napi_value value,
const char *file, const size_t line_number) {
napi_valuetype type;
ENSURE_NAPI_OK_RETVAL(env, napi_typeof(env, value, &type), false);
bool is_object = type == napi_object;
if (!is_object) {
NapiThrowError(env, file, line_number, "Argument is not an object!");
}
return is_object;
}
#define ENSURE_VALUE_IS_STRING(env, value) \
if (!EnsureValueIsString(env, value, __FILE__, __LINE__)) return;
#define ENSURE_VALUE_IS_STRING_RETVAL(env, value, retval) \
if (!EnsureValueIsString(env, value, __FILE__, __LINE__)) return retval;
inline bool EnsureValueIsString(napi_env env, napi_value value,
const char *file, const size_t line_number) {
napi_valuetype type;
ENSURE_NAPI_OK_RETVAL(env, napi_typeof(env, value, &type), false);
bool is_string = type == napi_string;
if (!is_string) {
NapiThrowError(env, file, line_number, "Argument is not a string!");
}
return is_string;
}
#define ENSURE_VALUE_IS_NUMBER(env, value) \
if (!EnsureValueIsNumber(env, value, __FILE__, __LINE__)) return;
#define ENSURE_VALUE_IS_NUMBER_RETVAL(env, value, retval) \
if (!EnsureValueIsNumber(env, value, __FILE__, __LINE__)) return retval;
inline bool EnsureValueIsNumber(napi_env env, napi_value value,
const char *file, const size_t line_number) {
napi_valuetype type;
ENSURE_NAPI_OK_RETVAL(env, napi_typeof(env, value, &type), false);
bool is_number = type == napi_number;
if (!is_number) {
NapiThrowError(env, file, line_number, "Argument is not a string!");
}
return is_number;
}
#define ENSURE_VALUE_IS_ARRAY(env, value) \
if (!EnsureValueIsArray(env, value, __FILE__, __LINE__)) return;
#define ENSURE_VALUE_IS_ARRAY_RETVAL(env, value, retval) \
if (!EnsureValueIsArray(env, value, __FILE__, __LINE__)) return retval;
inline bool EnsureValueIsArray(napi_env env, napi_value value, const char *file,
const size_t line_number) {
bool is_array;
ENSURE_NAPI_OK_RETVAL(env, napi_is_array(env, value, &is_array), false);
if (!is_array) {
NapiThrowError(env, file, line_number, "Argument is not an array!");
}
return is_array;
}
#define ENSURE_VALUE_IS_TYPED_ARRAY(env, value) \
if (!EnsureValueIsTypedArray(env, value, __FILE__, __LINE__)) return;
#define ENSURE_VALUE_IS_TYPED_ARRAY_RETVAL(env, value, retval) \
if (!EnsureValueIsTypedArray(env, value, __FILE__, __LINE__)) return retval;
inline bool EnsureValueIsTypedArray(napi_env env, napi_value value,
const char *file,
const size_t line_number) {
bool is_array;
ENSURE_NAPI_OK_RETVAL(env, napi_is_typedarray(env, value, &is_array), false);
if (!is_array) {
NapiThrowError(env, file, line_number, "Argument is not a typed-array!");
}
return is_array;
}
#define ENSURE_VALUE_IS_LESS_THAN(env, value, max) \
if (!EnsureValueIsLessThan(env, value, max, __FILE__, __LINE__)) return;
#define ENSURE_VALUE_IS_LESS_THAN_RETVAL(env, value, max, retval) \
if (!EnsureValueIsLessThan(env, value, max, __FILE__, __LINE__)) \
return retval;
inline bool EnsureValueIsLessThan(napi_env env, uint32_t value, uint32_t max,
const char *file, const size_t line_number) {
if (value > max) {
NapiThrowError(env, file, line_number,
"Argument is greater than max: %u > %u", value, max);
return false;
} else {
return true;
}
}
#define REPORT_UNKNOWN_TF_DATA_TYPE(env, type) \
ReportUnknownTFDataType(env, type, __FILE__, __LINE__)
inline void ReportUnknownTFDataType(napi_env env, TF_DataType type,
const char *file,
const size_t line_number) {
NapiThrowError(env, file, line_number, "Unhandled TF_DataType: %u\n", type);
}
#define REPORT_UNKNOWN_TF_ATTR_TYPE(env, type) \
ReportUnknownTFAttrType(env, type, __FILE__, __LINE__)
inline void ReportUnknownTFAttrType(napi_env env, TF_AttrType type,
const char *file,
const size_t line_number) {
NapiThrowError(env, file, line_number, "Unhandled TF_AttrType: %u\n", type);
}
#define REPORT_UNKNOWN_TYPED_ARRAY_TYPE(env, type) \
ReportUnknownTypedArrayType(env, type, __FILE__, __LINE__)
inline void ReportUnknownTypedArrayType(napi_env env, napi_typedarray_type type,
const char *file,
const size_t line_number) {
NapiThrowError(env, file, line_number, "Unhandled napi typed_array_type: %u",
type);
}
// Returns a vector with the shape values of an array.
inline void ExtractArrayShape(napi_env env, napi_value array_value,
std::vector<int64_t> *result) {
napi_status nstatus;
uint32_t array_length;
nstatus = napi_get_array_length(env, array_value, &array_length);
ENSURE_NAPI_OK(env, nstatus);
for (uint32_t i = 0; i < array_length; i++) {
napi_value dimension_value;
nstatus = napi_get_element(env, array_value, i, &dimension_value);
ENSURE_NAPI_OK(env, nstatus);
int64_t dimension;
nstatus = napi_get_value_int64(env, dimension_value, &dimension);
ENSURE_NAPI_OK(env, nstatus);
result->push_back(dimension);
}
}
inline bool IsExceptionPending(napi_env env) {
bool has_exception = false;
ENSURE_NAPI_OK_RETVAL(env, napi_is_exception_pending(env, &has_exception),
has_exception);
return has_exception;
}
#define ENSURE_VALUE_IS_NOT_NULL(env, value) \
if (!EnsureValueIsNotNull(env, value, __FILE__, __LINE__)) return;
#define ENSURE_VALUE_IS_NOT_NULL_RETVAL(env, value, retval) \
if (!EnsureValueIsNotNull(env, value, __FILE__, __LINE__)) return retval;
inline bool EnsureValueIsNotNull(napi_env env, void *value, const char *file,
const size_t line_number) {
bool is_null = value == nullptr;
if (is_null) {
NapiThrowError(env, file, line_number, "Argument is null!");
}
return !is_null;
}
inline napi_status GetStringParam(napi_env env, napi_value string_value,
std::string &string) {
ENSURE_VALUE_IS_STRING_RETVAL(env, string_value, napi_invalid_arg);
napi_status nstatus;
size_t str_length;
nstatus =
napi_get_value_string_utf8(env, string_value, nullptr, 0, &str_length);
ENSURE_NAPI_OK_RETVAL(env, nstatus, nstatus);
char *buffer = (char *)(malloc(sizeof(char) * (str_length + 1)));
ENSURE_VALUE_IS_NOT_NULL_RETVAL(env, buffer, napi_generic_failure);
nstatus = napi_get_value_string_utf8(env, string_value, buffer,
str_length + 1, &str_length);
ENSURE_NAPI_OK_RETVAL(env, nstatus, nstatus);
string.assign(buffer, str_length);
free(buffer);
return napi_ok;
}
// Returns the number of elements in a Tensor.
inline size_t GetTensorNumElements(TF_Tensor *tensor) {
size_t ret = 1;
for (int i = 0; i < TF_NumDims(tensor); ++i) {
ret *= TF_Dim(tensor, i);
}
return ret;
}
// Split a string into an array of characters array with `,` as delimiter.
inline std::vector<const char *> splitStringByComma(const std::string &str) {
std::vector<const char *> tokens;
size_t prev = 0, pos = 0;
do {
pos = str.find(',', prev);
if (pos == std::string::npos) pos = str.length();
std::string token = str.substr(prev, pos - prev);
if (!token.empty()) {
char *cstr = new char[str.length() + 1];
std::strcpy(cstr, token.c_str());
tokens.push_back(cstr);
}
prev = pos + 1;
} while (pos < str.length() && prev < str.length());
return tokens;
}
} // namespace tfnodejs
#endif // TF_NODEJS_UTILS_H_