Skip to content

Commit

Permalink
Merge pull request scala#15544 from dwijnand/gadt/unsound-cast
Browse files Browse the repository at this point in the history
Use GADT constraints in maximiseType
  • Loading branch information
abgruszecki authored Jul 12, 2022
2 parents 6efd92d + 32826d8 commit 2ea400a
Show file tree
Hide file tree
Showing 9 changed files with 103 additions and 32 deletions.
42 changes: 19 additions & 23 deletions compiler/src/dotty/tools/dotc/core/GadtConstraint.scala
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ sealed abstract class GadtConstraint extends Showable {
/** See [[ConstraintHandling.approximation]] */
def approximation(sym: Symbol, fromBelow: Boolean)(using Context): Type

def symbols: List[Symbol]

def fresh: GadtConstraint

/** Restore the state from other [[GadtConstraint]], probably copied using [[fresh]] */
Expand Down Expand Up @@ -193,12 +195,7 @@ final class ProperGadtConstraint private(
case null => null
// TODO: Improve flow typing so that ascription becomes redundant
case tv: TypeVar =>
def retrieveBounds: TypeBounds =
bounds(tv.origin) match {
case TypeAlias(tpr: TypeParamRef) if reverseMapping.contains(tpr) =>
TypeAlias(reverseMapping(tpr).nn.typeRef)
case tb => tb
}
def retrieveBounds: TypeBounds = externalize(bounds(tv.origin)).bounds
retrieveBounds
//.showing(i"gadt bounds $sym: $result", gadts)
//.ensuring(containsNoInternalTypes(_))
Expand All @@ -222,6 +219,8 @@ final class ProperGadtConstraint private(
res
}

override def symbols: List[Symbol] = mapping.keys

override def fresh: GadtConstraint = new ProperGadtConstraint(
myConstraint,
mapping,
Expand All @@ -247,13 +246,7 @@ final class ProperGadtConstraint private(
override protected def isSame(tp1: Type, tp2: Type)(using Context): Boolean = TypeComparer.isSameType(tp1, tp2)

override def nonParamBounds(param: TypeParamRef)(using Context): TypeBounds =
val externalizeMap = new TypeMap {
def apply(tp: Type): Type = tp match {
case tpr: TypeParamRef => externalize(tpr)
case tp => mapOver(tp)
}
}
externalizeMap(constraint.nonParamBounds(param)).bounds
externalize(constraint.nonParamBounds(param)).bounds

override def fullLowerBound(param: TypeParamRef)(using Context): Type =
constraint.minLower(param).foldLeft(nonParamBounds(param).lo) {
Expand All @@ -270,27 +263,28 @@ final class ProperGadtConstraint private(

// ---- Private ----------------------------------------------------------

private def externalize(param: TypeParamRef)(using Context): Type =
reverseMapping(param) match {
private def externalize(tp: Type, theMap: TypeMap | Null = null)(using Context): Type = tp match
case param: TypeParamRef => reverseMapping(param) match
case sym: Symbol => sym.typeRef
case null => param
}
case null => param
case tp: TypeAlias => tp.derivedAlias(externalize(tp.alias, theMap))
case tp => (if theMap == null then ExternalizeMap() else theMap).mapOver(tp)

private class ExternalizeMap(using Context) extends TypeMap:
def apply(tp: Type): Type = externalize(tp, this)(using mapCtx)

private def tvarOrError(sym: Symbol)(using Context): TypeVar =
mapping(sym).ensuring(_ != null, i"not a constrainable symbol: $sym").uncheckedNN

private def containsNoInternalTypes(
tp: Type,
acc: TypeAccumulator[Boolean] | Null = null
)(using Context): Boolean = tp match {
private def containsNoInternalTypes(tp: Type, theAcc: TypeAccumulator[Boolean] | Null = null)(using Context): Boolean = tp match {
case tpr: TypeParamRef => !reverseMapping.contains(tpr)
case tv: TypeVar => !reverseMapping.contains(tv.origin)
case tp =>
(if (acc != null) acc else new ContainsNoInternalTypesAccumulator()).foldOver(true, tp)
(if (theAcc != null) theAcc else new ContainsNoInternalTypesAccumulator()).foldOver(true, tp)
}

private class ContainsNoInternalTypesAccumulator(using Context) extends TypeAccumulator[Boolean] {
override def apply(x: Boolean, tp: Type): Boolean = x && containsNoInternalTypes(tp)
override def apply(x: Boolean, tp: Type): Boolean = x && containsNoInternalTypes(tp, this)
}

// ---- Debug ------------------------------------------------------------
Expand Down Expand Up @@ -325,6 +319,8 @@ final class ProperGadtConstraint private(

override def approximation(sym: Symbol, fromBelow: Boolean)(using Context): Type = unsupported("EmptyGadtConstraint.approximation")

override def symbols: List[Symbol] = Nil

override def fresh = new ProperGadtConstraint
override def restore(other: GadtConstraint): Unit =
assert(!other.isNarrowing, "cannot restore a non-empty GADTMap")
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/typer/Applications.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1354,7 +1354,7 @@ trait Applications extends Compatibility {
// Constraining only fails if the pattern cannot possibly match,
// but useless pattern checks detect more such cases, so we simply rely on them instead.
withMode(Mode.GadtConstraintInference)(TypeComparer.constrainPatternType(unapplyArgType, selType))
val patternBound = maximizeType(unapplyArgType, tree.span)
val patternBound = maximizeType(unapplyArgType, unapplyFn.span.endPos)
if (patternBound.nonEmpty) unapplyFn = addBinders(unapplyFn, patternBound)
unapp.println(i"case 2 $unapplyArgType ${ctx.typerState.constraint}")
unapplyArgType
Expand Down
11 changes: 7 additions & 4 deletions compiler/src/dotty/tools/dotc/typer/Inferencing.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ import collection.mutable

import scala.annotation.internal.sharable

import config.Printers.gadts

object Inferencing {

import tpd._
Expand Down Expand Up @@ -408,10 +406,15 @@ object Inferencing {
Stats.record("maximizeType")
val vs = variances(tp)
val patternBindings = new mutable.ListBuffer[(Symbol, TypeParamRef)]
val gadtBounds = ctx.gadt.symbols.map(ctx.gadt.bounds(_).nn)
vs foreachBinding { (tvar, v) =>
if !tvar.isInstantiated then
if (v == 1) tvar.instantiate(fromBelow = false)
else if (v == -1) tvar.instantiate(fromBelow = true)
// if the tvar is covariant/contravariant (v == 1/-1, respectively) in the input type tp
// then it is safe to instantiate if it doesn't occur in any of the GADT bounds.
// Eg neg/i14983 the C in Node[+C] occurs in GADT bound X >: List[C] so maximising to Node[Any] is unsound
// Eg pos/precise-pattern-type the T in Tree[-T] doesn't occur in any GADT bound so can maximise to Tree[Type]
val safeToInstantiate = v != 0 && gadtBounds.forall(!tvar.occursIn(_))
if safeToInstantiate then tvar.instantiate(fromBelow = v == -1)
else {
val bounds = TypeComparer.fullBounds(tvar.origin)
if bounds.hi <:< bounds.lo || bounds.hi.classSymbol.is(Final) then
Expand Down
4 changes: 1 addition & 3 deletions compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3764,9 +3764,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
res
} =>
// Insert an explicit cast, so that -Ycheck in later phases succeeds.
// I suspect, but am not 100% sure that this might affect inferred types,
// if the expected type is a supertype of the GADT bound. It would be good to come
// up with a test case for this.
// The check "safeToInstantiate" in `maximizeType` works to prevent unsound GADT casts.
val target =
if tree.tpe.isSingleton then
val conj = AndType(tree.tpe, pt)
Expand Down
23 changes: 23 additions & 0 deletions tests/neg/i14983.co-contra.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
case class Showing[-C](show: C => String)

sealed trait Tree[+A]
final case class Leaf[+B](b: B) extends Tree[B]
final case class Node[-C](l: Showing[C]) extends Tree[Showing[C]]

object Test:
def meth[X](tree: Tree[X]): X = tree match
case Leaf(v) => v
case Node(x) =>
// tree: Tree[X] vs Node[C] aka Tree[Showing[C]]
// PTC: X >: Showing[C]
// max: Node[C] to Node[Nothing], instantiating C := Nothing, which makes X >: Showing[Nothing]
// adapt: Showing[String] <: X = OKwithGADTUsed; insert GADT cast asInstanceOf[X]
Showing[String](_ + " boom!") // error: Found: Showing[String] Required: X where: X is a type in method meth with bounds >: Showing[C$1]
// after fix:
// max: Node[C] => Node[C$1], instantiating C := C$1, a new symbol, so X >: Showing[C$1]
// adapt: Showing[String] <: X = Fail, because String !<: C$1

def main(args: Array[String]): Unit =
val tree = Node(Showing[Int](_.toString))
val res = meth(tree)
println(res.show(42)) // was: ClassCastException: class java.lang.Integer cannot be cast to class java.lang.String
15 changes: 15 additions & 0 deletions tests/neg/i14983.contra.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
sealed trait Show[-A]
final case class Pure[-B](showB: B => String) extends Show[B]
final case class Many[-C](showL: List[C] => String) extends Show[List[C]]

object Test:
def meth[X](show: Show[X]): X => String = show match
case Pure(showB) => showB
case Many(showL) =>
val res = (xs: List[String]) => xs.head.length.toString
res // error: Found: List[String] => String Required: X => String where: X is a type in method meth with bounds <: List[C$1]

def main(args: Array[String]): Unit =
val show = Many((is: List[Int]) => (is.head + 1).toString)
val fn = meth(show)
assert(fn(List(42)) == "43") // was: ClassCastException: class java.lang.Integer cannot be cast to class java.lang.String
22 changes: 22 additions & 0 deletions tests/neg/i14983.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
sealed trait Tree[+A]
final case class Leaf[+B](b: B) extends Tree[B]
final case class Node[+C](l: List[C]) extends Tree[List[C]]

// The original test case, minimised.
object Test:
def meth[X](tree: Tree[X]): X = tree match
case Leaf(v) => v // ok: Tree[X] vs Leaf[B], PTC: X >: B, max: Leaf[B] => Leaf[X], x: X
case Node(x) =>
// tree: Tree[X] vs Node[C] aka Tree[List[C]]
// PTC: X >: List[C]
// max: Node[C] => Node[Any], instantiating C := Any, which makes X >: List[Any]
// adapt: List[String] <: X = OKwithGADTUsed; insert GADT cast asInstanceOf[X]
List("boom") // error: Found: List[String] Required: X where: X is a type in method meth with bounds >: List[C$1]
// after fix:
// max: Node[C] => Node[C$1], instantiating C := C$1, a new symbol, so X >: List[C$1]
// adapt: List[String] <: X = Fail, because String !<: C$1

def main(args: Array[String]): Unit =
val tree = Node(List(42))
val res = meth(tree)
assert(res.head == 42) // was: ClassCastException: class java.lang.String cannot be cast to class java.lang.Integer
14 changes: 14 additions & 0 deletions tests/run/i14983.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
sealed trait Tree[+A]
final case class Leaf[+B](b: B) extends Tree[B]
final case class Node[+C](l: List[C]) extends Tree[List[C]]

// A version of the original test case that is sound so should typecheck.
object Test:
def meth[X](tree: Tree[X]): X = tree match
case Leaf(v) => v // ok: Tree[X] vs Leaf[B], PTC: X >: B, max: Leaf[B] => Leaf[X], x: X <:< X
case Node(x) => x // ok: Tree[X] vs Node[C], PTC: X >: List[C], max: Node[C] => Node[C$1], x: C$1 <:< X, w/ GADT cast

def main(args: Array[String]): Unit =
val tree = Node(List(42))
val res = meth(tree)
assert(res.head == 42) // ok

0 comments on commit 2ea400a

Please sign in to comment.