diff --git a/compiler/src/dotty/tools/dotc/inlines/Inliner.scala b/compiler/src/dotty/tools/dotc/inlines/Inliner.scala index 1b4d985c7c4c..a9c9568d0d31 100644 --- a/compiler/src/dotty/tools/dotc/inlines/Inliner.scala +++ b/compiler/src/dotty/tools/dotc/inlines/Inliner.scala @@ -860,46 +860,71 @@ class Inliner(val call: tpd.Tree)(using Context): case _ => sel.tpe } val selType = if (sel.isEmpty) wideSelType else selTyped(sel) - reduceInlineMatch(sel, selType, cases.asInstanceOf[List[CaseDef]], this) match { - case Some((caseBindings, rhs0)) => - // drop type ascriptions/casts hiding pattern-bound types (which are now aliases after reducing the match) - // note that any actually necessary casts will be reinserted by the typing pass below - val rhs1 = rhs0 match { - case Block(stats, t) if t.span.isSynthetic => - t match { - case Typed(expr, _) => - Block(stats, expr) - case TypeApply(sel@Select(expr, _), _) if sel.symbol.isTypeCast => - Block(stats, expr) - case _ => - rhs0 + + /** Make an Inlined that has no bindings. */ + def flattenInlineBlock(tree: Tree): Tree = { + def inlineBlock(call: Tree, stats: List[Tree], expr: Tree): Block = + def inlinedTree(tree: Tree) = Inlined(call, Nil, tree).withSpan(tree.span) + val stats1 = stats.map: + case stat: ValDef => cpy.ValDef(stat)(rhs = inlinedTree(stat.rhs)) + case stat: DefDef => cpy.DefDef(stat)(rhs = inlinedTree(stat.rhs)) + case stat => inlinedTree(stat) + cpy.Block(tree)(stats1, flattenInlineBlock(inlinedTree(expr))) + + tree match + case tree @ Inlined(call, bindings, expr) if !bindings.isEmpty => + inlineBlock(call, bindings, expr) + case tree @ Inlined(call, Nil, Block(stats, expr)) => + inlineBlock(call, stats, expr) + case _ => + tree + } + + def reduceInlineMatchExpr(sel: Tree): Tree = flattenInlineBlock(sel) match + case Block(stats, expr) => + cpy.Block(sel)(stats, reduceInlineMatchExpr(expr)) + case _ => + reduceInlineMatch(sel, selType, cases.asInstanceOf[List[CaseDef]], this) match { + case Some((caseBindings, rhs0)) => + // drop type ascriptions/casts hiding pattern-bound types (which are now aliases after reducing the match) + // note that any actually necessary casts will be reinserted by the typing pass below + val rhs1 = rhs0 match { + case Block(stats, t) if t.span.isSynthetic => + t match { + case Typed(expr, _) => + Block(stats, expr) + case TypeApply(sel@Select(expr, _), _) if sel.symbol.isTypeCast => + Block(stats, expr) + case _ => + rhs0 + } + case _ => rhs0 } - case _ => rhs0 - } - val rhs2 = rhs1 match { - case Typed(expr, tpt) if rhs1.span.isSynthetic => constToLiteral(expr) - case _ => constToLiteral(rhs1) + val rhs2 = rhs1 match { + case Typed(expr, tpt) if rhs1.span.isSynthetic => constToLiteral(expr) + case _ => constToLiteral(rhs1) + } + val (usedBindings, rhs3) = dropUnusedDefs(caseBindings, rhs2) + val rhs = seq(usedBindings, rhs3) + inlining.println(i"""--- reduce: + |$tree + |--- to: + |$rhs""") + typedExpr(rhs, pt) + case None => + def guardStr(guard: untpd.Tree) = if (guard.isEmpty) "" else i" if $guard" + def patStr(cdef: untpd.CaseDef) = i"case ${cdef.pat}${guardStr(cdef.guard)}" + val msg = + if (tree.selector.isEmpty) + em"""cannot reduce summonFrom with + | patterns : ${tree.cases.map(patStr).mkString("\n ")}""" + else + em"""cannot reduce inline match with + | scrutinee: $sel : ${selType} + | patterns : ${tree.cases.map(patStr).mkString("\n ")}""" + errorTree(tree, msg) } - val (usedBindings, rhs3) = dropUnusedDefs(caseBindings, rhs2) - val rhs = seq(usedBindings, rhs3) - inlining.println(i"""--- reduce: - |$tree - |--- to: - |$rhs""") - typedExpr(rhs, pt) - case None => - def guardStr(guard: untpd.Tree) = if (guard.isEmpty) "" else i" if $guard" - def patStr(cdef: untpd.CaseDef) = i"case ${cdef.pat}${guardStr(cdef.guard)}" - val msg = - if (tree.selector.isEmpty) - em"""cannot reduce summonFrom with - | patterns : ${tree.cases.map(patStr).mkString("\n ")}""" - else - em"""cannot reduce inline match with - | scrutinee: $sel : ${selType} - | patterns : ${tree.cases.map(patStr).mkString("\n ")}""" - errorTree(tree, msg) - } + reduceInlineMatchExpr(sel) } override def newLikeThis(nestingLevel: Int): Typer = new InlineTyper(initialErrorCount, nestingLevel) diff --git a/compiler/test/dotty/tools/backend/jvm/InlineBytecodeTests.scala b/compiler/test/dotty/tools/backend/jvm/InlineBytecodeTests.scala index fcbc738f2934..b490d55bb43f 100644 --- a/compiler/test/dotty/tools/backend/jvm/InlineBytecodeTests.scala +++ b/compiler/test/dotty/tools/backend/jvm/InlineBytecodeTests.scala @@ -785,4 +785,36 @@ class InlineBytecodeTests extends DottyBytecodeTest { } } + @Test def inline_match_scrutinee_with_side_effect = { + val source = """class Test: + | inline def inlineTest(): Int = + | inline { + | println("scrutinee") + | (1, 2) + | } match + | case (e1, e2) => e1 + e2 + | + | def test: Int = inlineTest() + """.stripMargin + + checkBCode(source) { dir => + val clsIn = dir.lookupName("Test.class", directory = false).input + val clsNode = loadClassNode(clsIn) + + val fun = getMethod(clsNode, "test") + val instructions = instructionsFromMethod(fun) + val expected = List( + Field(GETSTATIC, "scala/Predef$", "MODULE$", "Lscala/Predef$;"), + Ldc(LDC, "scrutinee"), + Invoke(INVOKEVIRTUAL, "scala/Predef$", "println", "(Ljava/lang/Object;)V", false), + Op(ICONST_3), + Op(IRETURN), + ) + + assert(instructions == expected, + "`i was not properly inlined in `test`\n" + diffInstructions(instructions, expected)) + + } + } + } diff --git a/tests/pos/i18151a.scala b/tests/pos/i18151a.scala new file mode 100644 index 000000000000..6be2c5c23a30 --- /dev/null +++ b/tests/pos/i18151a.scala @@ -0,0 +1,10 @@ +case class El[A](attr: String, child: String) + +transparent inline def inlineTest(): String = + inline { + val el: El[Any] = El("1", "2") + El[Any](el.attr, el.child) + } match + case El(attr, child) => attr + child + +def test: Unit = inlineTest() diff --git a/tests/pos/i18151b.scala b/tests/pos/i18151b.scala new file mode 100644 index 000000000000..01d2aaee972a --- /dev/null +++ b/tests/pos/i18151b.scala @@ -0,0 +1,10 @@ +case class El[A](val attr: String, val child: String) + +transparent inline def tmplStr(inline t: El[Any]): String = + inline t match + case El(attr, child) => attr + child + +def test: Unit = tmplStr { + val el = El("1", "2") + El[Any](el.attr, null) +} diff --git a/tests/pos/i18151c.scala b/tests/pos/i18151c.scala new file mode 100644 index 000000000000..a46ec9dd927c --- /dev/null +++ b/tests/pos/i18151c.scala @@ -0,0 +1,39 @@ +import scala.compiletime.* +import scala.compiletime.ops.any.ToString + +trait Attr +case object EmptyAttr extends Attr +transparent inline def attrStr(inline a: Attr): String = inline a match + case EmptyAttr => "" +transparent inline def attrStrHelper(inline a: Attr): String = inline a match + case EmptyAttr => "" +trait TmplNode +case class El[T <: String & Singleton, A <: Attr, C <: Tmpl](val tag: T, val attr: A, val child: C) + extends TmplNode +case class Sib[L <: Tmpl, R <: Tmpl](left: L, right: R) extends TmplNode +type TmplSingleton = String | Char | Int | Long | Float | Double | Boolean +type Tmpl = TmplNode | Unit | (TmplSingleton & Singleton) +transparent inline def tmplStr(inline t: Tmpl): String = inline t match + case El(tag, attr, child) => inline attrStr(attr) match + case "" => "<" + tag + ">" + tmplStr(child) + case x => "<" + tag + " " + x + ">" + tmplStr(child) + case Sib(left, right) => inline tmplStr(right) match + case "" => tmplStr(left) + case right => tmplStrHelper(left) + right + case () => "" + case s: (t & TmplSingleton) => constValue[ToString[t]] +transparent inline def tmplStrHelper(inline t: Tmpl): String = inline t match + case El(tag, attr, child) => inline (tmplStr(child), attrStr(attr)) match + case ("", "") => "<" + tag + "/>" + case (child, "") => "<" + tag + ">" + child + "" + case ("", attr) => "<" + tag + " " + attr + "/>" + case (child, attr) => "<" + tag + " " + attr + ">" + child + "" + case Sib(left, right) => tmplStrHelper(left) + tmplStrHelper(right) + case () => "" + case s: (t & TmplSingleton) => constValue[ToString[t]] +transparent inline def el(tag: String & Singleton): El[tag.type, EmptyAttr.type, Unit] = + El(tag, EmptyAttr, ()) +extension [T <: String & Singleton, A <: Attr, C <: Tmpl](el: El[T, A, C]) + transparent inline def >>[C2 <: Tmpl](child: C2) = El(el.tag, el.attr, el.child ++ child) + +extension [L <: Tmpl](left: L) transparent inline def ++[R <: Tmpl](right: R) = Sib(left, right)