Skip to content

Commit

Permalink
Add JIT NPU compilation to LiteRtCompileModel
Browse files Browse the repository at this point in the history
By calling compiler plugins

PiperOrigin-RevId: 707770943
  • Loading branch information
ai-edge-bot authored and copybara-github committed Dec 19, 2024
1 parent 2750f45 commit f1abee8
Show file tree
Hide file tree
Showing 23 changed files with 365 additions and 191 deletions.
1 change: 1 addition & 0 deletions tflite/experimental/litert/c/litert_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ typedef enum {
} LiteRtStatus;

typedef enum : int {
kLiteRtHwAccelatorNone = 0,
kLiteRtHwAccelatorCpu = 1 << 0,
kLiteRtHwAccelatorGpu = 1 << 1,
kLiteRtHwAccelatorNpu = 1 << 2,
Expand Down
2 changes: 1 addition & 1 deletion tflite/experimental/litert/c/litert_compiled_model_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ TEST(CompiledModelTest, Basic) {

LiteRtCompiledModel compiled_model;
ASSERT_EQ(
LiteRtCreateCompiledModel(model, kLiteRtHwAccelatorCpu, &compiled_model),
LiteRtCreateCompiledModel(model, kLiteRtHwAccelatorNone, &compiled_model),
kLiteRtStatusOk);

LiteRtSubgraph subgraph;
Expand Down
1 change: 1 addition & 0 deletions tflite/experimental/litert/cc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ cc_library(
":litert_expected",
"//tflite/experimental/litert/c:litert_any",
"//tflite/experimental/litert/c:litert_common",
"@com_google_absl//absl/strings:string_view",
],
)

Expand Down
9 changes: 8 additions & 1 deletion tflite/experimental/litert/cc/litert_any.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <any>
#include <cstdint>

#include "absl/strings/string_view.h"
#include "tflite/experimental/litert/c/litert_any.h"
#include "tflite/experimental/litert/c/litert_common.h"
#include "tflite/experimental/litert/cc/litert_expected.h"
Expand Down Expand Up @@ -94,13 +95,19 @@ inline Expected<LiteRtAny> ToLiteRtAny(const std::any& any) {
result.str_value = std::any_cast<decltype(LiteRtAny::str_value)>(any);
return result;

} else if (any.type() == typeid(absl::string_view)) {
result.type = kLiteRtAnyTypeString;
result.str_value = std::any_cast<absl::string_view>(any).data();
return result;

} else if (any.type() == typeid(LiteRtAny::ptr_value)) {
result.type = kLiteRtAnyTypeVoidPtr;
result.ptr_value = std::any_cast<decltype(LiteRtAny::ptr_value)>(any);
return result;

} else {
return Error(kLiteRtStatusErrorInvalidArgument);
return Error(kLiteRtStatusErrorInvalidArgument,
"Invalid argument for ToLiteRtAny");
}
}

Expand Down
43 changes: 36 additions & 7 deletions tflite/experimental/litert/cc/litert_compiled_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,22 +51,35 @@ Expected<std::vector<TensorBuffer>> CompiledModel::CreateInputBuffers(
return Unexpected(kLiteRtStatusErrorRuntimeFailure,
input_buffer_requirements.Error().Message());
}

auto supported_types = input_buffer_requirements->SupportedTypes();
if (!supported_types) {
return supported_types.Error();
}
if (supported_types->empty()) {
return Unexpected(kLiteRtStatusErrorRuntimeFailure,
"Input doesn't support any tensor buffer types");
}
// For simplicity we just pick the first supported tensor buffer type.
LiteRtTensorBufferType tensor_buffer_type = (*supported_types)[0];

auto tensor_type = input_tensors[i].RankedTensorType();
if (!tensor_type) {
return tensor_type.Error();
}
LiteRtTensorBufferType tensor_buffer_type =
(*(*input_buffer_requirements).SupportedTypes())[0];

auto input_buffer = TensorBuffer::CreateManaged(
tensor_buffer_type, *tensor_type,
(*input_buffer_requirements).BufferSize().Value());
if (!input_buffer) {
return Unexpected(kLiteRtStatusErrorRuntimeFailure,
input_buffer.Error().Message());
}

input_buffers.push_back(std::move(*input_buffer));
}
return std::move(input_buffers);

return input_buffers;
}

