Skip to content

Commit

Permalink
Refactor NotNullInfo to record every reference which is retracted once.
Browse files Browse the repository at this point in the history
  • Loading branch information
noti0na1 committed Dec 6, 2024
1 parent c61897d commit 585dda9
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 15 deletions.
32 changes: 23 additions & 9 deletions compiler/src/dotty/tools/dotc/typer/Nullables.scala
Original file line number Diff line number Diff line change
Expand Up @@ -52,35 +52,49 @@ object Nullables:
val hiTree = if(hiTpe eq hi.typeOpt) hi else TypeTree(hiTpe)
TypeBoundsTree(lo, hiTree, alias)

/** A set of val or var references that are known to be not null, plus a set of
* variable references that are not known (anymore) to be not null
/** A set of val or var references that are known to be not null,
* a set of variable references that are not known (anymore) to be not null,
* plus a set of variables that are known to be not null at any point.
*/
case class NotNullInfo(asserted: Set[TermRef], retracted: Set[TermRef]):
case class NotNullInfo(asserted: Set[TermRef], retracted: Set[TermRef], onceRetracted: Set[TermRef]):
assert((asserted & retracted).isEmpty)
assert(retracted.subsetOf(onceRetracted))

def isEmpty = this eq NotNullInfo.empty

def retractedInfo = NotNullInfo(Set(), retracted)
def retractedInfo = NotNullInfo(Set(), retracted, onceRetracted)

def onceRetractedInfo = NotNullInfo(Set(), onceRetracted, onceRetracted)

/** The sequential combination with another not-null info */
def seq(that: NotNullInfo): NotNullInfo =
if this.isEmpty then that
else if that.isEmpty then this
else NotNullInfo(
this.asserted.union(that.asserted).diff(that.retracted),
this.retracted.union(that.retracted).diff(that.asserted))
this.retracted.union(that.retracted).diff(that.asserted),
this.onceRetracted.union(that.onceRetracted))

/** The alternative path combination with another not-null info. Used to merge
* the nullability info of the two branches of an if.
*/
def alt(that: NotNullInfo): NotNullInfo =
NotNullInfo(this.asserted.intersect(that.asserted), this.retracted.union(that.retracted))
NotNullInfo(
this.asserted.intersect(that.asserted),
this.retracted.union(that.retracted),
this.onceRetracted.union(that.onceRetracted))

def withOnceRetracted(that: NotNullInfo): NotNullInfo =
if that.isEmpty then this
else NotNullInfo(this.asserted, this.retracted, this.onceRetracted.union(that.onceRetracted))

object NotNullInfo:
val empty = new NotNullInfo(Set(), Set())
val empty = new NotNullInfo(Set(), Set(), Set())
def apply(asserted: Set[TermRef], retracted: Set[TermRef]): NotNullInfo =
if asserted.isEmpty && retracted.isEmpty then empty
else new NotNullInfo(asserted, retracted)
apply(asserted, retracted, retracted)
def apply(asserted: Set[TermRef], retracted: Set[TermRef], onceRetracted: Set[TermRef]): NotNullInfo =
if asserted.isEmpty && onceRetracted.isEmpty then empty
else new NotNullInfo(asserted, retracted, onceRetracted)
end NotNullInfo

/** A pair of not-null sets, depending on whether a condition is `true` or `false` */
Expand Down
15 changes: 12 additions & 3 deletions compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1552,8 +1552,10 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
def thenPathInfo = cond1.notNullInfoIf(true).seq(result.thenp.notNullInfo)
def elsePathInfo = cond1.notNullInfoIf(false).seq(result.elsep.notNullInfo)
result.withNotNullInfo(
if result.thenp.tpe.isRef(defn.NothingClass) then elsePathInfo
else if result.elsep.tpe.isRef(defn.NothingClass) then thenPathInfo
if result.thenp.tpe.isRef(defn.NothingClass) then
elsePathInfo.withOnceRetracted(thenPathInfo)
else if result.elsep.tpe.isRef(defn.NothingClass) then
thenPathInfo.withOnceRetracted(elsePathInfo)
else thenPathInfo.alt(elsePathInfo)
)
end typedIf
Expand Down Expand Up @@ -2350,10 +2352,17 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
}: @unchecked
val cases2 = cases2x.asInstanceOf[List[CaseDef]]

var nni = expr2.notNullInfo.retractedInfo
// Since we don't know at which point the the exception is thrown in the body,
// we have to collect any reference that is once retracted.
var nni = expr2.notNullInfo.onceRetractedInfo
// It is possible to have non-exhaustive cases, and some exceptions are thrown and not caught.
// Therefore, the code in the finallizer and after the try block can only rely on the retracted
// info from the cases' body.
if cases2.nonEmpty then nni = nni.seq(cases2.map(_.notNullInfo.retractedInfo).reduce(_.alt(_)))

val finalizer1 = typed(tree.finalizer, defn.UnitType)(using ctx.addNotNullInfo(nni))
nni = nni.seq(finalizer1.notNullInfo)

assignType(cpy.Try(tree)(expr2, cases2, finalizer1), expr2, cases2).withNotNullInfo(nni)
}

Expand Down
6 changes: 3 additions & 3 deletions tests/explicit-nulls/neg/i21380c.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ def test4: Int =
case npe: NullPointerException => x = ""
case _ => x = ""
x.length // error
// Although the catch block here is exhaustive,
// it is possible that the exception is thrown and not caught.
// Therefore, the code after the try block can only rely on the retracted info.
// Although the catch block here is exhaustive, it is possible to have non-exhaustive cases,
// and some exceptions are thrown and not caught. Therefore, the code in the finallizer and
// after the try block can only rely on the retracted info from the cases' body.

def test5: Int =
var x: String | Null = null
Expand Down
62 changes: 62 additions & 0 deletions tests/explicit-nulls/neg/i21619.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
def test1: String =
var x: String | Null = null
x = ""
var i: Int = 1
try
i match
case _ =>
x = null
throw new Exception()
x = ""
catch
case e: Exception =>
x.replace("", "") // error

def test2: String =
var x: String | Null = null
x = ""
var i: Int = 1
try
i match
case _ =>
x = null
throw new Exception()
x = ""
catch
case e: Exception =>
x = "e"
x.replace("", "") // error

def test3: String =
var x: String | Null = null
x = ""
var i: Int = 1
try
i match
case _ =>
x = null
throw new Exception()
x = ""
catch
case e: Exception =>
finally
x = "f"
x.replace("", "") // ok

def test4: String =
var x: String | Null = null
x = ""
var i: Int = 1
try
try
if i == 1 then
x = null
throw new Exception()
else
x = ""
catch
case _ =>
x = ""
catch
case _ =>
x.replace("", "") // error

0 comments on commit 585dda9

Please sign in to comment.