Skip to content

Commit

Permalink
Refactor adjustTypeArgs, giving its result record a name
Browse files Browse the repository at this point in the history
  • Loading branch information
retronym committed Mar 21, 2019
1 parent f967dcd commit cad96e2
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 59 deletions.
9 changes: 5 additions & 4 deletions src/compiler/scala/tools/nsc/typechecker/Implicits.scala
Original file line number Diff line number Diff line change
Expand Up @@ -599,9 +599,9 @@ trait Implicits {
false
} else {
val targs = solvedTypes(tvars, allUndetparams, allUndetparams map varianceInType(wildPt), upper = false, lubDepth(tpInstantiated :: wildPt :: Nil))
val AdjustedTypeArgs(okParams, okArgs) = adjustTypeArgs(allUndetparams, tvars, targs)
val remainingUndet = allUndetparams diff okParams
val tpSubst = deriveTypeWithWildcards(remainingUndet)(tp.instantiateTypeParams(okParams, okArgs))
val adjusted = adjustTypeArgs(allUndetparams, tvars, targs)
val remainingUndet = allUndetparams diff adjusted.okParams
val tpSubst = deriveTypeWithWildcards(remainingUndet)(tp.instantiateTypeParams(adjusted.okParams, adjusted.okArgs))
if(!matchesPt(tpSubst, wildPt, remainingUndet)) {
if (StatisticsStatics.areSomeColdStatsEnabled) statistics.incCounter(matchesPtInstMismatch2)
false
Expand Down Expand Up @@ -820,7 +820,8 @@ trait Implicits {
// filter out failures from type inference, don't want to remove them from undetParams!
// we must be conservative in leaving type params in undetparams
// prototype == WildcardType: want to remove all inferred Nothings
val AdjustedTypeArgs(okParams, okArgs) = adjustTypeArgs(undetParams, tvars, targs)
val adjusted = adjustTypeArgs(undetParams, tvars, targs)
import adjusted.{okParams, okArgs}

val subst: TreeTypeSubstituter =
if (okParams.isEmpty) EmptyTreeTypeSubstituter
Expand Down
91 changes: 36 additions & 55 deletions src/compiler/scala/tools/nsc/typechecker/Infer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
package scala.tools.nsc
package typechecker

import scala.collection.{ mutable, immutable }
import scala.collection.mutable.ListBuffer
import scala.collection.{immutable, mutable}
import scala.util.control.ControlThrowable
import symtab.Flags._
import scala.reflect.internal.Depth
Expand Down Expand Up @@ -447,27 +448,35 @@ trait Infer extends Checkable {
* @return map from tparams to inferred arg, if inference was successful, tparams that map to None are considered left undetermined
* type parameters that are inferred as `scala.Nothing` and that are not covariant in `restpe` are taken to be undetermined
*/
def adjustTypeArgs(tparams: List[Symbol], tvars: List[TypeVar], targs: List[Type], restpe: Type = WildcardType): AdjustedTypeArgs.Result = {
val buf = AdjustedTypeArgs.Result.newBuilder[Symbol, Option[Type]]
def adjustTypeArgs(tparams: List[Symbol], tvars: List[TypeVar], targs: List[Type], restpe: Type = WildcardType): AdjustedTypeArgs = {
val okParams = ListBuffer[Symbol]()
val okArgs = ListBuffer[Type]()
val undetParams = ListBuffer[Symbol]()
val allArgs = ListBuffer[Type]()

foreach3(tparams, tvars, targs) { (tparam, tvar, targ) =>
val retract = (
targ.typeSymbol == NothingClass // only retract Nothings
&& (restpe.isWildcard || !varianceInType(restpe)(tparam).isPositive) // don't retract covariant occurrences
)

buf += ((tparam,
if (retract) None
else Some(
if (targ.typeSymbol == RepeatedParamClass) targ.baseType(SeqClass)
if (retract) {
undetParams += tparam
allArgs += NothingTpe
} else {
val arg =
if (targ.typeSymbol == RepeatedParamClass) targ.baseType(SeqClass)
else if (targ.typeSymbol == JavaRepeatedParamClass) targ.baseType(ArrayClass)
// this infers Foo.type instead of "object Foo" (see also widenIfNecessary)
else if (targ.typeSymbol.isModuleClass || tvar.constr.avoidWiden) targ
else targ.widen
)
))
okParams += tparam
okArgs += arg
allArgs += arg
}
}
buf.result()

new AdjustedTypeArgs(tparams, okParams.toList, okArgs.toList, undetParams.toList, allArgs.toList)
}

/** Return inferred type arguments, given type parameters, formal parameters,
Expand All @@ -487,7 +496,7 @@ trait Infer extends Checkable {
* @throws NoInstance
*/
def methTypeArgs(fn: Tree, tparams: List[Symbol], formals: List[Type], restpe: Type,
argtpes: List[Type], pt: Type): AdjustedTypeArgs.Result = {
argtpes: List[Type], pt: Type): AdjustedTypeArgs = {
val tvars = tparams map freshVar
if (!sameLength(formals, argtpes))
throw new NoInstance("parameter lists differ in length")
Expand Down Expand Up @@ -703,12 +712,13 @@ trait Infer extends Checkable {
)
def tryInstantiating(args: List[Type]) = falseIfNoInstance {
val restpe = mt resultType args
val AdjustedTypeArgs.Undets(okparams, okargs, leftUndet) = methTypeArgs(EmptyTree, undetparams, formals, restpe, args, pt)
val restpeInst = restpe.instantiateTypeParams(okparams, okargs)
val adjusted = methTypeArgs(EmptyTree, undetparams, formals, restpe, args, pt)
import adjusted.{okParams, okArgs, undetParams}
val restpeInst = restpe.instantiateTypeParams(okParams, okArgs)
// #2665: must use weak conformance, not regular one (follow the monomorphic case above)
exprTypeArgs(leftUndet, restpeInst, pt, useWeaklyCompatible = true) match {
exprTypeArgs(undetParams, restpeInst, pt, useWeaklyCompatible = true) match {
case null => false
case _ => isWithinBounds(NoPrefix, NoSymbol, okparams, okargs)
case _ => isWithinBounds(NoPrefix, NoSymbol, okParams, okArgs)
}
}
def typesCompatible(args: List[Type]) = undetparams match {
Expand Down Expand Up @@ -911,15 +921,16 @@ trait Infer extends Checkable {
substExpr(tree, tparams, targsStrict, pt)
List()
} else {
val AdjustedTypeArgs.Undets(okParams, okArgs, leftUndet) = adjustTypeArgs(tparams, tvars, targsStrict)
val adjusted = adjustTypeArgs(tparams, tvars, targsStrict)
import adjusted.{okParams, okArgs, undetParams}
def solved_s = map2(okParams, okArgs)((p, a) => s"$p=$a") mkString ","
def undet_s = leftUndet match {
def undet_s = undetParams match {
case Nil => ""
case ps => ps.mkString(", undet=", ",", "")
}
printTyping(tree, s"infer solved $solved_s$undet_s")
substExpr(tree, okParams, okArgs, pt)
leftUndet
undetParams
}
}

Expand Down Expand Up @@ -956,15 +967,15 @@ trait Infer extends Checkable {
val argtpes = tupleIfNecessary(formals, args map (x => elimAnonymousClass(x.tpe.deconst)))
val restpe = fn.tpe.resultType(argtpes)

val AdjustedTypeArgs.AllArgsAndUndets(okparams, okargs, allargs, leftUndet) =
methTypeArgs(fn, undetparams, formals, restpe, argtpes, pt)
val adjusted = methTypeArgs(fn, undetparams, formals, restpe, argtpes, pt)
import adjusted.{okParams, okArgs, allArgs, undetParams}

if (checkBounds(fn, NoPrefix, NoSymbol, undetparams, allargs, "inferred ")) {
val treeSubst = new TreeTypeSubstituter(okparams, okargs)
if (checkBounds(fn, NoPrefix, NoSymbol, undetparams, allArgs, "inferred ")) {
val treeSubst = new TreeTypeSubstituter(okParams, okArgs)
treeSubst traverseTrees fn :: args
notifyUndetparamsInferred(okparams, okargs)
notifyUndetparamsInferred(okParams, okArgs)

leftUndet match {
undetParams match {
case Nil => Nil
case xs =>
// #3890
Expand Down Expand Up @@ -1427,35 +1438,5 @@ trait Infer extends Checkable {
}
}

/** [Martin] Can someone comment this please? I have no idea what it's for
* and the code is not exactly readable.
*/
object AdjustedTypeArgs {
val Result = mutable.LinkedHashMap
type Result = mutable.LinkedHashMap[Symbol, Option[Type]]

def unapply(m: Result): Some[(List[Symbol], List[Type])] = Some(toLists(
(m collect {case (p, Some(a)) => (p, a)}).unzip ))

object Undets {
def unapply(m: Result): Some[(List[Symbol], List[Type], List[Symbol])] = Some(toLists{
val (ok, nok) = m.map{case (p, a) => (p, a.getOrElse(null))}.partition(_._2 ne null)
val (okArgs, okTparams) = ok.unzip
(okArgs, okTparams, nok.keys)
})
}

object AllArgsAndUndets {
def unapply(m: Result): Some[(List[Symbol], List[Type], List[Type], List[Symbol])] = Some(toLists{
val (ok, nok) = m.map{case (p, a) => (p, a.getOrElse(null))}.partition(_._2 ne null)
val (okArgs, okTparams) = ok.unzip
(okArgs, okTparams, m.values.map(_.getOrElse(NothingTpe)), nok.keys)
})
}

private def toLists[A1, A2](pxs: (Iterable[A1], Iterable[A2])) = (pxs._1.toList, pxs._2.toList)
private def toLists[A1, A2, A3](pxs: (Iterable[A1], Iterable[A2], Iterable[A3])) = (pxs._1.toList, pxs._2.toList, pxs._3.toList)
private def toLists[A1, A2, A3, A4](pxs: (Iterable[A1], Iterable[A2], Iterable[A3], Iterable[A4])) = (pxs._1.toList, pxs._2.toList, pxs._3.toList, pxs._4.toList)
}

case class AdjustedTypeArgs(tparams: List[Symbol], okParams: List[Symbol], okArgs: List[Type], undetParams: List[Symbol], allArgs: List[Type])
}

0 comments on commit cad96e2

Please sign in to comment.