Skip to content

Commit

Permalink
Add terminated info
Browse files Browse the repository at this point in the history
  • Loading branch information
noti0na1 committed Dec 6, 2024
1 parent d44147b commit f859afe
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 34 deletions.
43 changes: 28 additions & 15 deletions compiler/src/dotty/tools/dotc/typer/Nullables.scala
Original file line number Diff line number Diff line change
Expand Up @@ -53,37 +53,45 @@ object Nullables:
TypeBoundsTree(lo, hiTree, alias)

/** A set of val or var references that are known to be not null
* after the tree finishes executing normally (non-exceptionally),
* after the tree finishes executing normally (non-exceptionally),
* plus a set of variable references that are ever assigned to null,
* and may therefore be null if execution of the tree is interrupted
* by an exception.
*/
case class NotNullInfo(asserted: Set[TermRef], retracted: Set[TermRef]):
case class NotNullInfo(asserted: Set[TermRef] | Null, retracted: Set[TermRef]):
def isEmpty = this eq NotNullInfo.empty

def retractedInfo = NotNullInfo(Set(), retracted)

def terminatedInfo = NotNullInfo(null, retracted)

/** The sequential combination with another not-null info */
def seq(that: NotNullInfo): NotNullInfo =
if this.isEmpty then that
else if that.isEmpty then this
else NotNullInfo(
this.asserted.diff(that.retracted).union(that.asserted),
this.retracted.union(that.retracted))
else
val newAsserted =
if this.asserted == null || that.asserted == null then null
else this.asserted.diff(that.retracted).union(that.asserted)
val newRetracted = this.retracted.union(that.retracted)
NotNullInfo(newAsserted, newRetracted)

/** The alternative path combination with another not-null info. Used to merge
* the nullability info of the two branches of an if.
* the nullability info of the branches of an if or match.
*/
def alt(that: NotNullInfo): NotNullInfo =
NotNullInfo(this.asserted.intersect(that.asserted), this.retracted.union(that.retracted))

def withRetracted(that: NotNullInfo): NotNullInfo =
NotNullInfo(this.asserted, this.retracted.union(that.retracted))
val newAsserted =
if this.asserted == null then that.asserted
else if that.asserted == null then this.asserted
else this.asserted.intersect(that.asserted)
val newRetracted = this.retracted.union(that.retracted)
NotNullInfo(newAsserted, newRetracted)
end NotNullInfo

object NotNullInfo:
val empty = new NotNullInfo(Set(), Set())
def apply(asserted: Set[TermRef], retracted: Set[TermRef]): NotNullInfo =
if asserted.isEmpty && retracted.isEmpty then empty
def apply(asserted: Set[TermRef] | Null, retracted: Set[TermRef]): NotNullInfo =
if asserted != null && asserted.isEmpty && retracted.isEmpty then empty
else new NotNullInfo(asserted, retracted)
end NotNullInfo

Expand Down Expand Up @@ -227,7 +235,7 @@ object Nullables:
*/
@tailrec def impliesNotNull(ref: TermRef): Boolean = infos match
case info :: infos1 =>
if info.asserted.contains(ref) then true
if info.asserted != null && info.asserted.contains(ref) then true
else if info.retracted.contains(ref) then false
else infos1.impliesNotNull(ref)
case _ =>
Expand All @@ -243,7 +251,9 @@ object Nullables:
/** Retract all references to mutable variables */
def retractMutables(using Context) =
val mutables = infos.foldLeft(Set[TermRef]()):
(ms, info) => ms.union(info.asserted.filter(_.symbol.is(Mutable)))
(ms, info) => ms.union(
if info.asserted == null then Set.empty
else info.asserted.filter(_.symbol.is(Mutable)))
infos.extendWith(NotNullInfo(Set(), mutables))

end extension
Expand Down Expand Up @@ -516,7 +526,10 @@ object Nullables:
&& assignmentSpans.getOrElse(sym.span.start, Nil).exists(whileSpan.contains(_))
&& ctx.notNullInfos.impliesNotNull(ref)