Expected<std::vector<TensorBuffer>> CompiledModel::CreateOutputBuffers(
Expand All @@ -79,22 +92,37 @@ Expected<std::vector<TensorBuffer>> CompiledModel::CreateOutputBuffers(
if (!subgraph) {
return Unexpected(kLiteRtStatusErrorNotFound, "Failed to get subgraph");
}
std::vector<TensorBuffer> output_buffers;

auto output_tensors = subgraph->Outputs();

std::vector<TensorBuffer> output_buffers;
output_buffers.reserve(output_tensors.size());

for (int i = 0; i < output_tensors.size(); ++i) {
auto output_buffer_requirements =
GetOutputBufferRequirements(signature_index, i);
if (!output_buffer_requirements.HasValue()) {
return Unexpected(kLiteRtStatusErrorRuntimeFailure,
output_buffer_requirements.Error().Message());
}

auto supported_types = output_buffer_requirements->SupportedTypes();
if (!supported_types) {
return supported_types.Error();
}
if (supported_types->empty()) {
return Unexpected(kLiteRtStatusErrorRuntimeFailure,
"Output doesn't support any tensor buffer types");
}

// For simplicity we just pick the first supported tensor buffer type.
LiteRtTensorBufferType tensor_buffer_type = (*supported_types)[0];

auto tensor_type = output_tensors[i].RankedTensorType();
if (!tensor_type) {
return tensor_type.Error();
}
LiteRtTensorBufferType tensor_buffer_type =
(*(*output_buffer_requirements).SupportedTypes())[0];

auto output_buffer = TensorBuffer::CreateManaged(
tensor_buffer_type, *tensor_type,
(*output_buffer_requirements).BufferSize().Value());
Expand All @@ -104,7 +132,8 @@ Expected<std::vector<TensorBuffer>> CompiledModel::CreateOutputBuffers(
}
output_buffers.push_back(std::move(*output_buffer));
}
return std::move(output_buffers);

return output_buffers;
}

Expected<void> CompiledModel::Run(
Expand Down
2 changes: 1 addition & 1 deletion tflite/experimental/litert/cc/litert_compiled_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class CompiledModel
// returned object.
static Expected<CompiledModel> Create(
litert::Model& model,
LiteRtCompilationOptions compilation_options = kLiteRtHwAccelatorCpu) {
LiteRtCompilationOptions compilation_options = kLiteRtHwAccelatorNone) {
LiteRtCompiledModel compiled_model;
if (auto status = LiteRtCreateCompiledModel(
model.Get(), compilation_options, &compiled_model);
Expand Down
6 changes: 6 additions & 0 deletions tflite/experimental/litert/cc/litert_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,12 @@ class Model : public internal::Handle<LiteRtModel, LiteRtDestroyModel> {
return litert::Subgraph(signature->Subgraph());
}

size_t GetNumSignatures() const {
LiteRtParamIndex num_signatures;
internal::AssertOk(LiteRtGetNumModelSignatures, Get(), &num_signatures);
return num_signatures;
}

// Returns the list of signatures defined in the model.
Expected<std::vector<class Signature>> GetSignatures() const {
LiteRtParamIndex num_signatures;
Expand Down
5 changes: 5 additions & 0 deletions tflite/experimental/litert/compiler/plugin/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ cc_library(
hdrs = ["compiler_plugin.h"],
deps = [
":algo",
"//tflite/experimental/litert/c:litert_any",
"//tflite/experimental/litert/c:litert_common",
"//tflite/experimental/litert/c:litert_logging",
"//tflite/experimental/litert/cc:litert_buffer_ref",
Expand All @@ -32,9 +33,11 @@ cc_library(
"//tflite/experimental/litert/cc:litert_model",
"//tflite/experimental/litert/core:byte_code_util",
"//tflite/experimental/litert/core:dynamic_loading",
"//tflite/experimental/litert/core:environment",
"//tflite/experimental/litert/core:filesystem",
"//tflite/experimental/litert/core/model",
"//tflite/experimental/litert/core/model:ir_allocator",
"//tflite/experimental/litert/core/model:model_serialize",
"//tflite/experimental/litert/vendors/c:litert_compiler_plugin",
"//tflite/experimental/litert/vendors/c:litert_compiler_plugin_api",
"@com_google_absl//absl/log:absl_check",
Expand Down Expand Up @@ -62,7 +65,9 @@ cc_library(
# "@com_google_googletest//:gtest_main",
# "//testing/base/public:unique-test-directory",
# "@com_google_absl//absl/strings:string_view",
# "//tflite/experimental/litert/c:litert_common",
# "//tflite/experimental/litert/c:litert_op_code",
# "//tflite/experimental/litert/cc:litert_environment",
# "//tflite/experimental/litert/core:byte_code_util",
# "//tflite/experimental/litert/core:filesystem",
# "//tflite/experimental/litert/test:common",
Expand Down
Loading

0 comments on commit f1abee8

Please sign in to comment.