Skip to content

Commit

Permalink
Submerge the Variance extractor function into the solve functions.
Browse files Browse the repository at this point in the history
In the `solvedTypes` and the `solve` functions, there is a third
parameter to give the specific variances, in the context of the
resolution, of each parameter which goes in the second parameter.
In fact, this `variances` list is always a `map` of a function,
which is different in each call, on the second list of symbols.

We replace the third parameter from being the list of variances to
being the function that is used to get that third list, and
thus merge the application of that list in each step of the foreach.

This has these benefits:

- We avoid allocating the list of variances, particularly for the case
  in which we are just using a constant function to Invariant,
  before the call.
- Since the only relevant information is whether or not a type
  parameter is contravariant, which is one bit (boolean), we
  use a BitSet to store that information.
- To use a BitSet, we need indices, so in the solve we replace the use
  of map and foreach, by the utility foreachWithIndex.
- By using a Variance.Extractor instead of a Function1, as required by
  the List.map function, we can avoid allocations of Variance objects,
  and use instead the underlying integer value.

There could be a small performance prejudice: the double-nested loop
of the solve method could compute the variances up to N times.
  • Loading branch information
diesalbla authored and retronym committed Mar 21, 2019
1 parent 6990ee3 commit 78cf068
Show file tree
Hide file tree
Showing 7 changed files with 36 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ trait Validators {
checkMacroImplResultTypeMismatch(atpeToRtpe(aret), rret)

val maxLubDepth = lubDepth(aparamss.flatten map (_.tpe)) max lubDepth(rparamss.flatten map (_.tpe))
val atargs = solvedTypes(atvars, atparams, atparams map varianceInType(aret), upper = false, maxLubDepth)
val atargs = solvedTypes(atvars, atparams, varianceInType(aret), upper = false, maxLubDepth)
val boundsOk = typer.silent(_.infer.checkBounds(macroDdef, NoPrefix, NoSymbol, atparams, atargs, ""))
boundsOk match {
case SilentResultValue(true) => // do nothing, success
Expand Down
4 changes: 2 additions & 2 deletions src/compiler/scala/tools/nsc/typechecker/Implicits.scala
Original file line number Diff line number Diff line change
Expand Up @@ -598,7 +598,7 @@ trait Implicits {
if (StatisticsStatics.areSomeColdStatsEnabled) statistics.incCounter(matchesPtInstMismatch1)
false
} else {
val targs = solvedTypes(tvars, allUndetparams, allUndetparams map varianceInType(wildPt), upper = false, lubDepth(tpInstantiated :: wildPt :: Nil))
val targs = solvedTypes(tvars, allUndetparams, varianceInType(wildPt), upper = false, lubDepth(tpInstantiated :: wildPt :: Nil))
val adjusted = adjustTypeArgs(allUndetparams, tvars, targs)
val tpSubst = deriveTypeWithWildcards(adjusted.undetParams)(tp.instantiateTypeParams(adjusted.okParams, adjusted.okArgs))
if(!matchesPt(tpSubst, wildPt, adjusted.undetParams)) {
Expand Down Expand Up @@ -796,7 +796,7 @@ trait Implicits {
if (tvars.nonEmpty)
typingLog("solve", ptLine("tvars" -> tvars, "tvars.constr" -> tvars.map(_.constr)))

val targs = solvedTypes(tvars, undetParams, undetParams map varianceInType(pt), upper = false, lubDepth(itree3.tpe :: pt :: Nil))
val targs = solvedTypes(tvars, undetParams, varianceInType(pt), upper = false, lubDepth(itree3.tpe :: pt :: Nil))

// #2421: check that we correctly instantiated type parameters outside of the implicit tree:
checkBounds(itree3, NoPrefix, NoSymbol, undetParams, targs, "inferred ")
Expand Down
19 changes: 9 additions & 10 deletions src/compiler/scala/tools/nsc/typechecker/Infer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -137,20 +137,20 @@ trait Infer extends Checkable {
*
* @param tvars All type variables to be instantiated.
* @param tparams The type parameters corresponding to `tvars`
* @param variances The variances of type parameters; need to reverse
* @param getVariance Function to extract variances of type parameters; we need to reverse
* solution direction for all contravariant variables.
* @param upper When `true` search for max solution else min.
* @throws NoInstance
*/
def solvedTypes(tvars: List[TypeVar], tparams: List[Symbol], variances: List[Variance], upper: Boolean, depth: Depth): List[Type] = {
def solvedTypes(tvars: List[TypeVar], tparams: List[Symbol], getVariance: Variance.Extractor[Symbol], upper: Boolean, depth: Depth): List[Type] = {
if (tvars.isEmpty) Nil else {
printTyping("solving for " + parentheses(map2(tparams, tvars)((p, tv) => s"${p.name}: $tv")))
// !!! What should be done with the return value of "solve", which is at present ignored?
// The historical commentary says "no panic, it's good enough to just guess a solution,
// we'll find out later whether it works", meaning don't issue an error here when types
// don't conform to bounds. That means you can never trust the results of implicit search.
// For an example where this was not being heeded, scala/bug#2421.
solve(tvars, tparams, variances, upper, depth)
solve(tvars, tparams, getVariance, upper, depth)
tvars map instantiate
}
}
Expand Down Expand Up @@ -377,7 +377,7 @@ trait Infer extends Checkable {
case mt: MethodType if mt.isImplicit && isFullyDefined(pt) => MethodType(mt.params, AnyTpe)
case _ => restpe
}
def solve() = solvedTypes(tvars, tparams, tparams map varianceInType(variance), upper = false, lubDepth(restpe :: pt :: Nil))
def solve() = solvedTypes(tvars, tparams, varianceInType(variance), upper = false, lubDepth(restpe :: pt :: Nil))

if (conforms)
try solve() catch { case _: NoInstance => null }
Expand Down Expand Up @@ -535,7 +535,7 @@ trait Infer extends Checkable {
"argument expression's type is not compatible with formal parameter type" + foundReqMsg(tp1, pt1))
}
}
val targs = solvedTypes(tvars, tparams, tparams map varianceInTypes(formals), upper = false, lubDepth(formals) max lubDepth(argtpes))
val targs = solvedTypes(tvars, tparams, varianceInTypes(formals), upper = false, lubDepth(formals) max lubDepth(argtpes))
// Can warn about inferring Any/AnyVal as long as they don't appear
// explicitly anywhere amongst the formal, argument, result, or expected type.
// ...or lower bound of a type param, since they're asking for it.
Expand Down Expand Up @@ -1016,13 +1016,12 @@ trait Infer extends Checkable {
try {
// debuglog("TVARS "+ (tvars map (_.constr)))
// look at the argument types of the primary constructor corresponding to the pattern
val variances =
if (ctorTp.paramTypes.isEmpty) undetparams map varianceInType(ctorTp)
else undetparams map varianceInTypes(ctorTp.paramTypes)
val varianceFun: Variance.Extractor[Symbol] =
if (ctorTp.paramTypes.isEmpty) varianceInType(ctorTp) else varianceInTypes(ctorTp.paramTypes)

// Note: this is the only place where solvedTypes (or, indirectly, solve) is called
// with upper = true.
val targs = solvedTypes(tvars, undetparams, variances, upper = true, lubDepth(resTp :: pt :: Nil))
val targs = solvedTypes(tvars, undetparams, varianceFun, upper = true, lubDepth(resTp :: pt :: Nil))
// checkBounds(tree, NoPrefix, NoSymbol, undetparams, targs, "inferred ")
// no checkBounds here. If we enable it, test bug602 fails.
// TODO: reinstate checkBounds, return params that fail to meet their bounds to undetparams
Expand Down Expand Up @@ -1091,7 +1090,7 @@ trait Infer extends Checkable {
val tvars1 = tvars map (_.cloneInternal)
// Note: right now it's not clear that solving is complete, or how it can be made complete!
// So we should come back to this and investigate.
solve(tvars1, tvars1 map (_.origin.typeSymbol), tvars1 map (_ => Variance.Covariant), upper = false, Depth.AnyDepth)
solve(tvars1, tvars1.map(_.origin.typeSymbol), (_ => Variance.Covariant), upper = false, Depth.AnyDepth)
}

// this is quite nasty: it destructively changes the info of the syms of e.g., method type params
Expand Down
4 changes: 1 addition & 3 deletions src/compiler/scala/tools/nsc/typechecker/Typers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2930,10 +2930,8 @@ trait Typers extends Adaptations with Tags with TypersTracking with PatternTyper
// use function type subtyping, not method type subtyping (the latter is invariant in argument types)
fun.tpe <:< functionType(samInfoWithTVars.paramTypes, samInfoWithTVars.finalResultType)

val variances = tparams map varianceInType(sam.info)

// solve constraints tracked by tvars
val targs = solvedTypes(tvars, tparams, variances, upper = false, lubDepth(sam.info :: Nil))
val targs = solvedTypes(tvars, tparams, varianceInType(sam.info), upper = false, lubDepth(sam.info :: Nil))

debuglog(s"sam infer: $pt --> ${appliedType(samTyCon, targs)} by ${fun.tpe} <:< $samInfoWithTVars --> $targs for $tparams")

Expand Down
2 changes: 1 addition & 1 deletion src/reflect/scala/reflect/internal/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2871,7 +2871,7 @@ trait Types
val tvars = quantifiedFresh map (tparam => TypeVar(tparam))
val underlying1 = underlying.instantiateTypeParams(quantified, tvars) // fuse subst quantified -> quantifiedFresh -> tvars
op(underlying1) && {
solve(tvars, quantifiedFresh, quantifiedFresh map (_ => Invariant), upper = false, depth) &&
solve(tvars, quantifiedFresh, (_ => Invariant), upper = false, depth) &&
isWithinBounds(NoPrefix, NoSymbol, quantifiedFresh, tvars map (_.inst))
}
}
Expand Down
33 changes: 21 additions & 12 deletions src/reflect/scala/reflect/internal/tpe/TypeConstraints.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package tpe

import scala.collection.{ generic }
import generic.Clearable
import scala.collection.mutable.BitSet

private[internal] trait TypeConstraints {
self: SymbolTable =>
Expand Down Expand Up @@ -195,32 +196,40 @@ private[internal] trait TypeConstraints {

/** Solve constraint collected in types `tvars`.
*
* @param tvars All type variables to be instantiated.
* @param tparams The type parameters corresponding to `tvars`
* @param variances The variances of type parameters; need to reverse
* @param tvars All type variables to be instantiated.
* @param tparams The type parameters corresponding to `tvars`
* @param getVariance Function to extract variances of type parameters; we need to reverse
* solution direction for all contravariant variables.
* @param upper When `true` search for max solution else min.
* @param upper When `true` search for max solution else min.
*/
def solve(tvars: List[TypeVar], tparams: List[Symbol], variances: List[Variance], upper: Boolean, depth: Depth): Boolean = {
def solve(tvars: List[TypeVar], tparams: List[Symbol], getVariance: Variance.Extractor[Symbol], upper: Boolean, depth: Depth): Boolean = {
assert(tvars.corresponds(tparams)((tvar, tparam) => tvar.origin.typeSymbol eq tparam), (tparams, tvars.map(_.origin.typeSymbol)))
val areContravariant: BitSet = BitSet.empty
foreachWithIndex(tparams){(tparam, ix) =>
if (getVariance(tparam).isContravariant) areContravariant += ix
}

def solveOne(tvar: TypeVar, tparam: Symbol, variance: Variance) {
def solveOne(tvar: TypeVar, ix: Int): Unit = {
val tparam = tvar.origin.typeSymbol
val isContravariant = areContravariant(ix)
if (tvar.constr.inst == NoType) {
val up = if (variance.isContravariant) !upper else upper
val up = if (isContravariant) !upper else upper
tvar.constr.inst = null
val bound: Type = if (up) tparam.info.upperBound else tparam.info.lowerBound
//Console.println("solveOne0(tv, tp, v, b)="+(tvar, tparam, variance, bound))
var cyclic = bound contains tparam
foreach3(tvars, tparams, variances)((tvar2, tparam2, variance2) => {
foreachWithIndex(tvars){ (tvar2, jx) =>
val tparam2 = tvar2.origin.typeSymbol
val ok = (tparam2 != tparam) && (
(bound contains tparam2)
|| up && (tparam2.info.lowerBound =:= tparam.tpeHK)
|| !up && (tparam2.info.upperBound =:= tparam.tpeHK)
)
if (ok) {
if (tvar2.constr.inst eq null) cyclic = true
solveOne(tvar2, tparam2, variance2)
solveOne(tvar2, jx)
}
})
}
if (!cyclic) {
if (up) {
if (bound.typeSymbol != AnyClass) {
Expand Down Expand Up @@ -260,7 +269,7 @@ private[internal] trait TypeConstraints {
if (depth.isAnyDepth) lub(tvar.constr.loBounds)
else lub(tvar.constr.loBounds, depth)
}
)
)

debuglog(s"$tvar setInst $newInst")
tvar setInst newInst
Expand All @@ -269,7 +278,7 @@ private[internal] trait TypeConstraints {
}

// println("solving "+tvars+"/"+tparams+"/"+(tparams map (_.info)))
foreach3(tvars, tparams, variances)(solveOne)
foreachWithIndex(tvars)(solveOne)

def logBounds(tv: TypeVar) = log {
val what = if (!tv.instValid) "is invalid" else s"does not conform to bounds: ${tv.constr}"
Expand Down
2 changes: 1 addition & 1 deletion src/reflect/scala/reflect/internal/util/Collections.scala
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ trait Collections {
xss.isEmpty || xss.head.isEmpty && flattensToEmpty(xss.tail)
}

final def foreachWithIndex[A, B](xs: List[A])(f: (A, Int) => Unit) {
final def foreachWithIndex[A](xs: List[A])(f: (A, Int) => Unit) {
var index = 0
var ys = xs
while (!ys.isEmpty) {
Expand Down

0 comments on commit 78cf068

Please sign in to comment.