Skip to content

Commit

Permalink
Allow preserving intermediate tensors for debugging.
Browse files Browse the repository at this point in the history
e.g.

interpreter = tf.lite.Interpreter(
    model_path="test.tflite",
    experimental_preserve_all_tensors=True)
# Run evaluation
interpreter.invoke()
# Look at all tensors including intermediates.
print({
    t['name']: interpreter.get_tensor(t['index'])
    for t in interpreter.get_tensor_details()
})
PiperOrigin-RevId: 364618628
Change-Id: Ia2fda83fd67f61cf360d7e1c561b536a86a1c397
  • Loading branch information
aselle authored and tensorflower-gardener committed Mar 23, 2021
1 parent e61f5d3 commit 6480615
Show file tree
Hide file tree
Showing 15 changed files with 182 additions and 64 deletions.
5 changes: 3 additions & 2 deletions RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -135,17 +135,18 @@
ML authoring is generally discouraged.
* Add support for static hash tables through
`TFLiteConverter.from_saved_model`.
* The Python TF Lite Interpreter bindings now have an option
`experimental_preserve_all_tensors` to aid in debugging conversion.
* Quantized x86 execution defaults to Ruy GEMM library for platforms with
AVX support.
* Deprecate `tf.compat.v1.lite.experimental.get_potentially_supported_ops`.
Use `tf.lite.TFLiteConverter` directly to check whether a model is
convertible.
* Add support to select one of three different built-in op resolvers to be
used in Python Interpreter API.
* Enabled post training with calibrations for models that require user
provied TensorFlow Lite custom op libraries via
`converter.target_spec._experimental_custom_op_registerers`.

used in Python Interpreter API.
* TF Core:
* Corrected higher-order gradients of control flow constructs (`tf.cond`,
`tf.while_loop`, and compositions like `tf.foldl`) computed with
Expand Down
12 changes: 11 additions & 1 deletion tensorflow/lite/core/subgraph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -924,7 +924,7 @@ TfLiteStatus Subgraph::PrepareOpsAndTensors() {
if (!memory_planner_) {
memory_planner_.reset(new ArenaPlanner(
&context_, std::unique_ptr<GraphInfo>(new InterpreterInfo(this)),
/*preserve_inputs=*/true, /*preserve_intermediates*/ false,
/*preserve_inputs=*/true, preserve_all_tensors_,
kDefaultTensorAlignment));
memory_planner_->PlanAllocations();
}
Expand Down Expand Up @@ -1628,4 +1628,14 @@ void Subgraph::SetName(const char* name) {

const std::string& Subgraph::GetName() const { return name_; }

TfLiteStatus Subgraph::PreserveAllTensorsExperimental() {
if (memory_planner_) {
ReportError(
"PreserveAllTensorsExperimental called after memory was planned. ");
return kTfLiteError;
}
preserve_all_tensors_ = true;
return kTfLiteOk;
}

} // namespace tflite
8 changes: 8 additions & 0 deletions tensorflow/lite/core/subgraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,9 @@ class Subgraph {
// Returns true if cancellation function returns true.
bool IsCancelled();

// Enables preserving intermediates for debugging.
TfLiteStatus PreserveAllTensorsExperimental();

// The state of the Interpreter.
enum State {
// The interpreter isn't ready to be invoked.
Expand Down Expand Up @@ -744,7 +747,12 @@ class Subgraph {
// A map of resources. Owned by interpreter and shared by multiple subgraphs.
resource::ResourceMap* resources_ = nullptr;

// Name of the subgraph (analogous to function name).
std::string name_;

// Whether memory planner should be instantiated to retain intermediates for
// debugging.
bool preserve_all_tensors_ = false;
};

} // namespace tflite
Expand Down
9 changes: 9 additions & 0 deletions tensorflow/lite/interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -469,4 +469,13 @@ Profiler* Interpreter::GetProfiler() {
return primary_subgraph().GetProfiler();
}

TfLiteStatus Interpreter::PreserveAllTensorsExperimental() {
for (int subgraph_index = 0; subgraph_index < subgraphs_.size();
++subgraph_index) {
TF_LITE_ENSURE_STATUS(
subgraphs_[subgraph_index]->PreserveAllTensorsExperimental());
}
return kTfLiteOk;
}

} // namespace tflite
4 changes: 4 additions & 0 deletions tensorflow/lite/interpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -721,6 +721,10 @@ class Interpreter {
signature_defs_ = std::move(signature_defs);
}

// Enables preserving intermediates for debugging. Should only be set by
// InterpreterBuilder before allocating any tensors.
TfLiteStatus PreserveAllTensorsExperimental();

// A pure C data structure used to communicate with the pure C plugin
// interface. To avoid copying tensor metadata, this is also the definitive
// structure to store tensors.
Expand Down
10 changes: 10 additions & 0 deletions tensorflow/lite/interpreter_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -726,6 +726,10 @@ TfLiteStatus InterpreterBuilder::operator()(
(*interpreter)->AddSubgraphs(subgraphs->size() - 1);
}

if (preserve_all_tensors_) {
(*interpreter)->PreserveAllTensorsExperimental();
}

(*interpreter)->SetProfiler(tflite::profiling::MaybeCreatePlatformProfiler());

for (int subgraph_index = 0; subgraph_index < subgraphs->size();
Expand Down Expand Up @@ -795,4 +799,10 @@ void InterpreterBuilder::AddDelegate(TfLiteDelegate* delegate) {
}
}

// Enables preserving intermediates for debugging.
InterpreterBuilder& InterpreterBuilder::PreserveAllTensorsExperimental() {
preserve_all_tensors_ = true;
return *this;
}

} // namespace tflite
5 changes: 5 additions & 0 deletions tensorflow/lite/interpreter_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ class InterpreterBuilder {
TfLiteStatus operator()(std::unique_ptr<Interpreter>* interpreter,
int num_threads);

/// Enables preserving intermediates for debugging. Otherwise, by default
/// intermediates are undefined due to memory planning and reuse.
InterpreterBuilder& PreserveAllTensorsExperimental();

/// Any delegates added with AddDelegate will be applied to the Interpreter
/// generated by operator(), in the order that they were added. (The delegate
/// parameter passed to AddDelegate should be non-null, otherwise an error
Expand Down Expand Up @@ -114,6 +118,7 @@ class InterpreterBuilder {

bool has_flex_op_ = false;
int num_fp32_tensors_ = 0;
bool preserve_all_tensors_ = false;
};

} // namespace tflite
Expand Down
16 changes: 11 additions & 5 deletions tensorflow/lite/python/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@
from __future__ import print_function

