Skip to content

Commit

Permalink
Add support for unapplySeq
Browse files Browse the repository at this point in the history
  • Loading branch information
liufengyun committed Jul 10, 2023
1 parent d93a214 commit 6d645b7
Showing 1 changed file with 98 additions and 9 deletions.
107 changes: 98 additions & 9 deletions compiler/src/dotty/tools/dotc/transform/init/Objects.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@ import core.*
import Contexts.*
import Symbols.*
import Types.*
import Denotations.Denotation
import StdNames.*
import Names.TermName
import NameKinds.OuterSelectName
import NameKinds.SuperAccessorName

import ast.tpd.*
import util.SourcePosition
import util.{ SourcePosition, NoSourcePosition }
import config.Printers.init as printer
import reporting.StoreReporter
import reporting.trace as log
Expand Down Expand Up @@ -1176,6 +1178,16 @@ object Objects:
* @param klass The enclosing class where the type `tp` is located.
*/
def patternMatch(scrutinee: Value, cases: List[CaseDef], thisV: Value, klass: ClassSymbol): Contextual[Value] =
// expected member types for `unapplySeq`
def lengthType = ExprType(defn.IntType)
def lengthCompareType = MethodType(List(defn.IntType), defn.IntType)
def applyType(elemTp: Type) = MethodType(List(defn.IntType), elemTp)
def dropType(elemTp: Type) = MethodType(List(defn.IntType), defn.CollectionSeqType.appliedTo(elemTp))
def toSeqType(elemTp: Type) = ExprType(defn.CollectionSeqType.appliedTo(elemTp))

def getMemberMethod(receiver: Type, name: TermName, tp: Type): Denotation =
receiver.member(name).suchThat(receiver.memberInfo(_) <:< tp)

def evalCase(caseDef: CaseDef): Value =
evalPattern(scrutinee, caseDef.pat)
eval(caseDef.guard, thisV, klass)
Expand Down Expand Up @@ -1206,18 +1218,59 @@ object Objects:
case UnApply(fun, implicits, pats) =>
val fun1 = funPart(fun)
val funRef = fun1.tpe.asInstanceOf[TermRef]
val unapplyResTp = funRef.widen.finalResultType

val receiver = evalType(funRef.prefix, thisV, klass)
val implicitValues = evalArgs(implicits.map(Arg.apply), thisV, klass)
// TODO: implicit values may appear before and/or after the scrutinee parameter.
val unapplyRes = call(receiver, funRef.symbol, TraceValue(scrutinee, summon[Trace]) :: implicitValues, funRef.prefix, superType = NoType, needResolve = true)

if fun.symbol.name == nme.unapplySeq then
// TODO: handle unapplySeq
()
var resultTp = unapplyResTp
var elemTp = unapplySeqTypeElemTp(resultTp)
var arity = productArity(resultTp, NoSourcePosition)
var needsGet = false
if (!elemTp.exists && arity <= 0) {
needsGet = true
resultTp = resultTp.select(nme.get).finalResultType
elemTp = unapplySeqTypeElemTp(resultTp.widen)
arity = productSelectorTypes(resultTp, NoSourcePosition).size
}

var resToMatch = unapplyRes

if needsGet then
// Get match
val isEmptyDenot = unapplyResTp.member(nme.isEmpty).suchThat(_.info.isParameterless)
call(unapplyRes, isEmptyDenot.symbol, Nil, unapplyResTp, superType = NoType, needResolve = true)

val getDenot = unapplyResTp.member(nme.get).suchThat(_.info.isParameterless)
resToMatch = call(unapplyRes, getDenot.symbol, Nil, unapplyResTp, superType = NoType, needResolve = true)
end if

