Skip to content

Commit

Permalink
refactor box adaptation
Browse files Browse the repository at this point in the history
- special handling for the env created during box adaptation
- rewrite `adapt` to make it cleaner and easier to understand
Linyxus committed Oct 4, 2022
1 parent daf6766 commit 363f142
Showing 4 changed files with 106 additions and 109 deletions.
173 changes: 89 additions & 84 deletions compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala
Original file line number Diff line number Diff line change
@@ -42,12 +42,20 @@ object CheckCaptures:
end Pre

/** A class describing environments.
* @param owner the current owner
* @param captured the caputure set containing all references to tracked free variables outside of boxes
* @param isBoxed true if the environment is inside a box (in which case references are not counted)
* @param outer0 the next enclosing environment
* @param owner the current owner
* @param nestedInOwner true if the environment is a temporary one nested in the owner's environment,
* and does not have an actual owner symbol (this happens when doing box adaptation).
* @param captured the caputure set containing all references to tracked free variables outside of boxes
* @param isBoxed true if the environment is inside a box (in which case references are not counted)
* @param outer0 the next enclosing environment
*/
case class Env(owner: Symbol, captured: CaptureSet, isBoxed: Boolean, outer0: Env | Null):
case class Env(
owner: Symbol,
nestedInOwner: Boolean,
captured: CaptureSet,
isBoxed: Boolean,
outer0: Env | Null
):
def outer = outer0.nn

def isOutermost = outer0 == null
@@ -204,7 +212,7 @@ class CheckCaptures extends Recheck, SymTransformer:
report.error(i"$header included in allowed capture set ${res.blocking}", pos)

/** The current environment */
private var curEnv: Env = Env(NoSymbol, CaptureSet.empty, isBoxed = false, null)
private var curEnv: Env = Env(NoSymbol, false, CaptureSet.empty, isBoxed = false, null)

private val myCapturedVars: util.EqHashMap[Symbol, CaptureSet] = EqHashMap()

