diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td index 62ceb506831671..36e205a3720d90 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td @@ -830,23 +830,19 @@ class AllElementsAreBool : Constraint A def Optimize#SelectOp#True : Pat< - (SelectOp (ConstantOp $constant), - $input1, - $input2), + (SelectOp:$result (ConstantOp $constant), + $input1, + $input2), (replaceWithValue $input1), - [(HaveSameType $input1, $input2), - (IsTailOfShape $input1, $constant), - (IsTailOfShape $constant, $input1), + [(HaveSameType $input1, $result), (AllElementsAreBool<"true"> $constant)]>; // select(false_tensor, A, B) -> B def Optimize#SelectOp#False : Pat< - (SelectOp (ConstantOp $constant), - $input1, - $input2), + (SelectOp:$result (ConstantOp $constant), + $input1, + $input2), (replaceWithValue $input2), - [(HaveSameType $input1, $input2), - (IsTailOfShape $input1, $constant), - (IsTailOfShape $constant, $input1), + [(HaveSameType $input2, $result), (AllElementsAreBool<"false"> $constant)]>; }