diff --git a/tensorflow_lite_support/cc/task/audio/audio_classifier.cc b/tensorflow_lite_support/cc/task/audio/audio_classifier.cc index 8e23a02ba..e0ba5febd 100644 --- a/tensorflow_lite_support/cc/task/audio/audio_classifier.cc +++ b/tensorflow_lite_support/cc/task/audio/audio_classifier.cc @@ -63,9 +63,10 @@ StatusOr> AudioClassifier::CreateFromOptions( // Copy options to ensure the ExternalFile outlives the constructed object. auto options_copy = absl::make_unique(options); - ASSIGN_OR_RETURN(auto audio_classifier, - TaskAPIFactory::CreateFromBaseOptions( - &options_copy->base_options(), std::move(resolver))); + ASSIGN_OR_RETURN( + auto audio_classifier, + TaskAPIFactory::CreateFromExternalFileProto( + &options_copy->base_options().model_file(), std::move(resolver))); // TODO(b/182625132): Retrieve the required audio format from the model // metadata. Return an error status if the audio format metadata are missed in @@ -83,6 +84,12 @@ absl::Status AudioClassifier::SanityCheckOptions( "Missing mandatory `base_options` field", TfLiteSupportStatus::kInvalidArgumentError); } + if (!options.base_options().has_model_file()) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + "Missing mandatory `model_file` field in `base_options`", + TfLiteSupportStatus::kInvalidArgumentError); + } if (options.max_results() == 0) { return CreateStatusWithPayload( StatusCode::kInvalidArgument, diff --git a/tensorflow_lite_support/cc/task/core/BUILD b/tensorflow_lite_support/cc/task/core/BUILD index 65312f76d..0ec7a271d 100644 --- a/tensorflow_lite_support/cc/task/core/BUILD +++ b/tensorflow_lite_support/cc/task/core/BUILD @@ -67,7 +67,6 @@ cc_library_with_tflite( "//tensorflow_lite_support/cc/port:configuration_proto_inc", "//tensorflow_lite_support/cc/port:status_macros", "//tensorflow_lite_support/cc/port:statusor", - "//tensorflow_lite_support/cc/task/core/proto:base_options_proto_inc", "//tensorflow_lite_support/cc/task/core/proto:external_file_proto_inc", "@com_google_absl//absl/status", "@org_tensorflow//tensorflow/lite/c:common", diff --git a/tensorflow_lite_support/cc/task/core/proto/BUILD b/tensorflow_lite_support/cc/task/core/proto/BUILD index dba1f76bf..dd2e4be22 100644 --- a/tensorflow_lite_support/cc/task/core/proto/BUILD +++ b/tensorflow_lite_support/cc/task/core/proto/BUILD @@ -28,7 +28,6 @@ proto_library( srcs = ["base_options.proto"], deps = [ ":external_file_proto", - "@org_tensorflow//tensorflow/lite/experimental/acceleration/configuration:configuration_proto", ], ) @@ -45,7 +44,6 @@ cc_library( deps = [ ":base_options_cc_proto", ":external_file_proto_inc", - "//tensorflow_lite_support/cc/port:configuration_proto_inc", ], ) diff --git a/tensorflow_lite_support/cc/task/core/proto/base_options.proto b/tensorflow_lite_support/cc/task/core/proto/base_options.proto index 9d529d64b..de3e5b92a 100644 --- a/tensorflow_lite_support/cc/task/core/proto/base_options.proto +++ b/tensorflow_lite_support/cc/task/core/proto/base_options.proto @@ -17,12 +17,10 @@ syntax = "proto2"; package tflite.task.core; -import "tensorflow/lite/experimental/acceleration/configuration/configuration.proto"; - import "tensorflow_lite_support/cc/task/core/proto/external_file.proto"; // Base options for task libraries. -// Next Id: 4 +// Next Id: 2 message BaseOptions { // The external model file, as a single standalone TFLite file. It could be // packed with TFLite Model Metadata[1] and associated files if exist. Fail to @@ -31,21 +29,4 @@ message BaseOptions { // [1]: https://www.tensorflow.org/lite/convert/metadata optional core.ExternalFile model_file = 1; - - // Advanced settings specifying how to accelerate the model inference using - // dedicated delegates. Supported delegate type includes: - // NONE, NNAPI, GPU, HEXAGON, XNNPACK, EDGETPU (Google internal), - // and EDGETPU_CORAL. - // - // IMPORTANT: in order to use a delegate, the appropriate delegate plugin - // needs to be linked at build time. - // - // For example, `gpu_plugin` for GPU from: - // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/experimental/acceleration/configuration/BUILD - // To use EDGETPU_CORAL, link to `edgetpu_coral_plugin` from: - // https://github.com/tensorflow/tflite-support/blob/a58a4f9225c411fa9ba29f821523e6e283988d23/tensorflow_lite_support/acceleration/configuration/BUILD#L11 - // - // See settings definition at: - // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/experimental/acceleration/configuration/configuration.proto - optional tflite.proto.ComputeSettings compute_settings = 2; -} +} \ No newline at end of file diff --git a/tensorflow_lite_support/cc/task/core/proto/base_options_proto_inc.h b/tensorflow_lite_support/cc/task/core/proto/base_options_proto_inc.h index 4c53a2f5c..254a3e070 100644 --- a/tensorflow_lite_support/cc/task/core/proto/base_options_proto_inc.h +++ b/tensorflow_lite_support/cc/task/core/proto/base_options_proto_inc.h @@ -16,8 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_PROTO_BASE_OPTIONS_PROTO_INC_H_ #define TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_PROTO_BASE_OPTIONS_PROTO_INC_H_ -#include "tensorflow_lite_support/cc/port/configuration_proto_inc.h" #include "tensorflow_lite_support/cc/task/core/proto/base_options.pb.h" -#include "tensorflow_lite_support/cc/task/core/proto/external_file_proto_inc.h" #endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_PROTO_BASE_OPTIONS_PROTO_INC_H_ diff --git a/tensorflow_lite_support/cc/task/core/task_api_factory.h b/tensorflow_lite_support/cc/task/core/task_api_factory.h index 8aef6bb03..f7bbcdc59 100644 --- a/tensorflow_lite_support/cc/task/core/task_api_factory.h +++ b/tensorflow_lite_support/cc/task/core/task_api_factory.h @@ -25,7 +25,6 @@ limitations under the License. #include "tensorflow_lite_support/cc/port/status_macros.h" #include "tensorflow_lite_support/cc/port/statusor.h" #include "tensorflow_lite_support/cc/task/core/base_task_api.h" -#include "tensorflow_lite_support/cc/task/core/proto/base_options_proto_inc.h" #include "tensorflow_lite_support/cc/task/core/proto/external_file_proto_inc.h" #include "tensorflow_lite_support/cc/task/core/tflite_engine.h" @@ -42,8 +41,6 @@ class TaskAPIFactory { public: TaskAPIFactory() = delete; - // Deprecated: use CreateFromBaseOptions and configure model input from - // tensorflow_lite_support/cc/task/core/proto/base_options.proto template = nullptr> static tflite::support::StatusOr> CreateFromBuffer( const char* buffer_data, size_t buffer_size, @@ -59,8 +56,6 @@ class TaskAPIFactory { compute_settings); } - // Deprecated: use CreateFromBaseOptions and configure model input from - // tensorflow_lite_support/cc/task/core/proto/base_options.proto template = nullptr> static tflite::support::StatusOr> CreateFromFile( const string& file_name, @@ -75,8 +70,6 @@ class TaskAPIFactory { compute_settings); } - // Deprecated: use CreateFromBaseOptions and configure model input from - // tensorflow_lite_support/cc/task/core/proto/base_options.proto template = nullptr> static tflite::support::StatusOr> CreateFromFileDescriptor( int file_descriptor, @@ -92,8 +85,6 @@ class TaskAPIFactory { compute_settings); } - // Deprecated: use CreateFromBaseOptions and configure model input from - // tensorflow_lite_support/cc/task/core/proto/base_options.proto template = nullptr> static tflite::support::StatusOr> CreateFromExternalFileProto( @@ -110,45 +101,13 @@ class TaskAPIFactory { compute_settings); } - template = nullptr> - static tflite::support::StatusOr> CreateFromBaseOptions( - const BaseOptions* base_options, - std::unique_ptr resolver = - absl::make_unique()) { - if (!base_options->has_model_file()) { - return CreateStatusWithPayload( - absl::StatusCode::kInvalidArgument, - "Missing mandatory `model_file` field in `base_options`", - tflite::support::TfLiteSupportStatus::kInvalidArgumentError); - } - - auto engine = absl::make_unique(std::move(resolver)); - RETURN_IF_ERROR(engine->BuildModelFromExternalFileProto( - &base_options->model_file(), base_options->compute_settings())); - return CreateFromTfLiteEngine(std::move(engine), - base_options->compute_settings()); - } - private: template = nullptr> static tflite::support::StatusOr> CreateFromTfLiteEngine( std::unique_ptr engine, int num_threads, const tflite::proto::ComputeSettings& compute_settings = tflite::proto::ComputeSettings()) { - tflite::proto::ComputeSettings settings_copy = - tflite::proto::ComputeSettings(compute_settings); - settings_copy.mutable_tflite_settings() - ->mutable_cpu_settings() - ->set_num_threads(num_threads); - return CreateFromTfLiteEngine(std::move(engine), settings_copy); - } - - template = nullptr> - static tflite::support::StatusOr> CreateFromTfLiteEngine( - std::unique_ptr engine, - const tflite::proto::ComputeSettings& compute_settings = - tflite::proto::ComputeSettings()) { - RETURN_IF_ERROR(engine->InitInterpreter(compute_settings)); + RETURN_IF_ERROR(engine->InitInterpreter(compute_settings, num_threads)); return absl::make_unique(std::move(engine)); } }; diff --git a/tensorflow_lite_support/cc/task/core/tflite_engine.cc b/tensorflow_lite_support/cc/task/core/tflite_engine.cc index 10eb3a589..0c41d70e3 100644 --- a/tensorflow_lite_support/cc/task/core/tflite_engine.cc +++ b/tensorflow_lite_support/cc/task/core/tflite_engine.cc @@ -47,11 +47,11 @@ static std::ios_base::Init s_iostream_initializer; #endif using ::absl::StatusCode; -using ::tflite::proto::ComputeSettings; using ::tflite::support::CreateStatusWithPayload; -using ::tflite::support::InterpreterCreationResources; using ::tflite::support::TfLiteSupportStatus; +using ::tflite::support::InterpreterCreationResources; + bool TfLiteEngine::Verifier::Verify(const char* data, int length, tflite::ErrorReporter* reporter) { return tflite::Verify(data, length, reporter); @@ -185,28 +185,11 @@ absl::Status TfLiteEngine::BuildModelFromExternalFileProto( absl::Status TfLiteEngine::InitInterpreter(int num_threads) { tflite::proto::ComputeSettings compute_settings; - compute_settings.mutable_tflite_settings() - ->mutable_cpu_settings() - ->set_num_threads(num_threads); - return InitInterpreter(compute_settings); + return InitInterpreter(compute_settings, num_threads); } -// TODO(b/183798104): deprecate num_threads in VK task protos. -// Deprecated. Use the following method, and configure `num_threads` through -// `compute_settings`, i.e. in `CPUSettings`: -// absl::Status TfLiteEngine::InitInterpreter( -// const tflite::proto::ComputeSettings& compute_settings) absl::Status TfLiteEngine::InitInterpreter( const tflite::proto::ComputeSettings& compute_settings, int num_threads) { - ComputeSettings settings_copy = ComputeSettings(compute_settings); - settings_copy.mutable_tflite_settings() - ->mutable_cpu_settings() - ->set_num_threads(num_threads); - return InitInterpreter(settings_copy); -} - -absl::Status TfLiteEngine::InitInterpreter( - const tflite::proto::ComputeSettings& compute_settings) { if (model_ == nullptr) { return CreateStatusWithPayload( StatusCode::kInternal, @@ -214,13 +197,13 @@ absl::Status TfLiteEngine::InitInterpreter( "BuildModelFrom methods before calling InitInterpreter."); } auto initializer = - [this]( + [this, num_threads]( const InterpreterCreationResources& resources, std::unique_ptr* interpreter_out) -> absl::Status { tflite_shims::InterpreterBuilder interpreter_builder(*model_, *resolver_); resources.ApplyTo(&interpreter_builder); - if (interpreter_builder(interpreter_out) != kTfLiteOk) { + if (interpreter_builder(interpreter_out, num_threads) != kTfLiteOk) { return CreateStatusWithPayload( StatusCode::kUnknown, absl::StrCat("Could not build the TF Lite interpreter: ", diff --git a/tensorflow_lite_support/cc/task/core/tflite_engine.h b/tensorflow_lite_support/cc/task/core/tflite_engine.h index 7c2835873..43f933fc4 100644 --- a/tensorflow_lite_support/cc/task/core/tflite_engine.h +++ b/tensorflow_lite_support/cc/task/core/tflite_engine.h @@ -126,17 +126,10 @@ class TfLiteEngine { // the value. absl::Status InitInterpreter(int num_threads = 1); - // Initializes interpreter with acceleration configurations. - absl::Status InitInterpreter( - const tflite::proto::ComputeSettings& compute_settings); - - // Deprecated. Use the following method, and configure `num_threads` through - // `compute_settings`, i.e. in `CPUSettings`: - // absl::Status TfLiteEngine::InitInterpreter( - // const tflite::proto::ComputeSettings& compute_settings) + // Same as above, but allows specifying `compute_settings` for acceleration. absl::Status InitInterpreter( const tflite::proto::ComputeSettings& compute_settings, - int num_threads); + int num_threads = 1); // Cancels the on-going `Invoke()` call if any and if possible. This method // can be called from a different thread than the one where `Invoke()` is