import ctypes
import enum
import os
import platform
import sys
import os

import enum
import numpy as np

# pylint: disable=g-import-not-at-top
Expand Down Expand Up @@ -304,7 +304,8 @@ def __init__(self,
model_content=None,
experimental_delegates=None,
num_threads=None,
experimental_op_resolver_type=OpResolverType.AUTO):
experimental_op_resolver_type=OpResolverType.AUTO,
experimental_preserve_all_tensors=False):
"""Constructor.
Args:
Expand All @@ -321,6 +322,9 @@ def __init__(self,
must be an instance of OpResolverType. By default, we use the built-in
op resolver which corresponds to tflite::ops::builtin::BuiltinOpResolver
in C++.
experimental_preserve_all_tensors: If true, then intermediate tensors
used during computation are preserved for inspection. Otherwise, reading
intermediate tensors provides undefined values.
Raises:
ValueError: If the interpreter was unable to create.
Expand All @@ -343,7 +347,8 @@ def __init__(self,
self._interpreter = (
_interpreter_wrapper.CreateWrapperFromFile(
model_path, op_resolver_id, custom_op_registerers_by_name,
custom_op_registerers_by_func))
custom_op_registerers_by_func,
experimental_preserve_all_tensors))
if not self._interpreter:
raise ValueError('Failed to open {}'.format(model_path))
elif model_content and not model_path:
Expand All @@ -360,7 +365,8 @@ def __init__(self,
self._interpreter = (
_interpreter_wrapper.CreateWrapperFromBuffer(
model_content, op_resolver_id, custom_op_registerers_by_name,
custom_op_registerers_by_func))
custom_op_registerers_by_func,
experimental_preserve_all_tensors))
elif not model_content and not model_path:
raise ValueError('`model_path` or `model_content` must be specified.')
else:
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/lite/python/interpreter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,7 @@ def testFail(self):
with self.assertRaisesRegex(
# Due to exception chaining in PY3, we can't be more specific here and check that
# the phrase 'Fail argument sent' is present.
ValueError,
ValueError, #
r'Failed to load delegate from'):
interpreter_wrapper.load_delegate(
self._delegate_file, options={'fail': 'fail'})
Expand Down
40 changes: 24 additions & 16 deletions tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,15 +73,17 @@ using python_utils::PyDecrefDeleter;

std::unique_ptr<Interpreter> CreateInterpreter(
const InterpreterWrapper::Model* model,
const tflite::MutableOpResolver& resolver) {
const tflite::MutableOpResolver& resolver, bool preserve_all_tensors) {
if (!model) {
return nullptr;
}

::tflite::python::ImportNumpy();

std::unique_ptr<Interpreter> interpreter;
if (InterpreterBuilder(*model, resolver)(&interpreter) != kTfLiteOk) {
InterpreterBuilder builder(*model, resolver);
if (preserve_all_tensors) builder.PreserveAllTensorsExperimental();
if (builder(&interpreter) != kTfLiteOk) {
return nullptr;
}
return interpreter;
Expand Down Expand Up @@ -179,7 +181,7 @@ InterpreterWrapper* InterpreterWrapper::CreateInterpreterWrapper(
std::unique_ptr<PythonErrorReporter> error_reporter,
const std::vector<std::string>& registerers_by_name,
const std::vector<std::function<void(uintptr_t)>>& registerers_by_func,
std::string* error_msg) {
std::string* error_msg, bool preserve_all_tensors) {
if (!model) {
*error_msg = error_reporter->message();
return nullptr;
Expand Down Expand Up @@ -212,7 +214,8 @@ InterpreterWrapper* InterpreterWrapper::CreateInterpreterWrapper(
for (const auto& registerer : registerers_by_func) {
registerer(reinterpret_cast<uintptr_t>(resolver.get()));
}
auto interpreter = CreateInterpreter(model.get(), *resolver);
auto interpreter =
CreateInterpreter(model.get(), *resolver, preserve_all_tensors);
if (!interpreter) {
*error_msg = error_reporter->message();
return nullptr;
Expand Down Expand Up @@ -733,27 +736,30 @@ InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromFile(
const char* model_path, int op_resolver_id,
const std::vector<std::string>& registerers_by_name,
const std::vector<std::function<void(uintptr_t)>>& registerers_by_func,
std::string* error_msg) {
std::string* error_msg, bool preserve_all_tensors) {
std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter);
std::unique_ptr<InterpreterWrapper::Model> model =
Model::BuildFromFile(model_path, error_reporter.get());
return CreateInterpreterWrapper(
std::move(model), op_resolver_id, std::move(error_reporter),
registerers_by_name, registerers_by_func, error_msg);
return CreateInterpreterWrapper(std::move(model), op_resolver_id,
std::move(error_reporter),
registerers_by_name, registerers_by_func,
error_msg, preserve_all_tensors);
}

InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromFile(
const char* model_path, int op_resolver_id,
const std::vector<std::string>& registerers, std::string* error_msg) {
const std::vector<std::string>& registerers, std::string* error_msg,
bool preserve_all_tensors) {
return CreateWrapperCPPFromFile(model_path, op_resolver_id, registerers,
{} /*registerers_by_func*/, error_msg);
{} /*registerers_by_func*/, error_msg,
preserve_all_tensors);
}

InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromBuffer(
PyObject* data, int op_resolver_id,
const std::vector<std::string>& registerers_by_name,
const std::vector<std::function<void(uintptr_t)>>& registerers_by_func,
std::string* error_msg) {
std::string* error_msg, bool preserve_all_tensors) {
char* buf = nullptr;
Py_ssize_t length;
std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter);
Expand All @@ -763,16 +769,18 @@ InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromBuffer(
}
std::unique_ptr<InterpreterWrapper::Model> model =
Model::BuildFromBuffer(buf, length, error_reporter.get());
return CreateInterpreterWrapper(
std::move(model), op_resolver_id, std::move(error_reporter),
registerers_by_name, registerers_by_func, error_msg);
return CreateInterpreterWrapper(std::move(model), op_resolver_id,
std::move(error_reporter),
registerers_by_name, registerers_by_func,
error_msg, preserve_all_tensors);
}

InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromBuffer(
PyObject* data, int op_resolver_id,
const std::vector<std::string>& registerers, std::string* error_msg) {
const std::vector<std::string>& registerers, std::string* error_msg,
bool preserve_all_tensors) {
return CreateWrapperCPPFromBuffer(data, op_resolver_id, registerers, {},
error_msg);
error_msg, preserve_all_tensors);
}

PyObject* InterpreterWrapper::ResetVariableTensors() {
Expand Down
12 changes: 7 additions & 5 deletions tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,22 +47,24 @@ class InterpreterWrapper {
// SWIG caller takes ownership of pointer.
static InterpreterWrapper* CreateWrapperCPPFromFile(
const char* model_path, int op_resolver_id,
const std::vector<std::string>& registerers, std::string* error_msg);
const std::vector<std::string>& registerers, std::string* error_msg,
bool preserve_all_tensors);
static InterpreterWrapper* CreateWrapperCPPFromFile(
const char* model_path, int op_resolver_id,
const std::vector<std::string>& registerers_by_name,
const std::vector<std::function<void(uintptr_t)>>& registerers_by_func,
std::string* error_msg);
std::string* error_msg, bool preserve_all_tensors);

// SWIG caller takes ownership of pointer.
static InterpreterWrapper* CreateWrapperCPPFromBuffer(
PyObject* data, int op_resolver_id,
const std::vector<std::string>& registerers, std::string* error_msg);
const std::vector<std::string>& registerers, std::string* error_msg,
bool preserve_all_tensors);
static InterpreterWrapper* CreateWrapperCPPFromBuffer(
PyObject* data, int op_resolver_id,
const std::vector<std::string>& registerers_by_name,
const std::vector<std::function<void(uintptr_t)>>& registerers_by_func,
std::string* error_msg);
std::string* error_msg, bool preserve_all_tensors);

~InterpreterWrapper();
PyObject* AllocateTensors();
Expand Down Expand Up @@ -119,7 +121,7 @@ class InterpreterWrapper {
std::unique_ptr<PythonErrorReporter> error_reporter,
const std::vector<std::string>& registerers_by_name,
const std::vector<std::function<void(uintptr_t)>>& registerers_by_func,
std::string* error_msg);
std::string* error_msg, bool preserve_all_tensors);

InterpreterWrapper(std::unique_ptr<Model> model,
std::unique_ptr<PythonErrorReporter> error_reporter,
Expand Down
Loading

0 comments on commit 6480615

Please sign in to comment.