Skip to content

Commit

Permalink
Resolve properly provided symbol names for types defined at top level…
Browse files Browse the repository at this point in the history
… package (pantsbuild#16690)
  • Loading branch information
alonsodomin authored Aug 30, 2022
1 parent dc973fe commit 6cd1147
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import io.circe.syntax._
import scala.meta._
import scala.meta.transversers.Traverser

import scala.collection.SortedSet
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
import scala.reflect.NameTransformer

Expand All @@ -25,14 +26,15 @@ case class AnImport(
)

case class Analysis(
providedSymbols: Vector[Analysis.ProvidedSymbol],
providedSymbolsEncoded: Vector[Analysis.ProvidedSymbol],
providedSymbols: SortedSet[Analysis.ProvidedSymbol],
providedSymbolsEncoded: SortedSet[Analysis.ProvidedSymbol],
importsByScope: HashMap[String, ArrayBuffer[AnImport]],
consumedSymbolsByScope: HashMap[String, HashSet[String]],
scopes: Vector[String]
)
object Analysis {
case class ProvidedSymbol(name: String, recursive: Boolean)
implicit val providedSymbolOrdering: Ordering[ProvidedSymbol] = Ordering.by(_.name)
}

case class ProvidedSymbol(
Expand Down Expand Up @@ -80,7 +82,7 @@ class SourceAnalysisTraverser extends Traverser {
case Type.Name(name) => Vector(name)
case Type.Select(qual, Type.Name(name)) => {
val qualName = extractName(qual)
Vector(s"${qualName}.${name}")
Vector(qualifyName(qualName, name))
}
case Type.Apply(tpe, args) =>
extractNamesFromTypeTree(tpe) ++ args.toVector.flatMap(extractNamesFromTypeTree(_))
Expand Down Expand Up @@ -404,17 +406,17 @@ class SourceAnalysisTraverser extends Traverser {
case node => super.apply(node)
}

def gatherProvidedSymbols(): Vector[Analysis.ProvidedSymbol] = {
def gatherProvidedSymbols(): SortedSet[Analysis.ProvidedSymbol] = {
providedSymbolsByScope
.flatMap({ case (scopeName, symbolsForScope) =>
symbolsForScope.map { case (symbolName, symbol) =>
Analysis.ProvidedSymbol(s"${scopeName}.${symbolName}", symbol.recursive)
Analysis.ProvidedSymbol(qualifyName(scopeName, symbolName), symbol.recursive)
}.toVector
})
.toVector
.to(SortedSet)
}

def gatherEncodedProvidedSymbols(): Vector[Analysis.ProvidedSymbol] = {
def gatherEncodedProvidedSymbols(): SortedSet[Analysis.ProvidedSymbol] = {
providedSymbolsByScope
.flatMap({ case (scopeName, symbolsForScope) =>
val encodedSymbolsForScope = symbolsForScope.flatMap({
Expand All @@ -433,9 +435,9 @@ class SourceAnalysisTraverser extends Traverser {
}
})

encodedSymbolsForScope.map(symbol => symbol.copy(name = s"${scopeName}.${symbol.name}"))
encodedSymbolsForScope.map(symbol => symbol.copy(name = qualifyName(scopeName, symbol.name)))
})
.toVector
.to(SortedSet)
}

def toAnalysis: Analysis = {
Expand All @@ -447,6 +449,11 @@ class SourceAnalysisTraverser extends Traverser {
scopes = scopes.toVector
)
}

private def qualifyName(qualifier: String, name: String): String = {
if (qualifier.length > 0) s"$qualifier.$name"
else name
}
}

object ScalaParser {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -595,6 +595,40 @@ def test_package_object_extends_trait(rule_runner: RuleRunner) -> None:
assert sorted(analysis.fully_qualified_consumed_symbols()) == ["foo.Trait", "foo.bar.Trait"]


def test_types_at_toplevel_package(rule_runner: RuleRunner) -> None:
analysis = _analyze(
rule_runner,
textwrap.dedent(
"""\
trait Foo
class Bar
object Quxx
"""
),
)

expected_symbols = [
ScalaProvidedSymbol("Foo", False),
ScalaProvidedSymbol("Bar", False),
ScalaProvidedSymbol("Quxx", False),
]

expected_symbols_encoded = expected_symbols.copy()
expected_symbols_encoded.extend(
[ScalaProvidedSymbol("Quxx$", False), ScalaProvidedSymbol("Quxx$.MODULE$", False)]
)

def by_name(symbol: ScalaProvidedSymbol) -> str:
return symbol.name

assert analysis.provided_symbols == FrozenOrderedSet(sorted(expected_symbols, key=by_name))
assert analysis.provided_symbols_encoded == FrozenOrderedSet(
sorted(expected_symbols_encoded, key=by_name)
)


def test_type_constaint(rule_runner: RuleRunner) -> None:
analysis = _analyze(
rule_runner,
Expand Down

0 comments on commit 6cd1147

Please sign in to comment.