Skip to content

Commit

Permalink
Change elimination pattern for select ops to check input and result…
Browse files Browse the repository at this point in the history
… tensors instead of both inputs for type equivalence; recent broadcast optimizations broke the previous pattern.

PiperOrigin-RevId: 378537661
Change-Id: I91f1d52df8ace9d0e177385bdf3fc4adca909f3e
  • Loading branch information
lrdxgm authored and tensorflower-gardener committed Jun 10, 2021
1 parent 3f7a6ed commit 84262cc
Showing 1 changed file with 8 additions and 12 deletions.
20 changes: 8 additions & 12 deletions tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td
Original file line number Diff line number Diff line change
Expand Up @@ -830,23 +830,19 @@ class AllElementsAreBool<string val> : Constraint<CPred<
foreach SelectOp = [TFL_SelectOp, TFL_SelectV2Op] in {
// select(true_tensor, A, B) -> A
def Optimize#SelectOp#True : Pat<
(SelectOp (ConstantOp $constant),
$input1,
$input2),
(SelectOp:$result (ConstantOp $constant),
$input1,
$input2),
(replaceWithValue $input1),
[(HaveSameType $input1, $input2),
(IsTailOfShape $input1, $constant),
(IsTailOfShape $constant, $input1),
[(HaveSameType $input1, $result),
(AllElementsAreBool<"true"> $constant)]>;
// select(false_tensor, A, B) -> B
def Optimize#SelectOp#False : Pat<
(SelectOp (ConstantOp $constant),
$input1,
$input2),
(SelectOp:$result (ConstantOp $constant),
$input1,
$input2),
(replaceWithValue $input2),
[(HaveSameType $input1, $input2),
(IsTailOfShape $input1, $constant),
(IsTailOfShape $constant, $input1),
[(HaveSameType $input2, $result),
(AllElementsAreBool<"false"> $constant)]>;
}

Expand Down

0 comments on commit 84262cc

Please sign in to comment.