val retractedVars = ctx.notNullInfos.flatMap(_.asserted.filter(isRetracted)).toSet
val retractedVars = ctx.notNullInfos.flatMap(info =>
if info.asserted == null then Set.empty
else info.asserted.filter(isRetracted)
).toSet
ctx.addNotNullInfo(NotNullInfo(Set(), retractedVars))
end whileContext

Expand Down
27 changes: 8 additions & 19 deletions compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1551,13 +1551,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer

def thenPathInfo = cond1.notNullInfoIf(true).seq(result.thenp.notNullInfo)
def elsePathInfo = cond1.notNullInfoIf(false).seq(result.elsep.notNullInfo)
result.withNotNullInfo(
if result.thenp.tpe.isNothingType then
elsePathInfo.withRetracted(thenPathInfo)
else if result.elsep.tpe.isNothingType then
thenPathInfo.withRetracted(elsePathInfo)
else thenPathInfo.alt(elsePathInfo)
)
result.withNotNullInfo(thenPathInfo.alt(elsePathInfo))
end typedIf

/** Decompose function prototype into a list of parameter prototypes and a result
Expand Down Expand Up @@ -2154,14 +2148,9 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
}

private def notNullInfoFromCases(initInfo: NotNullInfo, cases: List[CaseDef])(using Context): NotNullInfo =
var nnInfo = initInfo
if cases.nonEmpty then
val (nothingCases, normalCases) = cases.partition(_.body.tpe.isNothingType)
nnInfo = nothingCases.foldLeft(nnInfo):
(nni, c) => nni.withRetracted(c.notNullInfo)
if normalCases.nonEmpty then
nnInfo = nnInfo.seq(normalCases.map(_.notNullInfo).reduce(_.alt(_)))
nnInfo
initInfo.seq(cases.map(_.notNullInfo).reduce(_.alt(_)))
else initInfo

def typedCases(cases: List[untpd.CaseDef], sel: Tree, wideSelType0: Type, pt: Type)(using Context): List[CaseDef] =
var caseCtx = ctx
Expand Down Expand Up @@ -2251,7 +2240,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
def typedLabeled(tree: untpd.Labeled)(using Context): Labeled = {
val bind1 = typedBind(tree.bind, WildcardType).asInstanceOf[Bind]
val expr1 = typed(tree.expr, bind1.symbol.info)
assignType(cpy.Labeled(tree)(bind1, expr1))
assignType(cpy.Labeled(tree)(bind1, expr1)).withNotNullInfo(expr1.notNullInfo.retractedInfo)
}

/** Type a case of a type match */
Expand Down Expand Up @@ -2301,7 +2290,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
// Hence no adaptation is possible, and we assume WildcardType as prototype.
(from, proto)
val expr1 = typedExpr(tree.expr orElse untpd.syntheticUnitLiteral.withSpan(tree.span), proto)
assignType(cpy.Return(tree)(expr1, from))
assignType(cpy.Return(tree)(expr1, from)).withNotNullInfo(expr1.notNullInfo.terminatedInfo)
end typedReturn

def typedWhileDo(tree: untpd.WhileDo)(using Context): Tree =
Expand Down Expand Up @@ -2388,15 +2377,15 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
def typedThrow(tree: untpd.Throw)(using Context): Tree =
val expr1 = typed(tree.expr, defn.ThrowableType)
val cap = checkCanThrow(expr1.tpe.widen, tree.span)
val res = Throw(expr1).withSpan(tree.span)
var res = Throw(expr1).withSpan(tree.span)
if Feature.ccEnabled && !cap.isEmpty && !ctx.isAfterTyper then
// Record access to the CanThrow capabulity recovered in `cap` by wrapping
// the type of the `throw` (i.e. Nothing) in a `@requiresCapability` annotation.
Typed(res,
res = Typed(res,
TypeTree(
AnnotatedType(res.tpe,
Annotation(defn.RequiresCapabilityAnnot, cap, tree.span))))
else res
res.withNotNullInfo(expr1.notNullInfo.terminatedInfo)

def typedSeqLiteral(tree: untpd.SeqLiteral, pt: Type)(using Context): SeqLiteral = {
val elemProto = pt.stripNull().elemType match {
Expand Down

0 comments on commit f859afe

Please sign in to comment.