forked from google-coral/edgetpu
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbasic_engine.h
51 lines (43 loc) · 1.89 KB
/
basic_engine.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
#ifndef EDGETPU_CPP_BASIC_BASIC_ENGINE_H_
#define EDGETPU_CPP_BASIC_BASIC_ENGINE_H_
#include <vector>
#include "src/cpp/basic/basic_engine_native.h"
namespace coral {
// BasicEngine wraps given model, creates interpreter and initializes EdgetTpu.
class BasicEngine {
public:
// Loads TFlite model and initializes interpreter.
// - 'model_path' : the file path of the model.
explicit BasicEngine(const std::string& model_path);
// Similar to above, but uses Edge TPU specified at `device_path`.
explicit BasicEngine(const std::string& model_path,
const std::string& device_path);
// Initializes BasicEngine with FlatBufferModel object.
explicit BasicEngine(std::unique_ptr<tflite::FlatBufferModel> model);
// Initializes BasicEngine with FlatBufferModel object and customized
// resolver.
explicit BasicEngine(
std::unique_ptr<tflite::FlatBufferModel> model,
std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver> resolver);
// For input, we assume there is only one tensor with type uint8_t.
// The input for RunInference is the flattened array of input tensor.
std::vector<std::vector<float>> RunInference(
const std::vector<uint8_t>& input);
// Functions to get/check attributes.
// Gets device path associated with Edge TPU.
std::string device_path() const;
// Gets the path of the model.
std::string model_path() const;
// Gets shape of input tensor.
std::vector<int> get_input_tensor_shape() const;
// Gets shapes of output tensors. We assume that all output tensors are
// in 1 dimension so the output is an array of lengths for each output
// tensor.
std::vector<int> get_all_output_tensors_sizes() const;
// Gets time consumed for last inference (milliseconds).
float get_inference_time() const;
private:
std::unique_ptr<BasicEngineNative> engine_;
};
} // namespace coral
#endif // EDGETPU_CPP_BASIC_BASIC_ENGINE_H_