From b87ff4b949ef86dcb7ea48209503430327f4d996 Mon Sep 17 00:00:00 2001 From: Nicolas Stucki Date: Mon, 8 Apr 2024 10:08:33 +0200 Subject: [PATCH 1/2] Fix inline match on blocks with multiple statements Only the last expression of the block is considered as the inlined scrutinee. Otherwise we may not reduce as much as we should. We also need to make sure that side effects and bindings in the scrutinee are not duplicated. Fixes #18151 --- .../dotty/tools/dotc/inlines/Inliner.scala | 101 +++++++++++------- .../backend/jvm/InlineBytecodeTests.scala | 32 ++++++ tests/pos/i18151a.scala | 10 ++ tests/pos/i18151b.scala | 10 ++ tests/pos/i18151c.scala | 39 +++++++ 5 files changed, 154 insertions(+), 38 deletions(-) create mode 100644 tests/pos/i18151a.scala create mode 100644 tests/pos/i18151b.scala create mode 100644 tests/pos/i18151c.scala diff --git a/compiler/src/dotty/tools/dotc/inlines/Inliner.scala b/compiler/src/dotty/tools/dotc/inlines/Inliner.scala index 1b4d985c7c4c..a9c9568d0d31 100644 --- a/compiler/src/dotty/tools/dotc/inlines/Inliner.scala +++ b/compiler/src/dotty/tools/dotc/inlines/Inliner.scala @@ -860,46 +860,71 @@ class Inliner(val call: tpd.Tree)(using Context): case _ => sel.tpe } val selType = if (sel.isEmpty) wideSelType else selTyped(sel) - reduceInlineMatch(sel, selType, cases.asInstanceOf[List[CaseDef]], this) match { - case Some((caseBindings, rhs0)) => - // drop type ascriptions/casts hiding pattern-bound types (which are now aliases after reducing the match) - // note that any actually necessary casts will be reinserted by the typing pass below - val rhs1 = rhs0 match { - case Block(stats, t) if t.span.isSynthetic => - t match { - case Typed(expr, _) => - Block(stats, expr) - case TypeApply(sel@Select(expr, _), _) if sel.symbol.isTypeCast => - Block(stats, expr) - case _ => - rhs0 + + /** Make an Inlined that has no bindings. */ + def flattenInlineBlock(tree: Tree): Tree = { + def inlineBlock(call: Tree, stats: List[Tree], expr: Tree): Block = + def inlinedTree(tree: Tree) = Inlined(call, Nil, tree).withSpan(tree.span) + val stats1 = stats.map: + case stat: ValDef => cpy.ValDef(stat)(rhs = inlinedTree(stat.rhs)) + case stat: DefDef => cpy.DefDef(stat)(rhs = inlinedTree(stat.rhs)) + case stat => inlinedTree(stat) + cpy.Block(tree)(stats1, flattenInlineBlock(inlinedTree(expr))) + + tree match + case tree @ Inlined(call, bindings, expr) if !bindings.isEmpty => + inlineBlock(call, bindings, expr) + case tree @ Inlined(call, Nil, Block(stats, expr)) => + inlineBlock(call, stats, expr) + case _ => + tree + } + + def reduceInlineMatchExpr(sel: Tree): Tree = flattenInlineBlock(sel) match + case Block(stats, expr) => + cpy.Block(sel)(stats, reduceInlineMatchExpr(expr)) + case _ => + reduceInlineMatch(sel, selType, cases.asInstanceOf[List[CaseDef]], this) match { + case Some((caseBindings, rhs0)) => + // drop type ascriptions/casts hiding pattern-bound types (which are now aliases after reducing the match) + // note that any actually necessary casts will be reinserted by the typing pass below + val rhs1 = rhs0 match { + case Block(stats, t) if t.span.isSynthetic => + t match { + case Typed(expr, _) => + Block(stats, expr) + case TypeApply(sel@Select(expr, _), _) if sel.symbol.isTypeCast => + Block(stats, expr) + case _ => + rhs0 + } + case _ => rhs0 } - case _ => rhs0 - } - val rhs2 = rhs1 match { - case Typed(expr, tpt) if rhs1.span.isSynthetic => constToLiteral(expr) - case _ => constToLiteral(rhs1) + val rhs2 = rhs1 match { + case Typed(expr, tpt) if rhs1.span.isSynthetic => constToLiteral(expr) + case _ => constToLiteral(rhs1) + } + val (usedBindings, rhs3) = dropUnusedDefs(caseBindings, rhs2) + val rhs = seq(usedBindings, rhs3) + inlining.println(i"""--- reduce: + |$tree + |--- to: + |$rhs""") + typedExpr(rhs, pt) + case None => + def guardStr(guard: untpd.Tree) = if (guard.isEmpty) "" else i" if $guard" + def patStr(cdef: untpd.CaseDef) = i"case ${cdef.pat}${guardStr(cdef.guard)}" + val msg = + if (tree.selector.isEmpty) + em"""cannot reduce summonFrom with + | patterns : ${tree.cases.map(patStr).mkString("\n ")}""" + else + em"""cannot reduce inline match with + | scrutinee: $sel : ${selType} + | patterns : ${tree.cases.map(patStr).mkString("\n ")}""" + errorTree(tree, msg) } - val (usedBindings, rhs3) = dropUnusedDefs(caseBindings, rhs2) - val rhs = seq(usedBindings, rhs3) - inlining.println(i"""--- reduce: - |$tree - |--- to: - |$rhs""") - typedExpr(rhs, pt) - case None => - def guardStr(guard: untpd.Tree) = if (guard.isEmpty) "" else i" if $guard" - def patStr(cdef: untpd.CaseDef) = i"case ${cdef.pat}${guardStr(cdef.guard)}" - val msg = - if (tree.selector.isEmpty) - em"""cannot reduce summonFrom with - | patterns : ${tree.cases.map(patStr).mkString("\n ")}""" - else - em"""cannot reduce inline match with - | scrutinee: $sel : ${selType} - | patterns : ${tree.cases.map(patStr).mkString("\n ")}""" - errorTree(tree, msg) - } + reduceInlineMatchExpr(sel) } override def newLikeThis(nestingLevel: Int): Typer = new InlineTyper(initialErrorCount, nestingLevel) diff --git a/compiler/test/dotty/tools/backend/jvm/InlineBytecodeTests.scala b/compiler/test/dotty/tools/backend/jvm/InlineBytecodeTests.scala index fcbc738f2934..7172e19184cb 100644 --- a/compiler/test/dotty/tools/backend/jvm/InlineBytecodeTests.scala +++ b/compiler/test/dotty/tools/backend/jvm/InlineBytecodeTests.scala @@ -785,4 +785,36 @@ class InlineBytecodeTests extends DottyBytecodeTest { } } + @Test def inline_match_scrutinee_with_side_effect = { + val source = """class Test: + | inline def inlineTest(): Int = + | inline { + | println("scrutinee") + | (1, 2) + | } match + | case (e1, e2) => e1 + e2 + | + | def test: Int = inlineTest() + """.stripMargin + + checkBCode(source) { dir => + val clsIn = dir.lookupName("Test.class", directory = false).input + val clsNode = loadClassNode(clsIn) + + val fun = getMethod(clsNode, "test") + val instructions = instructionsFromMethod(fun) + val expected = List( + Field(GETSTATIC, "scala/Predef$", "MODULE$", "Lscala/Predef$;"), + Ldc(LDC, "scrutinee"), + Invoke(INVOKEVIRTUAL, "scala/Predef$", "println", "(Ljava/lang/Object;)V", false), + Op(ICONST_3), + Op(IRETURN), + ) + + assert(instructions == expected, + "`i was not properly beta-reduced in `test`\n" + diffInstructions(instructions, expected)) + + } + } + } diff --git a/tests/pos/i18151a.scala b/tests/pos/i18151a.scala new file mode 100644 index 000000000000..6be2c5c23a30 --- /dev/null +++ b/tests/pos/i18151a.scala @@ -0,0 +1,10 @@ +case class El[A](attr: String, child: String) + +transparent inline def inlineTest(): String = + inline { + val el: El[Any] = El("1", "2") + El[Any](el.attr, el.child) + } match + case El(attr, child) => attr + child + +def test: Unit = inlineTest() diff --git a/tests/pos/i18151b.scala b/tests/pos/i18151b.scala new file mode 100644 index 000000000000..01d2aaee972a --- /dev/null +++ b/tests/pos/i18151b.scala @@ -0,0 +1,10 @@ +case class El[A](val attr: String, val child: String) + +transparent inline def tmplStr(inline t: El[Any]): String = + inline t match + case El(attr, child) => attr + child + +def test: Unit = tmplStr { + val el = El("1", "2") + El[Any](el.attr, null) +} diff --git a/tests/pos/i18151c.scala b/tests/pos/i18151c.scala new file mode 100644 index 000000000000..a46ec9dd927c --- /dev/null +++ b/tests/pos/i18151c.scala @@ -0,0 +1,39 @@ +import scala.compiletime.* +import scala.compiletime.ops.any.ToString + +trait Attr +case object EmptyAttr extends Attr +transparent inline def attrStr(inline a: Attr): String = inline a match + case EmptyAttr => "" +transparent inline def attrStrHelper(inline a: Attr): String = inline a match + case EmptyAttr => "" +trait TmplNode +case class El[T <: String & Singleton, A <: Attr, C <: Tmpl](val tag: T, val attr: A, val child: C) + extends TmplNode +case class Sib[L <: Tmpl, R <: Tmpl](left: L, right: R) extends TmplNode +type TmplSingleton = String | Char | Int | Long | Float | Double | Boolean +type Tmpl = TmplNode | Unit | (TmplSingleton & Singleton) +transparent inline def tmplStr(inline t: Tmpl): String = inline t match + case El(tag, attr, child) => inline attrStr(attr) match + case "" => "<" + tag + ">" + tmplStr(child) + case x => "<" + tag + " " + x + ">" + tmplStr(child) + case Sib(left, right) => inline tmplStr(right) match + case "" => tmplStr(left) + case right => tmplStrHelper(left) + right + case () => "" + case s: (t & TmplSingleton) => constValue[ToString[t]] +transparent inline def tmplStrHelper(inline t: Tmpl): String = inline t match + case El(tag, attr, child) => inline (tmplStr(child), attrStr(attr)) match + case ("", "") => "<" + tag + "/>" + case (child, "") => "<" + tag + ">" + child + "" + case ("", attr) => "<" + tag + " " + attr + "/>" + case (child, attr) => "<" + tag + " " + attr + ">" + child + "" + case Sib(left, right) => tmplStrHelper(left) + tmplStrHelper(right) + case () => "" + case s: (t & TmplSingleton) => constValue[ToString[t]] +transparent inline def el(tag: String & Singleton): El[tag.type, EmptyAttr.type, Unit] = + El(tag, EmptyAttr, ()) +extension [T <: String & Singleton, A <: Attr, C <: Tmpl](el: El[T, A, C]) + transparent inline def >>[C2 <: Tmpl](child: C2) = El(el.tag, el.attr, el.child ++ child) + +extension [L <: Tmpl](left: L) transparent inline def ++[R <: Tmpl](right: R) = Sib(left, right) From 530f77526d475d49f7ca394e21b70e8fb8c80389 Mon Sep 17 00:00:00 2001 From: Nicolas Stucki Date: Wed, 10 Apr 2024 10:27:04 +0200 Subject: [PATCH 2/2] Update compiler/test/dotty/tools/backend/jvm/InlineBytecodeTests.scala Co-authored-by: Jan Chyb <48855024+jchyb@users.noreply.github.com> --- compiler/test/dotty/tools/backend/jvm/InlineBytecodeTests.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compiler/test/dotty/tools/backend/jvm/InlineBytecodeTests.scala b/compiler/test/dotty/tools/backend/jvm/InlineBytecodeTests.scala index 7172e19184cb..b490d55bb43f 100644 --- a/compiler/test/dotty/tools/backend/jvm/InlineBytecodeTests.scala +++ b/compiler/test/dotty/tools/backend/jvm/InlineBytecodeTests.scala @@ -812,7 +812,7 @@ class InlineBytecodeTests extends DottyBytecodeTest { ) assert(instructions == expected, - "`i was not properly beta-reduced in `test`\n" + diffInstructions(instructions, expected)) + "`i was not properly inlined in `test`\n" + diffInstructions(instructions, expected)) } }