if elemTp.exists then
// sequence match
evalSeqPatterns(resToMatch, resultTp, elemTp, pats)
else
// product sequence match
val selectors = productSelectors(resultTp)
assert(selectors.length <= pats.length)
selectors.init.zip(pats).map { (sel, pat) =>
val selectRes = call(resToMatch, sel, Nil, resultTp, superType = NoType, needResolve = true)
evalPattern(selectRes, pat)
}
val seqPats = pats.drop(selectors.length - 1)
val toSeqRes = call(resToMatch, selectors.last, Nil, resultTp, superType = NoType, needResolve = true)
val toSeqResTp = resultTp.memberInfo(selectors.last).finalResultType
evalSeqPatterns(toSeqRes, toSeqResTp, elemTp, seqPats)
end if

else
val receiver = evalType(funRef.prefix, thisV, klass)
val implicitValues = evalArgs(implicits.map(Arg.apply), thisV, klass)
val unapplyRes = call(receiver, funRef.symbol, TraceValue(scrutinee, summon[Trace]) :: implicitValues, funRef.prefix, superType = NoType, needResolve = true)
// distribute unapply to patterns
val unapplyResTp = funRef.widen.finalResultType
if isProductMatch(unapplyResTp, pats.length) then
// product match
val selectors = productSelectors(unapplyResTp).take(pats.length)
val selectors = productSelectors(unapplyResTp)
assert(selectors.length == pats.length)
selectors.zip(pats).map { (sel, pat) =>
val selectRes = call(unapplyRes, sel, Nil, unapplyResTp, superType = NoType, needResolve = true)
evalPattern(selectRes, pat)
Expand All @@ -1239,7 +1292,7 @@ object Objects:
val getResTp = getDenot.info.finalResultType
val selectors = productSelectors(getResTp).take(pats.length)
selectors.zip(pats).map { (sel, pat) =>
val selectRes = call(unapplyRes, sel, Nil, unapplyResTp, superType = NoType, needResolve = true)
val selectRes = call(unapplyRes, sel, Nil, getResTp, superType = NoType, needResolve = true)
evalPattern(selectRes, pat)
}
end if
Expand All @@ -1259,6 +1312,42 @@ object Objects:

end evalPattern

/**
* Evaluate a sequence value against sequence patterns.
*/
def evalSeqPatterns(scrutinee: Value, scrutineeType: Type, elemType: Type, pats: List[Tree]): Unit =
// call .lengthCompare or .length
val lengthCompareDenot = getMemberMethod(scrutineeType, nme.lengthCompare, lengthCompareType)
if lengthCompareDenot.exists then
call(scrutinee, lengthCompareDenot.symbol, TraceValue(Bottom, summon[Trace]) :: Nil, scrutineeType, superType = NoType, needResolve = true)
else
val lengthDenot = getMemberMethod(scrutineeType, nme.length, lengthType)
call(scrutinee, lengthDenot.symbol, Nil, scrutineeType, superType = NoType, needResolve = true)
end if

// call .apply
val applyDenot = getMemberMethod(scrutineeType, nme.apply, applyType(elemType))
val applyRes = call(scrutinee, applyDenot.symbol, TraceValue(Bottom, summon[Trace]) :: Nil, scrutineeType, superType = NoType, needResolve = true)

if isWildcardStarArg(pats.last) then
if pats.size == 1 then
// call .toSeq
val toSeqDenot = scrutineeType.member(nme.toSeq).suchThat(_.info.isParameterless)
val toSeqRes = call(scrutinee, toSeqDenot.symbol, Nil, scrutineeType, superType = NoType, needResolve = true)
evalPattern(toSeqRes, pats.head)
else
// call .drop
val dropDenot = getMemberMethod(scrutineeType, nme.drop, applyType(elemType))
val dropRes = call(scrutinee, dropDenot.symbol, TraceValue(Bottom, summon[Trace]) :: Nil, scrutineeType, superType = NoType, needResolve = true)
for pat <- pats.init do evalPattern(applyRes, pat)
evalPattern(dropRes, pats.last)
end if
else
// no patterns like `xs*`
for pat <- pats do evalPattern(applyRes, pat)
end evalSeqPatterns


cases.map(evalCase).join


Expand Down

0 comments on commit 6d645b7

Please sign in to comment.