Skip to content

Commit

Permalink
Add improvements to for comprehensions
Browse files Browse the repository at this point in the history
- Allow `for`-comprehensions to start with aliases desugaring them into
  valdefs in a new block
- Desugar aliases into simple valdefs, instead of patterns when they are
  not followed by a guard
- Add an experimental language flag that enables the new desugaring
  method
  • Loading branch information
KacperFKorban committed Jul 22, 2024
1 parent bd0aa52 commit 9c3e454
Show file tree
Hide file tree
Showing 8 changed files with 263 additions and 45 deletions.
161 changes: 117 additions & 44 deletions compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import NameKinds.{UniqueName, ContextBoundParamName, ContextFunctionParamName, D
import typer.{Namer, Checking}
import util.{Property, SourceFile, SourcePosition, SrcPos, Chars}
import config.{Feature, Config}
import config.Feature.{sourceVersion, migrateTo3, enabled, betterForsEnabled}
import config.SourceVersion.*
import collection.mutable
import reporting.*
Expand Down Expand Up @@ -1807,7 +1808,7 @@ object desugar {
*
* 1.
*
* for (P <- G) E ==> G.foreach (P => E)
* for (P <- G) do E ==> G.foreach (P => E)
*
* Here and in the following (P => E) is interpreted as the function (P => E)
* if P is a variable pattern and as the partial function { case P => E } otherwise.
Expand All @@ -1816,11 +1817,11 @@ object desugar {
*
* for (P <- G) yield P ==> G
*
* if P is a variable or a tuple of variables and G is not a withFilter.
* If P is a variable or a tuple of variables and G is not a withFilter.
*
* for (P <- G) yield E ==> G.map (P => E)
*
* otherwise
* Otherwise
*
* 3.
*
Expand All @@ -1830,25 +1831,48 @@ object desugar {
*
* 4.
*
* for (P <- G; E; ...) ...
* =>
* for (P <- G.filter (P => E); ...) ...
* for (P <- G; if E; ...) ...
* ==>
* for (P <- G.withFilter (P => E); ...) ...
*
* 5. For any N:
*
* for (P_1 <- G; P_2 = E_2; val P_N = E_N; ...)
* for (P <- G; P_1 = E_1; ... P_N = E_N; rest)
* ==>
* for (TupleN(P_1, P_2, ... P_N) <-
* for (x_1 @ P_1 <- G) yield {
* val x_2 @ P_2 = E_2
* G.flatMap (P => for (P_1 = E_1; ... P_N = E_N; ...)) if rest contains (<-)
* G.map (P => for (P_1 = E_1; ... P_N = E_N; ...)) otherwise
*
* 6. For any N:
*
* for (P <- G; P_1 = E_1; ... P_N = E_N; if E; ...)
* ==>
* for (TupleN(P, P_1, ... P_N) <-
* for (x @ P <- G) yield {
* val x_1 @ P_1 = E_2
* ...
* val x_N & P_N = E_N
* TupleN(x_1, ..., x_N)
* } ...)
* val x_N @ P_N = E_N
* TupleN(x, x_1, ..., x_N)
* }; if E; ...)
*
* If any of the P_i are variable patterns, the corresponding `x_i @ P_i` is not generated
* and the variable constituting P_i is used instead of x_i
*
* 7. For any N:
*
* for (P_1 = E_1; ... P_N = E_N; ...)
* ==>
* {
* val x_N @ P_N = E_N
* for (...)
* }
*
* 8.
* for () yield E ==> E
*
* (Where empty for-comprehensions are excluded by the parser)
*
* If the aliases are not followed by a guard, otherwise an error.
*
* @param mapName The name to be used for maps (either map or foreach)
* @param flatMapName The name to be used for flatMaps (either flatMap or foreach)
* @param enums The enumerators in the for expression
Expand Down Expand Up @@ -1973,37 +1997,86 @@ object desugar {
case (Tuple(ts1), Tuple(ts2)) => ts1.corresponds(ts2)(deepEquals)
case _ => false

enums match {
case (gen: GenFrom) :: Nil =>
if gen.checkMode != GenCheckMode.Filtered // results of withFilter have the wrong type
&& deepEquals(gen.pat, body)
then gen.expr // avoid a redundant map with identity
else Apply(rhsSelect(gen, mapName), makeLambda(gen, body))
case (gen: GenFrom) :: (rest @ (GenFrom(_, _, _) :: _)) =>
val cont = makeFor(mapName, flatMapName, rest, body)
Apply(rhsSelect(gen, flatMapName), makeLambda(gen, cont))
case (gen: GenFrom) :: (rest @ GenAlias(_, _) :: _) =>
val (valeqs, rest1) = rest.span(_.isInstanceOf[GenAlias])
val pats = valeqs map { case GenAlias(pat, _) => pat }
val rhss = valeqs map { case GenAlias(_, rhs) => rhs }
val (defpat0, id0) = makeIdPat(gen.pat)
val (defpats, ids) = (pats map makeIdPat).unzip
val pdefs = valeqs.lazyZip(defpats).lazyZip(rhss).map { (valeq, defpat, rhs) =>
val mods = defpat match
case defTree: DefTree => defTree.mods
case _ => Modifiers()
makePatDef(valeq, mods, defpat, rhs)
}
val rhs1 = makeFor(nme.map, nme.flatMap, GenFrom(defpat0, gen.expr, gen.checkMode) :: Nil, Block(pdefs, makeTuple(id0 :: ids)))
val allpats = gen.pat :: pats
val vfrom1 = GenFrom(makeTuple(allpats), rhs1, GenCheckMode.Ignore)
makeFor(mapName, flatMapName, vfrom1 :: rest1, body)
case (gen: GenFrom) :: test :: rest =>
val filtered = Apply(rhsSelect(gen, nme.withFilter), makeLambda(gen, test))
val genFrom = GenFrom(gen.pat, filtered, GenCheckMode.Filtered)
makeFor(mapName, flatMapName, genFrom :: rest, body)
case _ =>
EmptyTree //may happen for erroneous input
if betterForsEnabled then
enums match {
case Nil => body
case (gen: GenFrom) :: Nil =>
if gen.checkMode != GenCheckMode.Filtered // results of withFilter have the wrong type
&& deepEquals(gen.pat, body)
then gen.expr // avoid a redundant map with identity
else Apply(rhsSelect(gen, mapName), makeLambda(gen, body))
case (gen: GenFrom) :: rest
if rest.dropWhile(_.isInstanceOf[GenAlias]).headOption.forall(e => e.isInstanceOf[GenFrom]) =>
val cont = makeFor(mapName, flatMapName, rest, body)
val selectName =
if rest.exists(_.isInstanceOf[GenFrom]) then flatMapName
else mapName
Apply(rhsSelect(gen, selectName), makeLambda(gen, cont))
case (gen: GenFrom) :: (rest @ GenAlias(_, _) :: _) =>
val (valeqs, rest1) = rest.span(_.isInstanceOf[GenAlias])
val pats = valeqs map { case GenAlias(pat, _) => pat }
val rhss = valeqs map { case GenAlias(_, rhs) => rhs }
val (defpat0, id0) = makeIdPat(gen.pat)
val (defpats, ids) = (pats map makeIdPat).unzip
val pdefs = valeqs.lazyZip(defpats).lazyZip(rhss).map { (valeq, defpat, rhs) =>
val mods = defpat match
case defTree: DefTree => defTree.mods
case _ => Modifiers()
makePatDef(valeq, mods, defpat, rhs)
}
val rhs1 = makeFor(nme.map, nme.flatMap, GenFrom(defpat0, gen.expr, gen.checkMode) :: Nil, Block(pdefs, makeTuple(id0 :: ids)))
val allpats = gen.pat :: pats
val vfrom1 = GenFrom(makeTuple(allpats), rhs1, GenCheckMode.Ignore)
makeFor(mapName, flatMapName, vfrom1 :: rest1, body)
case (gen: GenFrom) :: test :: rest =>
val filtered = Apply(rhsSelect(gen, nme.withFilter), makeLambda(gen, test))
val genFrom = GenFrom(gen.pat, filtered, GenCheckMode.Filtered)
makeFor(mapName, flatMapName, genFrom :: rest, body)
case GenAlias(_, _) :: _ =>
val (valeqs, rest) = enums.span(_.isInstanceOf[GenAlias])
val pats = valeqs.map { case GenAlias(pat, _) => pat }
val rhss = valeqs.map { case GenAlias(_, rhs) => rhs }
val (defpats, ids) = pats.map(makeIdPat).unzip
val pdefs = valeqs.lazyZip(defpats).lazyZip(rhss).map { (valeq, defpat, rhs) =>
val mods = defpat match
case defTree: DefTree => defTree.mods
case _ => Modifiers()
makePatDef(valeq, mods, defpat, rhs)
}
Block(pdefs, makeFor(mapName, flatMapName, rest, body))
case _ =>
EmptyTree //may happen for erroneous input
}
else {
enums match {
case (gen: GenFrom) :: Nil =>
Apply(rhsSelect(gen, mapName), makeLambda(gen, body))
case (gen: GenFrom) :: (rest @ (GenFrom(_, _, _) :: _)) =>
val cont = makeFor(mapName, flatMapName, rest, body)
Apply(rhsSelect(gen, flatMapName), makeLambda(gen, cont))
case (gen: GenFrom) :: (rest @ GenAlias(_, _) :: _) =>
val (valeqs, rest1) = rest.span(_.isInstanceOf[GenAlias])
val pats = valeqs map { case GenAlias(pat, _) => pat }
val rhss = valeqs map { case GenAlias(_, rhs) => rhs }
val (defpat0, id0) = makeIdPat(gen.pat)
val (defpats, ids) = (pats map makeIdPat).unzip
val pdefs = valeqs.lazyZip(defpats).lazyZip(rhss).map { (valeq, defpat, rhs) =>
val mods = defpat match
case defTree: DefTree => defTree.mods
case _ => Modifiers()
makePatDef(valeq, mods, defpat, rhs)
}
val rhs1 = makeFor(nme.map, nme.flatMap, GenFrom(defpat0, gen.expr, gen.checkMode) :: Nil, Block(pdefs, makeTuple(id0 :: ids)))
val allpats = gen.pat :: pats
val vfrom1 = GenFrom(makeTuple(allpats), rhs1, GenCheckMode.Ignore)
makeFor(mapName, flatMapName, vfrom1 :: rest1, body)
case (gen: GenFrom) :: test :: rest =>
val filtered = Apply(rhsSelect(gen, nme.withFilter), makeLambda(gen, test))
val genFrom = GenFrom(gen.pat, filtered, GenCheckMode.Ignore)
makeFor(mapName, flatMapName, genFrom :: rest, body)
case _ =>
EmptyTree //may happen for erroneous input
}
}
}

Expand Down
3 changes: 3 additions & 0 deletions compiler/src/dotty/tools/dotc/config/Feature.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ object Feature:
val modularity = experimental("modularity")
val betterMatchTypeExtractors = experimental("betterMatchTypeExtractors")
val quotedPatternsWithPolymorphicFunctions = experimental("quotedPatternsWithPolymorphicFunctions")
val betterFors = experimental("betterFors")

def experimentalAutoEnableFeatures(using Context): List[TermName] =
defn.languageExperimentalFeatures
Expand Down Expand Up @@ -125,6 +126,8 @@ object Feature:
def clauseInterleavingEnabled(using Context) =
sourceVersion.isAtLeast(`3.6`) || enabled(clauseInterleaving)

def betterForsEnabled(using Context) = enabled(betterFors)

def genericNumberLiteralsEnabled(using Context) = enabled(genericNumberLiterals)

def scala2ExperimentalMacroEnabled(using Context) = enabled(scala2macros)
Expand Down
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/core/StdNames.scala
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,7 @@ object StdNames {
val asInstanceOfPM: N = "$asInstanceOf$"
val assert_ : N = "assert"
val assume_ : N = "assume"
val betterFors: N = "betterFors"
val box: N = "box"
val break: N = "break"
val build : N = "build"
Expand Down
18 changes: 17 additions & 1 deletion compiler/src/dotty/tools/dotc/parsing/Parsers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2891,7 +2891,11 @@ object Parsers {

/** Enumerators ::= Generator {semi Enumerator | Guard}
*/
def enumerators(): List[Tree] = generator() :: enumeratorsRest()
def enumerators(): List[Tree] =
if in.featureEnabled(Feature.betterFors) then
aliasesUntilGenerator() ++ enumeratorsRest()
else
generator() :: enumeratorsRest()

def enumeratorsRest(): List[Tree] =
if (isStatSep) {
Expand Down Expand Up @@ -2933,6 +2937,18 @@ object Parsers {
GenFrom(pat, subExpr(), checkMode)
}

def aliasesUntilGenerator(): List[Tree] =
if in.token == CASE then generator() :: Nil
else {
val pat = pattern1()
if in.token == EQUALS then
atSpan(startOffset(pat), in.skipToken()) { GenAlias(pat, subExpr()) } :: {
if (isStatSep) in.nextToken()
aliasesUntilGenerator()
}
else generatorRest(pat, casePat = false) :: Nil
}

/** ForExpr ::= ‘for’ ‘(’ Enumerators ‘)’ {nl} [‘do‘ | ‘yield’] Expr
* | ‘for’ ‘{’ Enumerators ‘}’ {nl} [‘do‘ | ‘yield’] Expr
* | ‘for’ Enumerators (‘do‘ | ‘yield’) Expr
Expand Down
6 changes: 6 additions & 0 deletions library/src/scala/runtime/stdLibPatches/language.scala
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,12 @@ object language:
@compileTimeOnly("`quotedPatternsWithPolymorphicFunctions` can only be used at compile time in import statements")
object quotedPatternsWithPolymorphicFunctions

/** Experimental support for improvements in `for` comprehensions
*
* @see [[https://dotty.epfl.ch/docs/reference/experimental/better-fors]]
*/
@compileTimeOnly("`betterFors` can only be used at compile time in import statements")
object betterFors
end experimental

/** The deprecated object contains features that are no longer officially suypported in Scala.
Expand Down
12 changes: 12 additions & 0 deletions tests/run/better-fors.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
List((1,3), (1,4), (2,3), (2,4))
List((1,2,3), (1,2,4))
List((1,3), (1,4), (2,3), (2,4))
List((2,3), (2,4))
List((2,3), (2,4))
List((1,2), (2,4))
List(1, 2, 3)
List((2,3,6))
List(6)
List(3, 6)
List(6)
List(2)
105 changes: 105 additions & 0 deletions tests/run/better-fors.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import scala.language.experimental.betterFors

def for1 =
for {
a = 1
b <- List(a, 2)
c <- List(3, 4)
} yield (b, c)

def for2 =
for
a = 1
b = 2
c <- List(3, 4)
yield (a, b, c)

def for3 =
for {
a = 1
b <- List(a, 2)
c = 3
d <- List(c, 4)
} yield (b, d)

def for4 =
for {
a = 1
b <- List(a, 2)
if b > 1
c <- List(3, 4)
} yield (b, c)

def for5 =
for {
a = 1
b <- List(a, 2)
c = 3
if b > 1
d <- List(c, 4)
} yield (b, d)

def for6 =
for {
a = 1
b = 2
c <- for {
x <- List(a, b)
y = x * 2
} yield (x, y)
} yield c

def for7 =
for {
a <- List(1, 2, 3)
} yield a

def for8 =
for {
a <- List(1, 2)
b = a + 1
if b > 2
c = b * 2
if c < 8
} yield (a, b, c)

def for9 =
for {
a <- List(1, 2)
b = a * 2
if b > 2
} yield a + b

def for10 =
for {
a <- List(1, 2)
b = a * 2
} yield a + b

def for11 =
for {
a <- List(1, 2)
b = a * 2
if b > 2 && b % 2 == 0
} yield a + b

def for12 =
for {
a <- List(1, 2)
if a > 1
} yield a

object Test extends App {
println(for1)
println(for2)
println(for3)
println(for4)
println(for5)
println(for6)
println(for7)
println(for8)
println(for9)
println(for10)
println(for11)
println(for12)
}
Loading

0 comments on commit 9c3e454

Please sign in to comment.