Skip to content

Commit

Permalink
Add ComputeSettings and num_threads in the BaseOptions
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 366159856
  • Loading branch information
lu-wang-g authored and tflite-support-robot committed Apr 1, 2021
1 parent e1212e0 commit 0d034ba
Show file tree
Hide file tree
Showing 8 changed files with 20 additions and 102 deletions.
13 changes: 10 additions & 3 deletions tensorflow_lite_support/cc/task/audio/audio_classifier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,10 @@ StatusOr<std::unique_ptr<AudioClassifier>> AudioClassifier::CreateFromOptions(
// Copy options to ensure the ExternalFile outlives the constructed object.
auto options_copy = absl::make_unique<AudioClassifierOptions>(options);

ASSIGN_OR_RETURN(auto audio_classifier,
TaskAPIFactory::CreateFromBaseOptions<AudioClassifier>(
&options_copy->base_options(), std::move(resolver)));
ASSIGN_OR_RETURN(
auto audio_classifier,
TaskAPIFactory::CreateFromExternalFileProto<AudioClassifier>(
&options_copy->base_options().model_file(), std::move(resolver)));

// TODO(b/182625132): Retrieve the required audio format from the model
// metadata. Return an error status if the audio format metadata are missed in
Expand All @@ -83,6 +84,12 @@ absl::Status AudioClassifier::SanityCheckOptions(
"Missing mandatory `base_options` field",
TfLiteSupportStatus::kInvalidArgumentError);
}
if (!options.base_options().has_model_file()) {
return CreateStatusWithPayload(
StatusCode::kInvalidArgument,
"Missing mandatory `model_file` field in `base_options`",
TfLiteSupportStatus::kInvalidArgumentError);
}
if (options.max_results() == 0) {
return CreateStatusWithPayload(
StatusCode::kInvalidArgument,
Expand Down
1 change: 0 additions & 1 deletion tensorflow_lite_support/cc/task/core/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ cc_library_with_tflite(
"//tensorflow_lite_support/cc/port:configuration_proto_inc",
"//tensorflow_lite_support/cc/port:status_macros",
"//tensorflow_lite_support/cc/port:statusor",
"//tensorflow_lite_support/cc/task/core/proto:base_options_proto_inc",
"//tensorflow_lite_support/cc/task/core/proto:external_file_proto_inc",
"@com_google_absl//absl/status",
"@org_tensorflow//tensorflow/lite/c:common",
Expand Down
2 changes: 0 additions & 2 deletions tensorflow_lite_support/cc/task/core/proto/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ proto_library(
srcs = ["base_options.proto"],
deps = [
":external_file_proto",
"@org_tensorflow//tensorflow/lite/experimental/acceleration/configuration:configuration_proto",
],
)

Expand All @@ -45,7 +44,6 @@ cc_library(
deps = [
":base_options_cc_proto",
":external_file_proto_inc",
"//tensorflow_lite_support/cc/port:configuration_proto_inc",
],
)

Expand Down
23 changes: 2 additions & 21 deletions tensorflow_lite_support/cc/task/core/proto/base_options.proto
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,10 @@ syntax = "proto2";

package tflite.task.core;

import "tensorflow/lite/experimental/acceleration/configuration/configuration.proto";

import "tensorflow_lite_support/cc/task/core/proto/external_file.proto";

// Base options for task libraries.
// Next Id: 4
// Next Id: 2
message BaseOptions {
// The external model file, as a single standalone TFLite file. It could be
// packed with TFLite Model Metadata[1] and associated files if exist. Fail to
Expand All @@ -31,21 +29,4 @@ message BaseOptions {
// [1]: https://www.tensorflow.org/lite/convert/metadata

optional core.ExternalFile model_file = 1;

// Advanced settings specifying how to accelerate the model inference using
// dedicated delegates. Supported delegate type includes:
// NONE, NNAPI, GPU, HEXAGON, XNNPACK, EDGETPU (Google internal),
// and EDGETPU_CORAL.
//
// IMPORTANT: in order to use a delegate, the appropriate delegate plugin
// needs to be linked at build time.
//
// For example, `gpu_plugin` for GPU from:
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/experimental/acceleration/configuration/BUILD
// To use EDGETPU_CORAL, link to `edgetpu_coral_plugin` from:
// https://github.com/tensorflow/tflite-support/blob/a58a4f9225c411fa9ba29f821523e6e283988d23/tensorflow_lite_support/acceleration/configuration/BUILD#L11
//
// See settings definition at:
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/experimental/acceleration/configuration/configuration.proto
optional tflite.proto.ComputeSettings compute_settings = 2;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_PROTO_BASE_OPTIONS_PROTO_INC_H_
#define TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_PROTO_BASE_OPTIONS_PROTO_INC_H_

#include "tensorflow_lite_support/cc/port/configuration_proto_inc.h"
#include "tensorflow_lite_support/cc/task/core/proto/base_options.pb.h"
#include "tensorflow_lite_support/cc/task/core/proto/external_file_proto_inc.h"

#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_PROTO_BASE_OPTIONS_PROTO_INC_H_
43 changes: 1 addition & 42 deletions tensorflow_lite_support/cc/task/core/task_api_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ limitations under the License.
#include "tensorflow_lite_support/cc/port/status_macros.h"
#include "tensorflow_lite_support/cc/port/statusor.h"
#include "tensorflow_lite_support/cc/task/core/base_task_api.h"
#include "tensorflow_lite_support/cc/task/core/proto/base_options_proto_inc.h"
#include "tensorflow_lite_support/cc/task/core/proto/external_file_proto_inc.h"
#include "tensorflow_lite_support/cc/task/core/tflite_engine.h"

Expand All @@ -42,8 +41,6 @@ class TaskAPIFactory {
public:
TaskAPIFactory() = delete;

// Deprecated: use CreateFromBaseOptions and configure model input from
// tensorflow_lite_support/cc/task/core/proto/base_options.proto
template <typename T, EnableIfBaseUntypedTaskApiSubclass<T> = nullptr>
static tflite::support::StatusOr<std::unique_ptr<T>> CreateFromBuffer(
const char* buffer_data, size_t buffer_size,
Expand All @@ -59,8 +56,6 @@ class TaskAPIFactory {
compute_settings);
}

// Deprecated: use CreateFromBaseOptions and configure model input from
// tensorflow_lite_support/cc/task/core/proto/base_options.proto
template <typename T, EnableIfBaseUntypedTaskApiSubclass<T> = nullptr>
static tflite::support::StatusOr<std::unique_ptr<T>> CreateFromFile(
const string& file_name,
Expand All @@ -75,8 +70,6 @@ class TaskAPIFactory {
compute_settings);
}

// Deprecated: use CreateFromBaseOptions and configure model input from
// tensorflow_lite_support/cc/task/core/proto/base_options.proto
template <typename T, EnableIfBaseUntypedTaskApiSubclass<T> = nullptr>
static tflite::support::StatusOr<std::unique_ptr<T>> CreateFromFileDescriptor(
int file_descriptor,
Expand All @@ -92,8 +85,6 @@ class TaskAPIFactory {
compute_settings);
}

// Deprecated: use CreateFromBaseOptions and configure model input from
// tensorflow_lite_support/cc/task/core/proto/base_options.proto
template <typename T, EnableIfBaseUntypedTaskApiSubclass<T> = nullptr>
static tflite::support::StatusOr<std::unique_ptr<T>>
CreateFromExternalFileProto(
Expand All @@ -110,45 +101,13 @@ class TaskAPIFactory {
compute_settings);
}

template <typename T, EnableIfBaseUntypedTaskApiSubclass<T> = nullptr>
static tflite::support::StatusOr<std::unique_ptr<T>> CreateFromBaseOptions(
const BaseOptions* base_options,
std::unique_ptr<tflite::OpResolver> resolver =
absl::make_unique<tflite_shims::ops::builtin::BuiltinOpResolver>()) {
if (!base_options->has_model_file()) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
"Missing mandatory `model_file` field in `base_options`",
tflite::support::TfLiteSupportStatus::kInvalidArgumentError);
}

auto engine = absl::make_unique<TfLiteEngine>(std::move(resolver));
RETURN_IF_ERROR(engine->BuildModelFromExternalFileProto(
&base_options->model_file(), base_options->compute_settings()));
return CreateFromTfLiteEngine<T>(std::move(engine),
base_options->compute_settings());
}

private:
template <typename T, EnableIfBaseUntypedTaskApiSubclass<T> = nullptr>
static tflite::support::StatusOr<std::unique_ptr<T>> CreateFromTfLiteEngine(
std::unique_ptr<TfLiteEngine> engine, int num_threads,
const tflite::proto::ComputeSettings& compute_settings =
tflite::proto::ComputeSettings()) {
tflite::proto::ComputeSettings settings_copy =
tflite::proto::ComputeSettings(compute_settings);
settings_copy.mutable_tflite_settings()
->mutable_cpu_settings()
->set_num_threads(num_threads);
return CreateFromTfLiteEngine<T>(std::move(engine), settings_copy);
}

template <typename T, EnableIfBaseUntypedTaskApiSubclass<T> = nullptr>
static tflite::support::StatusOr<std::unique_ptr<T>> CreateFromTfLiteEngine(
std::unique_ptr<TfLiteEngine> engine,
const tflite::proto::ComputeSettings& compute_settings =
tflite::proto::ComputeSettings()) {
RETURN_IF_ERROR(engine->InitInterpreter(compute_settings));
RETURN_IF_ERROR(engine->InitInterpreter(compute_settings, num_threads));
return absl::make_unique<T>(std::move(engine));
}
};
Expand Down
27 changes: 5 additions & 22 deletions tensorflow_lite_support/cc/task/core/tflite_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,11 @@ static std::ios_base::Init s_iostream_initializer;
#endif

using ::absl::StatusCode;
using ::tflite::proto::ComputeSettings;
using ::tflite::support::CreateStatusWithPayload;
using ::tflite::support::InterpreterCreationResources;
using ::tflite::support::TfLiteSupportStatus;

using ::tflite::support::InterpreterCreationResources;

bool TfLiteEngine::Verifier::Verify(const char* data, int length,
tflite::ErrorReporter* reporter) {
return tflite::Verify(data, length, reporter);
Expand Down Expand Up @@ -185,42 +185,25 @@ absl::Status TfLiteEngine::BuildModelFromExternalFileProto(

absl::Status TfLiteEngine::InitInterpreter(int num_threads) {
tflite::proto::ComputeSettings compute_settings;
compute_settings.mutable_tflite_settings()
->mutable_cpu_settings()
->set_num_threads(num_threads);
return InitInterpreter(compute_settings);
return InitInterpreter(compute_settings, num_threads);
}

// TODO(b/183798104): deprecate num_threads in VK task protos.
// Deprecated. Use the following method, and configure `num_threads` through
// `compute_settings`, i.e. in `CPUSettings`:
// absl::Status TfLiteEngine::InitInterpreter(
// const tflite::proto::ComputeSettings& compute_settings)
absl::Status TfLiteEngine::InitInterpreter(
const tflite::proto::ComputeSettings& compute_settings, int num_threads) {
ComputeSettings settings_copy = ComputeSettings(compute_settings);
settings_copy.mutable_tflite_settings()
->mutable_cpu_settings()
->set_num_threads(num_threads);
return InitInterpreter(settings_copy);
}

absl::Status TfLiteEngine::InitInterpreter(
const tflite::proto::ComputeSettings& compute_settings) {
if (model_ == nullptr) {
return CreateStatusWithPayload(
StatusCode::kInternal,
"TF Lite FlatBufferModel is null. Please make sure to call one of the "
"BuildModelFrom methods before calling InitInterpreter.");
}
auto initializer =
[this](
[this, num_threads](
const InterpreterCreationResources& resources,
std::unique_ptr<Interpreter, InterpreterDeleter>* interpreter_out)
-> absl::Status {
tflite_shims::InterpreterBuilder interpreter_builder(*model_, *resolver_);
resources.ApplyTo(&interpreter_builder);
if (interpreter_builder(interpreter_out) != kTfLiteOk) {
if (interpreter_builder(interpreter_out, num_threads) != kTfLiteOk) {
return CreateStatusWithPayload(
StatusCode::kUnknown,
absl::StrCat("Could not build the TF Lite interpreter: ",
Expand Down
11 changes: 2 additions & 9 deletions tensorflow_lite_support/cc/task/core/tflite_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,17 +126,10 @@ class TfLiteEngine {
// the value.
absl::Status InitInterpreter(int num_threads = 1);

// Initializes interpreter with acceleration configurations.
absl::Status InitInterpreter(
const tflite::proto::ComputeSettings& compute_settings);

// Deprecated. Use the following method, and configure `num_threads` through
// `compute_settings`, i.e. in `CPUSettings`:
// absl::Status TfLiteEngine::InitInterpreter(
// const tflite::proto::ComputeSettings& compute_settings)
// Same as above, but allows specifying `compute_settings` for acceleration.
absl::Status InitInterpreter(
const tflite::proto::ComputeSettings& compute_settings,
int num_threads);
int num_threads = 1);

// Cancels the on-going `Invoke()` call if any and if possible. This method
// can be called from a different thread than the one where `Invoke()` is
Expand Down

0 comments on commit 0d034ba

Please sign in to comment.