Skip to content

Commit

Permalink
Enable passing in cached symbol table to then & else_funcion helpers
Browse files Browse the repository at this point in the history
Use this in shape inference where the SymbolTableCollection is already constructed and used. Follow up here for while & changing accessor names.

PiperOrigin-RevId: 407103519
Change-Id: If74de97afce2c88299a379c1f2478d3d94e99400
  • Loading branch information
jpienaar authored and tensorflower-gardener committed Nov 2, 2021
1 parent 09b3e01 commit 7b8ded3
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 13 deletions.
22 changes: 14 additions & 8 deletions tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -278,16 +278,22 @@ else_branch: A function that takes 'inputs' and returns a list of
let hasCanonicalizer = 1;

let extraClassDeclaration = [{
// Get the then branch function.
FuncOp then_function() {
return SymbolTable::lookupNearestSymbolFrom<FuncOp>(
*this, then_branchAttr());
// Get the then branch function. Prefer passing in SymbolTableCollection
// to reuse cached lookups.
FuncOp then_function(::mlir::SymbolTableCollection* table = nullptr) {
if (table)
return table->lookupNearestSymbolFrom<FuncOp>(*this, then_branchAttr());
return SymbolTable::lookupNearestSymbolFrom<FuncOp>(
*this, then_branchAttr());
}

// Get the else branch function.
FuncOp else_function() {
return SymbolTable::lookupNearestSymbolFrom<FuncOp>(
*this, else_branchAttr());
// Get the else branch function. Prefer passing in SymbolTableCollection
// to reuse cached lookups.
FuncOp else_function(::mlir::SymbolTableCollection* table = nullptr) {
if (table)
return table->lookupNearestSymbolFrom<FuncOp>(*this, else_branchAttr());
return SymbolTable::lookupNearestSymbolFrom<FuncOp>(
*this, else_branchAttr());
}
}];
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -898,8 +898,8 @@ bool ShapeInference::InferShapeForCast(Operation* op) {
bool ShapeInference::InferShapeForIf(IfOp op) {
DCOMMENT_OP(op.getOperation(), "Infer shape for if ");
bool changed = false;
auto then_results = op.then_function().getType().getResults();
auto else_results = op.else_function().getType().getResults();
auto then_results = op.then_function(&symbol_table_).getType().getResults();
auto else_results = op.else_function(&symbol_table_).getType().getResults();
for (auto it : llvm::zip(op.getResults(), then_results, else_results)) {
// If then and else types do not match, skip refinement for that result.
if (std::get<1>(it) != std::get<2>(it)) continue;
Expand Down Expand Up @@ -1720,9 +1720,10 @@ FailureOr<bool> ShapeInference::PropagateShapeIntoAttachedFunctions(
ModuleOp module = op->getParentOfType<ModuleOp>();
if (auto if_op = dyn_cast<TF::IfOp>(op)) {
DCOMMENT("Propagating shapes into If");
return PropagateShapeToFunctions(
module, if_op.input().getTypes(),
{if_op.then_function(), if_op.else_function()}, max_iterations);
return PropagateShapeToFunctions(module, if_op.input().getTypes(),
{if_op.then_function(&symbol_table_),
if_op.else_function(&symbol_table_)},
max_iterations);
} else if (auto case_op = dyn_cast<TF::CaseOp>(op)) {
SmallVector<FuncOp, 4> branches;
case_op.get_branch_functions(branches);
Expand Down

0 comments on commit 7b8ded3

Please sign in to comment.