Skip to content

Commit

Permalink
[lite] Update analyze-variables pass to allow resource to be passed t…
Browse files Browse the repository at this point in the history
…o TFLite control flow ops (If, While, CallOnce) these pass the resources for other ops inside the functions, if any of the other ops that use the resources are not supported we shouldn't enable variables. This should be handled since we inspect all ops in the graph.

PiperOrigin-RevId: 399222824
Change-Id: I0546d8abc9f72d3d9f0247e72910b6bfaf24f339
  • Loading branch information
karimnosseir authored and tensorflower-gardener committed Sep 27, 2021
1 parent cde2ced commit 9f3428b
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 10 deletions.
2 changes: 1 addition & 1 deletion tensorflow/compiler/mlir/lite/tests/analyze-variables.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ module {

// -----

// CHECK: module attributes {tfl._legalize_tfl_variables = false}
// CHECK: module attributes {tfl._legalize_tfl_variables = true}
module {
func @main() -> tensor<i32> {
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf_type.resource<tensor<*xi32>>>
Expand Down
15 changes: 6 additions & 9 deletions tensorflow/compiler/mlir/lite/transforms/analyze_variables.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,11 @@ bool IsSupportedTFLiteResourceOp(Operation* op) {
TF::LookupTableSizeV2Op>(op);
}

// Returns true if 'op' is a TFLite control flow operation.
bool IsTFLiteControlFlowOp(Operation* op) {
return llvm::isa<TFL::WhileOp, TFL::IfOp, TFL::CallOnceOp>(op);
// Returns true if 'op' is TF/TFLite control flow op that can accept resource
// type. Usually these ops are just pass through, they call another subgraph and
// pass the operands to.
bool IsSupportedTFLiteControlFlow(Operation* op) {
return llvm::isa<TFL::WhileOp, TFL::IfOp, TF::IfOp, TFL::CallOnceOp>(op);
}
} // namespace

Expand Down Expand Up @@ -67,15 +69,10 @@ class AnalyzeVariablesPass
module.walk([&](Operation* op) {
// Skip ops that are supported natively by TFLite.
if (IsSupportedTFLiteResourceOp(op)) return WalkResult::advance();
if (IsSupportedTFLiteControlFlow(op)) return WalkResult::advance();

// Check for ops that are legalized to TFLite.
if (op->getDialect()->getNamespace() == "tfl") {
// TODO(b/189370197): Enable control flow ops after updating
// checks to handle them.
if (IsTFLiteControlFlowOp(op)) {
legalize_to_tfl = false;
return WalkResult::interrupt();
}
return WalkResult::advance();
}
// Check for ops that are not legalized to TFLite.
Expand Down

0 comments on commit 9f3428b

Please sign in to comment.