diff --git a/compiler/src/dotty/tools/dotc/ast/Desugar.scala b/compiler/src/dotty/tools/dotc/ast/Desugar.scala index df5b7c1501d8..4231505dce62 100644 --- a/compiler/src/dotty/tools/dotc/ast/Desugar.scala +++ b/compiler/src/dotty/tools/dotc/ast/Desugar.scala @@ -11,6 +11,7 @@ import NameKinds.{UniqueName, ContextBoundParamName, ContextFunctionParamName, D import typer.{Namer, Checking} import util.{Property, SourceFile, SourcePosition, SrcPos, Chars} import config.{Feature, Config} +import config.Feature.{sourceVersion, migrateTo3, enabled, betterForsEnabled} import config.SourceVersion.* import collection.mutable import reporting.* @@ -1807,7 +1808,7 @@ object desugar { * * 1. * - * for (P <- G) E ==> G.foreach (P => E) + * for (P <- G) do E ==> G.foreach (P => E) * * Here and in the following (P => E) is interpreted as the function (P => E) * if P is a variable pattern and as the partial function { case P => E } otherwise. @@ -1816,11 +1817,11 @@ object desugar { * * for (P <- G) yield P ==> G * - * if P is a variable or a tuple of variables and G is not a withFilter. + * If P is a variable or a tuple of variables and G is not a withFilter. * * for (P <- G) yield E ==> G.map (P => E) * - * otherwise + * Otherwise * * 3. * @@ -1830,25 +1831,48 @@ object desugar { * * 4. * - * for (P <- G; E; ...) ... - * => - * for (P <- G.filter (P => E); ...) ... + * for (P <- G; if E; ...) ... + * ==> + * for (P <- G.withFilter (P => E); ...) ... * * 5. For any N: * - * for (P_1 <- G; P_2 = E_2; val P_N = E_N; ...) + * for (P <- G; P_1 = E_1; ... P_N = E_N; rest) * ==> - * for (TupleN(P_1, P_2, ... P_N) <- - * for (x_1 @ P_1 <- G) yield { - * val x_2 @ P_2 = E_2 + * G.flatMap (P => for (P_1 = E_1; ... P_N = E_N; ...)) if rest contains (<-) + * G.map (P => for (P_1 = E_1; ... P_N = E_N; ...)) otherwise + * + * 6. For any N: + * + * for (P <- G; P_1 = E_1; ... P_N = E_N; if E; ...) + * ==> + * for (TupleN(P, P_1, ... P_N) <- + * for (x @ P <- G) yield { + * val x_1 @ P_1 = E_2 * ... - * val x_N & P_N = E_N - * TupleN(x_1, ..., x_N) - * } ...) + * val x_N @ P_N = E_N + * TupleN(x, x_1, ..., x_N) + * }; if E; ...) * * If any of the P_i are variable patterns, the corresponding `x_i @ P_i` is not generated * and the variable constituting P_i is used instead of x_i * + * 7. For any N: + * + * for (P_1 = E_1; ... P_N = E_N; ...) + * ==> + * { + * val x_N @ P_N = E_N + * for (...) + * } + * + * 8. + * for () yield E ==> E + * + * (Where empty for-comprehensions are excluded by the parser) + * + * If the aliases are not followed by a guard, otherwise an error. + * * @param mapName The name to be used for maps (either map or foreach) * @param flatMapName The name to be used for flatMaps (either flatMap or foreach) * @param enums The enumerators in the for expression @@ -1973,37 +1997,86 @@ object desugar { case (Tuple(ts1), Tuple(ts2)) => ts1.corresponds(ts2)(deepEquals) case _ => false - enums match { - case (gen: GenFrom) :: Nil => - if gen.checkMode != GenCheckMode.Filtered // results of withFilter have the wrong type - && deepEquals(gen.pat, body) - then gen.expr // avoid a redundant map with identity - else Apply(rhsSelect(gen, mapName), makeLambda(gen, body)) - case (gen: GenFrom) :: (rest @ (GenFrom(_, _, _) :: _)) => - val cont = makeFor(mapName, flatMapName, rest, body) - Apply(rhsSelect(gen, flatMapName), makeLambda(gen, cont)) - case (gen: GenFrom) :: (rest @ GenAlias(_, _) :: _) => - val (valeqs, rest1) = rest.span(_.isInstanceOf[GenAlias]) - val pats = valeqs map { case GenAlias(pat, _) => pat } - val rhss = valeqs map { case GenAlias(_, rhs) => rhs } - val (defpat0, id0) = makeIdPat(gen.pat) - val (defpats, ids) = (pats map makeIdPat).unzip - val pdefs = valeqs.lazyZip(defpats).lazyZip(rhss).map { (valeq, defpat, rhs) => - val mods = defpat match - case defTree: DefTree => defTree.mods - case _ => Modifiers() - makePatDef(valeq, mods, defpat, rhs) - } - val rhs1 = makeFor(nme.map, nme.flatMap, GenFrom(defpat0, gen.expr, gen.checkMode) :: Nil, Block(pdefs, makeTuple(id0 :: ids))) - val allpats = gen.pat :: pats - val vfrom1 = GenFrom(makeTuple(allpats), rhs1, GenCheckMode.Ignore) - makeFor(mapName, flatMapName, vfrom1 :: rest1, body) - case (gen: GenFrom) :: test :: rest => - val filtered = Apply(rhsSelect(gen, nme.withFilter), makeLambda(gen, test)) - val genFrom = GenFrom(gen.pat, filtered, GenCheckMode.Filtered) - makeFor(mapName, flatMapName, genFrom :: rest, body) - case _ => - EmptyTree //may happen for erroneous input + if betterForsEnabled then + enums match { + case Nil => body + case (gen: GenFrom) :: Nil => + if gen.checkMode != GenCheckMode.Filtered // results of withFilter have the wrong type + && deepEquals(gen.pat, body) + then gen.expr // avoid a redundant map with identity + else Apply(rhsSelect(gen, mapName), makeLambda(gen, body)) + case (gen: GenFrom) :: rest + if rest.dropWhile(_.isInstanceOf[GenAlias]).headOption.forall(e => e.isInstanceOf[GenFrom]) => + val cont = makeFor(mapName, flatMapName, rest, body) + val selectName = + if rest.exists(_.isInstanceOf[GenFrom]) then flatMapName + else mapName + Apply(rhsSelect(gen, selectName), makeLambda(gen, cont)) + case (gen: GenFrom) :: (rest @ GenAlias(_, _) :: _) => + val (valeqs, rest1) = rest.span(_.isInstanceOf[GenAlias]) + val pats = valeqs map { case GenAlias(pat, _) => pat } + val rhss = valeqs map { case GenAlias(_, rhs) => rhs } + val (defpat0, id0) = makeIdPat(gen.pat) + val (defpats, ids) = (pats map makeIdPat).unzip + val pdefs = valeqs.lazyZip(defpats).lazyZip(rhss).map { (valeq, defpat, rhs) => + val mods = defpat match + case defTree: DefTree => defTree.mods + case _ => Modifiers() + makePatDef(valeq, mods, defpat, rhs) + } + val rhs1 = makeFor(nme.map, nme.flatMap, GenFrom(defpat0, gen.expr, gen.checkMode) :: Nil, Block(pdefs, makeTuple(id0 :: ids))) + val allpats = gen.pat :: pats + val vfrom1 = GenFrom(makeTuple(allpats), rhs1, GenCheckMode.Ignore) + makeFor(mapName, flatMapName, vfrom1 :: rest1, body) + case (gen: GenFrom) :: test :: rest => + val filtered = Apply(rhsSelect(gen, nme.withFilter), makeLambda(gen, test)) + val genFrom = GenFrom(gen.pat, filtered, GenCheckMode.Filtered) + makeFor(mapName, flatMapName, genFrom :: rest, body) + case GenAlias(_, _) :: _ => + val (valeqs, rest) = enums.span(_.isInstanceOf[GenAlias]) + val pats = valeqs.map { case GenAlias(pat, _) => pat } + val rhss = valeqs.map { case GenAlias(_, rhs) => rhs } + val (defpats, ids) = pats.map(makeIdPat).unzip + val pdefs = valeqs.lazyZip(defpats).lazyZip(rhss).map { (valeq, defpat, rhs) => + val mods = defpat match + case defTree: DefTree => defTree.mods + case _ => Modifiers() + makePatDef(valeq, mods, defpat, rhs) + } + Block(pdefs, makeFor(mapName, flatMapName, rest, body)) + case _ => + EmptyTree //may happen for erroneous input + } + else { + enums match { + case (gen: GenFrom) :: Nil => + Apply(rhsSelect(gen, mapName), makeLambda(gen, body)) + case (gen: GenFrom) :: (rest @ (GenFrom(_, _, _) :: _)) => + val cont = makeFor(mapName, flatMapName, rest, body) + Apply(rhsSelect(gen, flatMapName), makeLambda(gen, cont)) + case (gen: GenFrom) :: (rest @ GenAlias(_, _) :: _) => + val (valeqs, rest1) = rest.span(_.isInstanceOf[GenAlias]) + val pats = valeqs map { case GenAlias(pat, _) => pat } + val rhss = valeqs map { case GenAlias(_, rhs) => rhs } + val (defpat0, id0) = makeIdPat(gen.pat) + val (defpats, ids) = (pats map makeIdPat).unzip + val pdefs = valeqs.lazyZip(defpats).lazyZip(rhss).map { (valeq, defpat, rhs) => + val mods = defpat match + case defTree: DefTree => defTree.mods + case _ => Modifiers() + makePatDef(valeq, mods, defpat, rhs) + } + val rhs1 = makeFor(nme.map, nme.flatMap, GenFrom(defpat0, gen.expr, gen.checkMode) :: Nil, Block(pdefs, makeTuple(id0 :: ids))) + val allpats = gen.pat :: pats + val vfrom1 = GenFrom(makeTuple(allpats), rhs1, GenCheckMode.Ignore) + makeFor(mapName, flatMapName, vfrom1 :: rest1, body) + case (gen: GenFrom) :: test :: rest => + val filtered = Apply(rhsSelect(gen, nme.withFilter), makeLambda(gen, test)) + val genFrom = GenFrom(gen.pat, filtered, GenCheckMode.Ignore) + makeFor(mapName, flatMapName, genFrom :: rest, body) + case _ => + EmptyTree //may happen for erroneous input + } } } diff --git a/compiler/src/dotty/tools/dotc/config/Feature.scala b/compiler/src/dotty/tools/dotc/config/Feature.scala index 8c1021e91e38..cad9b4e76ca9 100644 --- a/compiler/src/dotty/tools/dotc/config/Feature.scala +++ b/compiler/src/dotty/tools/dotc/config/Feature.scala @@ -38,6 +38,7 @@ object Feature: val modularity = experimental("modularity") val betterMatchTypeExtractors = experimental("betterMatchTypeExtractors") val quotedPatternsWithPolymorphicFunctions = experimental("quotedPatternsWithPolymorphicFunctions") + val betterFors = experimental("betterFors") def experimentalAutoEnableFeatures(using Context): List[TermName] = defn.languageExperimentalFeatures @@ -125,6 +126,8 @@ object Feature: def clauseInterleavingEnabled(using Context) = sourceVersion.isAtLeast(`3.6`) || enabled(clauseInterleaving) + def betterForsEnabled(using Context) = enabled(betterFors) + def genericNumberLiteralsEnabled(using Context) = enabled(genericNumberLiterals) def scala2ExperimentalMacroEnabled(using Context) = enabled(scala2macros) diff --git a/compiler/src/dotty/tools/dotc/core/StdNames.scala b/compiler/src/dotty/tools/dotc/core/StdNames.scala index d3e198a7e7a7..bbe405b46bf1 100644 --- a/compiler/src/dotty/tools/dotc/core/StdNames.scala +++ b/compiler/src/dotty/tools/dotc/core/StdNames.scala @@ -435,6 +435,7 @@ object StdNames { val asInstanceOfPM: N = "$asInstanceOf$" val assert_ : N = "assert" val assume_ : N = "assume" + val betterFors: N = "betterFors" val box: N = "box" val break: N = "break" val build : N = "build" diff --git a/compiler/src/dotty/tools/dotc/parsing/Parsers.scala b/compiler/src/dotty/tools/dotc/parsing/Parsers.scala index 37587868da58..f4a6b5b76aa0 100644 --- a/compiler/src/dotty/tools/dotc/parsing/Parsers.scala +++ b/compiler/src/dotty/tools/dotc/parsing/Parsers.scala @@ -2891,7 +2891,11 @@ object Parsers { /** Enumerators ::= Generator {semi Enumerator | Guard} */ - def enumerators(): List[Tree] = generator() :: enumeratorsRest() + def enumerators(): List[Tree] = + if in.featureEnabled(Feature.betterFors) then + aliasesUntilGenerator() ++ enumeratorsRest() + else + generator() :: enumeratorsRest() def enumeratorsRest(): List[Tree] = if (isStatSep) { @@ -2933,6 +2937,18 @@ object Parsers { GenFrom(pat, subExpr(), checkMode) } + def aliasesUntilGenerator(): List[Tree] = + if in.token == CASE then generator() :: Nil + else { + val pat = pattern1() + if in.token == EQUALS then + atSpan(startOffset(pat), in.skipToken()) { GenAlias(pat, subExpr()) } :: { + if (isStatSep) in.nextToken() + aliasesUntilGenerator() + } + else generatorRest(pat, casePat = false) :: Nil + } + /** ForExpr ::= ‘for’ ‘(’ Enumerators ‘)’ {nl} [‘do‘ | ‘yield’] Expr * | ‘for’ ‘{’ Enumerators ‘}’ {nl} [‘do‘ | ‘yield’] Expr * | ‘for’ Enumerators (‘do‘ | ‘yield’) Expr diff --git a/library/src/scala/runtime/stdLibPatches/language.scala b/library/src/scala/runtime/stdLibPatches/language.scala index 7db326350fa1..3e8c2ab15cd2 100644 --- a/library/src/scala/runtime/stdLibPatches/language.scala +++ b/library/src/scala/runtime/stdLibPatches/language.scala @@ -133,6 +133,12 @@ object language: @compileTimeOnly("`quotedPatternsWithPolymorphicFunctions` can only be used at compile time in import statements") object quotedPatternsWithPolymorphicFunctions + /** Experimental support for improvements in `for` comprehensions + * + * @see [[https://dotty.epfl.ch/docs/reference/experimental/better-fors]] + */ + @compileTimeOnly("`betterFors` can only be used at compile time in import statements") + object betterFors end experimental /** The deprecated object contains features that are no longer officially suypported in Scala. diff --git a/tests/run/better-fors.check b/tests/run/better-fors.check new file mode 100644 index 000000000000..8b75db2f56ad --- /dev/null +++ b/tests/run/better-fors.check @@ -0,0 +1,12 @@ +List((1,3), (1,4), (2,3), (2,4)) +List((1,2,3), (1,2,4)) +List((1,3), (1,4), (2,3), (2,4)) +List((2,3), (2,4)) +List((2,3), (2,4)) +List((1,2), (2,4)) +List(1, 2, 3) +List((2,3,6)) +List(6) +List(3, 6) +List(6) +List(2) diff --git a/tests/run/better-fors.scala b/tests/run/better-fors.scala new file mode 100644 index 000000000000..8c0bff230632 --- /dev/null +++ b/tests/run/better-fors.scala @@ -0,0 +1,105 @@ +import scala.language.experimental.betterFors + +def for1 = + for { + a = 1 + b <- List(a, 2) + c <- List(3, 4) + } yield (b, c) + +def for2 = + for + a = 1 + b = 2 + c <- List(3, 4) + yield (a, b, c) + +def for3 = + for { + a = 1 + b <- List(a, 2) + c = 3 + d <- List(c, 4) + } yield (b, d) + +def for4 = + for { + a = 1 + b <- List(a, 2) + if b > 1 + c <- List(3, 4) + } yield (b, c) + +def for5 = + for { + a = 1 + b <- List(a, 2) + c = 3 + if b > 1 + d <- List(c, 4) + } yield (b, d) + +def for6 = + for { + a = 1 + b = 2 + c <- for { + x <- List(a, b) + y = x * 2 + } yield (x, y) + } yield c + +def for7 = + for { + a <- List(1, 2, 3) + } yield a + +def for8 = + for { + a <- List(1, 2) + b = a + 1 + if b > 2 + c = b * 2 + if c < 8 + } yield (a, b, c) + +def for9 = + for { + a <- List(1, 2) + b = a * 2 + if b > 2 + } yield a + b + +def for10 = + for { + a <- List(1, 2) + b = a * 2 + } yield a + b + +def for11 = + for { + a <- List(1, 2) + b = a * 2 + if b > 2 && b % 2 == 0 + } yield a + b + +def for12 = + for { + a <- List(1, 2) + if a > 1 + } yield a + +object Test extends App { + println(for1) + println(for2) + println(for3) + println(for4) + println(for5) + println(for6) + println(for7) + println(for8) + println(for9) + println(for10) + println(for11) + println(for12) +} diff --git a/tests/run/fors.scala b/tests/run/fors.scala index bd7de7d32263..af04beb311b1 100644 --- a/tests/run/fors.scala +++ b/tests/run/fors.scala @@ -112,6 +112,8 @@ object Test extends App { /////////////////// elimination of map /////////////////// + import scala.language.experimental.betterFors + @tailrec def pair[B](xs: List[Int], ys: List[B], n: Int): List[(Int, B)] = if n == 0 then xs.zip(ys)