@@ -249,8 +257,12 @@ class CheckCaptures extends Recheck, SymTransformer:
if !cs.isAlwaysEmpty then
forallOuterEnvsUpTo(ctx.owner.topLevelClass) { env =>
val included = cs.filter {
case ref: TermRef => env.owner.isProperlyContainedIn(ref.symbol.owner)
case ref: ThisType => env.owner.isProperlyContainedIn(ref.cls)
case ref: TermRef =>
(env.nestedInOwner || env.owner != ref.symbol.owner)
&& env.owner.isContainedIn(ref.symbol.owner)
case ref: ThisType =>
(env.nestedInOwner || env.owner != ref.cls)
&& env.owner.isContainedIn(ref.cls)
case _ => false
}
capt.println(i"Include call capture $included in ${env.owner}")
@@ -439,7 +451,7 @@ class CheckCaptures extends Recheck, SymTransformer:
if !Synthetics.isExcluded(sym) then
val saved = curEnv
val localSet = capturedVars(sym)
if !localSet.isAlwaysEmpty then curEnv = Env(sym, localSet, isBoxed = false, curEnv)
if !localSet.isAlwaysEmpty then curEnv = Env(sym, false, localSet, isBoxed = false, curEnv)
try super.recheckDefDef(tree, sym)
finally
interpolateVarsIn(tree.tpt)
@@ -455,7 +467,7 @@ class CheckCaptures extends Recheck, SymTransformer:
val localSet = capturedVars(cls)
for parent <- impl.parents do // (1)
checkSubset(capturedVars(parent.tpe.classSymbol), localSet, parent.srcPos)
if !localSet.isAlwaysEmpty then curEnv = Env(cls, localSet, isBoxed = false, curEnv)
if !localSet.isAlwaysEmpty then curEnv = Env(cls, false, localSet, isBoxed = false, curEnv)
try
val thisSet = cls.classInfo.selfType.captureSet.withDescription(i"of the self type of $cls")
checkSubset(localSet, thisSet, tree.srcPos) // (2)
@@ -502,7 +514,7 @@ class CheckCaptures extends Recheck, SymTransformer:
override def recheck(tree: Tree, pt: Type = WildcardType)(using Context): Type =
if tree.isTerm && pt.isBoxedCapturing then
val saved = curEnv
curEnv = Env(curEnv.owner, CaptureSet.Var(), isBoxed = true, curEnv)
curEnv = Env(curEnv.owner, false, CaptureSet.Var(), isBoxed = true, curEnv)
try super.recheck(tree, pt)
finally curEnv = saved
else
@@ -595,12 +607,11 @@ class CheckCaptures extends Recheck, SymTransformer:
* to `expected` type.
* @param reconstruct how to rebuild the adapted function type
*/
def adaptFun(actualTp: (Type, CaptureSet), aargs: List[Type], ares: Type, expected: Type,
def adaptFun(actual: Type, aargs: List[Type], ares: Type, expected: Type,
covariant: Boolean, boxed: Boolean,
reconstruct: (List[Type], Type) => Type): (Type, CaptureSet) =
val (actual, cs0) = actualTp
val saved = curEnv
curEnv = Env(curEnv.owner, CaptureSet.Var(), isBoxed = false, if boxed then null else curEnv)
curEnv = Env(curEnv.owner, true, CaptureSet.Var(), isBoxed = false, if boxed then null else curEnv)

try
val (eargs, eres) = expected.dealias match
@@ -618,17 +629,16 @@ class CheckCaptures extends Recheck, SymTransformer:
else reconstruct(aargs1, ares1)

curEnv.captured.asVar.markSolved()
(resTp, curEnv.captured ++ cs0)
(resTp, curEnv.captured)
finally
curEnv = saved

def adaptTypeFun(
actualTp: (Type, CaptureSet), ares: Type, expected: Type,
actual: Type, ares: Type, expected: Type,
covariant: Boolean, boxed: Boolean,
reconstruct: Type => Type): (Type, CaptureSet) =
val (actual, cs0) = actualTp
val saved = curEnv
curEnv = Env(curEnv.owner, CaptureSet.Var(), isBoxed = false, if boxed then null else curEnv)
curEnv = Env(curEnv.owner, true, CaptureSet.Var(), isBoxed = false, if boxed then null else curEnv)

try
val eres = expected.dealias.stripCapturing match
@@ -642,7 +652,7 @@ class CheckCaptures extends Recheck, SymTransformer:
else reconstruct(ares1)

curEnv.captured.asVar.markSolved()
(resTp, curEnv.captured ++ cs0)
(resTp, curEnv.captured)
finally
curEnv = saved
end adaptTypeFun
@@ -651,8 +661,8 @@ class CheckCaptures extends Recheck, SymTransformer:
val arrow = if covariant then "~~>" else "<~~"
i"adapting $actual $arrow $expected"

def adapt(actual: Type, expected: Type, covariant: Boolean): Type = trace(adaptInfo(actual, expected, covariant), recheckr, show = true) {
def destructCapturingType(tp: Type, reconstruct: Type => Type): ((Type, CaptureSet, Boolean), Type => Type) = tp.dealias match
def destructCapturingType(tp: Type, reconstruct: Type => Type = x => x): ((Type, CaptureSet, Boolean), Type => Type) =
tp.dealias match
case tp @ CapturingType(parent, cs) =>
if parent.dealias.isCapturingType then
destructCapturingType(parent, res => reconstruct(tp.derivedCapturingType(res, cs)))
@@ -661,72 +671,67 @@ class CheckCaptures extends Recheck, SymTransformer:
case actual =>
((actual, CaptureSet(), false), reconstruct)

if expected.isInstanceOf[WildcardType] then
actual
def adapt(actual: Type, expected: Type, covariant: Boolean): Type = trace(adaptInfo(actual, expected, covariant), recheckr, show = true) {
if expected.isInstanceOf[WildcardType] then actual
else
val (actualTp, recon) = destructCapturingType(actual, x => x)
val (parent1, cs1, isBoxed1) = adaptCapturingType(actualTp, expected, covariant)
recon(CapturingType(parent1, cs1, isBoxed1))
}

def adaptCapturingType(
actual: (Type, CaptureSet, Boolean),
expected: Type,
covariant: Boolean
): (Type, CaptureSet, Boolean) =
val (parent, cs, actualIsBoxed) = actual

val needsAdaptation = actualIsBoxed != expected.isBoxedCapturing
val insertBox = needsAdaptation && covariant != actualIsBoxed

val (parent1, cs1) = parent match {
case actual @ AppliedType(tycon, args) if defn.isNonRefinedFunction(actual) =>
adaptFun((parent, cs), args.init, args.last, expected, covariant, insertBox,
(aargs1, ares1) => actual.derivedAppliedType(tycon, aargs1 :+ ares1))
case actual @ RefinedType(_, _, rinfo: MethodType) if defn.isFunctionType(actual) =>
// TODO Find a way to combine handling of generic and dependent function types (here and elsewhere)
adaptFun((parent, cs), rinfo.paramInfos, rinfo.resType, expected, covariant, insertBox,
(aargs1, ares1) =>
rinfo.derivedLambdaType(paramInfos = aargs1, resType = ares1)
.toFunctionType(isJava = false, alwaysDependent = true))
case actual: MethodType =>
adaptFun((parent, cs), actual.paramInfos, actual.resType, expected, covariant, insertBox,
(aargs1, ares1) =>
actual.derivedLambdaType(paramInfos = aargs1, resType = ares1))
case actual @ RefinedType(p, nme, rinfo: PolyType) if defn.isFunctionOrPolyType(actual) =>
adaptTypeFun((parent, cs), rinfo.resType, expected, covariant, insertBox,
ares1 =>
val rinfo1 = rinfo.derivedLambdaType(rinfo.paramNames, rinfo.paramInfos, ares1)
val actual1 = actual.derivedRefinedType(p, nme, rinfo1)
actual1
)
case _ =>
(parent, cs)
}
val ((parent, cs, actualIsBoxed), recon) = destructCapturingType(actual)

val needsAdaptation = actualIsBoxed != expected.isBoxedCapturing
val insertBox = needsAdaptation && covariant != actualIsBoxed

val (parent1, cs1) = parent match {
case actual @ AppliedType(tycon, args) if defn.isNonRefinedFunction(actual) =>
val (parent1, cs1) = adaptFun(parent, args.init, args.last, expected, covariant, insertBox,
(aargs1, ares1) => actual.derivedAppliedType(tycon, aargs1 :+ ares1))
(parent1, cs1 ++ cs)
case actual @ RefinedType(_, _, rinfo: MethodType) if defn.isFunctionType(actual) =>
// TODO Find a way to combine handling of generic and dependent function types (here and elsewhere)
val (parent1, cs1) = adaptFun(parent, rinfo.paramInfos, rinfo.resType, expected, covariant, insertBox,
(aargs1, ares1) =>
rinfo.derivedLambdaType(paramInfos = aargs1, resType = ares1)
.toFunctionType(isJava = false, alwaysDependent = true))
(parent1, cs1 ++ cs)
case actual: MethodType =>
val (parent1, cs1) = adaptFun(parent, actual.paramInfos, actual.resType, expected, covariant, insertBox,
(aargs1, ares1) =>
actual.derivedLambdaType(paramInfos = aargs1, resType = ares1))
(parent1, cs1 ++ cs)
case actual @ RefinedType(p, nme, rinfo: PolyType) if defn.isFunctionOrPolyType(actual) =>
val (parent1, cs1) = adaptTypeFun(parent, rinfo.resType, expected, covariant, insertBox,
ares1 =>
val rinfo1 = rinfo.derivedLambdaType(rinfo.paramNames, rinfo.paramInfos, ares1)
val actual1 = actual.derivedRefinedType(p, nme, rinfo1)
actual1
)
(parent1, cs1 ++ cs)
case _ =>
(parent, cs)
}

if needsAdaptation then
val criticalSet = // the set which is not allowed to have `*`
if covariant then cs1 // can't box with `*`
else expected.captureSet // can't unbox with `*`
if criticalSet.isUniversal then
// We can't box/unbox the universal capability. Leave `actual` as it is
// so we get an error in checkConforms. This tends to give better error
// messages than disallowing the root capability in `criticalSet`.
capt.println(i"cannot box/unbox $cs $parent vs $expected")
actual
if needsAdaptation then
val criticalSet = // the set which is not allowed to have `*`
if covariant then cs1 // can't box with `*`
else expected.captureSet // can't unbox with `*`
if criticalSet.isUniversal then
// We can't box/unbox the universal capability. Leave `actual` as it is
// so we get an error in checkConforms. This tends to give better error
// messages than disallowing the root capability in `criticalSet`.
capt.println(i"cannot box/unbox $actual vs $expected")
actual
else
// Disallow future addition of `*` to `criticalSet`.
criticalSet.disallowRootCapability { () =>
report.error(
em"""$actual cannot be box-converted to $expected
|since one of their capture sets contains the root capability `*`""",
pos)
}
if !insertBox then // unboxing
markFree(criticalSet, pos)
recon(CapturingType(parent1, cs1, !actualIsBoxed))
else
// Disallow future addition of `*` to `criticalSet`.
criticalSet.disallowRootCapability { () =>
report.error(
em"""$actualIsBoxed $cs $parent cannot be box-converted to $expected
|since one of their capture sets contains the root capability `*`""",
pos)
}
if !insertBox then // unboxing
markFree(cs1, pos)
(parent1, cs1, !actualIsBoxed)
else
(parent1, cs1, actualIsBoxed)
recon(CapturingType(parent1, cs1, actualIsBoxed))
}


var actualw = actual.widenDealias
4 changes: 2 additions & 2 deletions tests/neg-custom-args/captures/capt1.check
Original file line number Diff line number Diff line change
@@ -40,14 +40,14 @@
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/capt1.scala:32:24 ----------------------------------------
32 | val z2 = h[() -> Cap](() => x) // error
| ^^^^^^^
| Found: {x} () -> {*} C
| Found: {x} () -> Cap
| Required: () -> box {*} C
|
| longer explanation available when compiling with `-explain`
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/capt1.scala:33:5 -----------------------------------------
33 | (() => C()) // error
| ^^^^^^^^^
| Found: ? () -> {*} C
| Found: ? () -> Cap
| Required: () -> box {*} C
|
| longer explanation available when compiling with `-explain`
28 changes: 10 additions & 18 deletions tests/neg-custom-args/captures/i15772.check
Original file line number Diff line number Diff line change
@@ -1,23 +1,15 @@
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/i15772.scala:18:2 ----------------------------------------
18 | () => // error
| ^
| Found: {x} () -> Int
| Required: () -> Int
19 | val c : {x} C = new C(x)
20 | val boxed1 : (({*} C) => Unit) -> Unit = box1(c)
21 | boxed1((cap: {*} C) => unsafe(c))
22 | 0
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/i15772.scala:20:49 ---------------------------------------
20 | val boxed1 : (({*} C) => Unit) -> Unit = box1(c) // error
| ^^^^^^^
| Found: {c} ({*} ({c} C{arg: {*} C}) -> Unit) -> Unit
| Required: (({*} C) => Unit) -> Unit
|
| longer explanation available when compiling with `-explain`
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/i15772.scala:25:2 ----------------------------------------
25 | () => // error
| ^
| Found: {x} () -> Int
| Required: () -> Int
26 | val c : {x} C = new C(x)
27 | val boxed2 : Observe[{*} C] = box2(c)
28 | boxed2((cap: {*} C) => unsafe(c))
29 | 0
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/i15772.scala:27:38 ---------------------------------------
27 | val boxed2 : Observe[{*} C] = box2(c) // error
| ^^^^^^^
| Found: {c} ({*} ({c} C{arg: {*} C}) -> Unit) -> Unit
| Required: Observe[{*} C]
|
| longer explanation available when compiling with `-explain`
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/i15772.scala:33:37 ---------------------------------------
10 changes: 5 additions & 5 deletions tests/neg-custom-args/captures/i15772.scala
Original file line number Diff line number Diff line change
@@ -15,16 +15,16 @@ class C(val arg: {*} C) {
}

def main1(x: {*} C) : () -> Int =
() => // error
() =>
val c : {x} C = new C(x)
val boxed1 : (({*} C) => Unit) -> Unit = box1(c)
val boxed1 : (({*} C) => Unit) -> Unit = box1(c) // error
boxed1((cap: {*} C) => unsafe(c))
0

def main2(x: {*} C) : () -> Int =
() => // error
() =>
val c : {x} C = new C(x)
val boxed2 : Observe[{*} C] = box2(c)
val boxed2 : Observe[{*} C] = box2(c) // error
boxed2((cap: {*} C) => unsafe(c))
0

@@ -41,4 +41,4 @@ def main(io: {*} Any) =
val sayHello: (({io} File) => Unit) = (file: {io} File) => file.write("Hello World!\r\n")
val filesList : List[{io} File] = ???
val x = () => filesList.foreach(sayHello)
x: (() -> Unit) // error
x: (() -> Unit) // error

0 comments on commit 363f142

Please sign in to comment.