Skip to content

Commit

Permalink
[internal] scala: extract provided names from source files (pantsbuil…
Browse files Browse the repository at this point in the history
…d#13537)

Extract and report on names provided by a Scala source file. This will be used for dependency inference once the "consumed types" side of the analysis is available.
  • Loading branch information
Tom Dyas authored Nov 9, 2021
1 parent c1128d7 commit fd5c30e
Show file tree
Hide file tree
Showing 3 changed files with 189 additions and 23 deletions.
Original file line number Diff line number Diff line change
@@ -1,16 +1,91 @@
package org.pantsbuild.backend.scala.dependency_inference

import io.circe._, io.circe.generic.auto._, io.circe.syntax._
//import io.circe._
//import io.circe.generic.auto._
//import io.circe.syntax._
import io.circe._
import io.circe.generic.auto._
import io.circe.syntax._

import scala.meta._
import scala.meta.transversers.Traverser

import scala.collection.mutable.ArrayBuffer

case class Analysis(
`package`: String
providedNames: Vector[String]
)

class ProvidedTypesTraverser extends Traverser {
val nameParts = ArrayBuffer[String]()
val providedNames = ArrayBuffer[String]()

// Extract a qualified name from a tree.
def extractName(tree: Tree): String = {
tree match {
case Term.Select(qual, name) => s"${extractName(qual)}.${extractName(name)}"
case Term.Name(name) => name
case Type.Name(name) => name
case Pat.Var(node) => extractName(node)
case _ => ""
}
}

def recordProvidedName(name: String): Unit = {
val fullPackageName = nameParts.mkString(".")
providedNames.append(s"${fullPackageName}.${name}")
}

def withNamePart[T](namePart: String, f: () => T): T = {
nameParts.append(namePart)
val result = f()
nameParts.remove(nameParts.length - 1)
result
}

override def apply(tree: Tree): Unit = tree match {
case Pkg(ref, stats) => {
withNamePart(extractName(ref), () => super.apply(stats))
}

case Defn.Class(_mods, nameNode, _tparams, _ctor, templ) => {
val name = extractName(nameNode)
recordProvidedName(name)
withNamePart(name, () => super.apply(templ))
}

case Defn.Trait(_mods, nameNode, _tparams, _ctor, templ) => {
val name = extractName(nameNode)
recordProvidedName(name)
withNamePart(name, () => super.apply(templ))
}

case Defn.Object(_mods, nameNode, templ) => {
val name = extractName(nameNode)
recordProvidedName(name)
withNamePart(name, () => super.apply(templ))
}

case Defn.Type(_mods, nameNode, _tparams, _body) => {
val name = extractName(nameNode)
recordProvidedName(name)
}

case Defn.Val(_mods, pats, _decltpe, _rhs) => {
pats.headOption.foreach(pat => {
val name = extractName(pat)
recordProvidedName(name)
})
}

case Defn.Var(_mods, pats, _decltpe, _rhs) => {
pats.headOption.foreach(pat => {
val name = extractName(pat)
recordProvidedName(name)
})
}

case node => super.apply(node)
}
}

object ScalaParser {
def analyze(pathStr: String): Analysis = {
val path = java.nio.file.Paths.get(pathStr)
Expand All @@ -20,15 +95,18 @@ object ScalaParser {

val tree = input.parse[Source].get

// TODO: Actually pare out the package (and other fields).
Analysis("foo")
val providedNamesTraverser = new ProvidedTypesTraverser()
providedNamesTraverser.apply(tree)

Analysis(providedNames = providedNamesTraverser.providedNames.toVector)
}

def main(args: Array[String]): Unit = {
val analysis = analyze(args(0))
val outputPath = java.nio.file.Paths.get(args(0))
val analysis = analyze(args(1))

val json = analysis.asJson.noSpaces
// TODO: Write to file specified by the caler.
println(json)
java.nio.file.Files.write(outputPath, json.getBytes(),
java.nio.file.StandardOpenOption.CREATE_NEW, java.nio.file.StandardOpenOption.WRITE)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Licensed under the Apache License, Version 2.0 (see LICENSE).
from __future__ import annotations

import json
import logging
import os
import pkgutil
Expand All @@ -12,13 +13,20 @@
AddPrefix,
CreateDigest,
Digest,
DigestContents,
Directory,
FileContent,
MergeDigests,
RemovePrefix,
)
from pants.engine.internals.selectors import Get, MultiGet
from pants.engine.process import BashBinary, FallibleProcessResult, Process, ProcessResult
from pants.engine.process import (
BashBinary,
FallibleProcessResult,
Process,
ProcessExecutionFailure,
ProcessResult,
)
from pants.engine.rules import collect_rules, rule
from pants.jvm.compile import ClasspathEntry
from pants.jvm.jdk_rules import JdkSetup
Expand All @@ -28,7 +36,9 @@
MaterializedClasspath,
MaterializedClasspathRequest,
)
from pants.option.global_options import GlobalOptions
from pants.util.logging import LogLevel
from pants.util.ordered_set import FrozenOrderedSet

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -83,12 +93,12 @@

@dataclass(frozen=True)
class ScalaSourceDependencyAnalysis:
package: str
provided_names: FrozenOrderedSet[str]

@classmethod
def from_json_dict(cls, d: dict) -> ScalaSourceDependencyAnalysis:
return cls(
package=d["package"],
provided_names=FrozenOrderedSet(d["providedNames"]),
)


Expand Down Expand Up @@ -164,7 +174,7 @@ async def analyze_scala_source_dependencies(
argv=[
*jdk_setup.args(bash, [*tool_classpath.classpath_entries(), processorcp_relpath]),
"org.pantsbuild.backend.scala.dependency_inference.ScalaParser",
# analysis_output_path,
analysis_output_path,
source_path,
],
input_digest=merged_digest,
Expand All @@ -180,6 +190,28 @@ async def analyze_scala_source_dependencies(
return FallibleScalaSourceDependencyAnalysisResult(process_result=process_result)


@rule(level=LogLevel.DEBUG)
async def resolve_fallible_result_to_analysis(
fallible_result: FallibleScalaSourceDependencyAnalysisResult,
global_options: GlobalOptions,
) -> ScalaSourceDependencyAnalysis:
# TODO(#12725): Just convert directly to a ProcessResult like this:
# result = await Get(ProcessResult, FallibleProcessResult, fallible_result.process_result)
if fallible_result.process_result.exit_code == 0:
analysis_contents = await Get(
DigestContents, Digest, fallible_result.process_result.output_digest
)
analysis = json.loads(analysis_contents[0].content)
return ScalaSourceDependencyAnalysis.from_json_dict(analysis)
raise ProcessExecutionFailure(
fallible_result.process_result.exit_code,
fallible_result.process_result.stdout,
fallible_result.process_result.stderr,
"Scala source dependency analysis failed.",
local_cleanup=global_options.options.process_execution_local_cleanup,
)


@rule
async def setup_scala_parser_classfiles(
bash: BashBinary, jdk_setup: JdkSetup
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@

from pants.backend.scala import target_types
from pants.backend.scala.dependency_inference import scala_parser
from pants.backend.scala.dependency_inference.scala_parser import (
FallibleScalaSourceDependencyAnalysisResult,
)
from pants.backend.scala.dependency_inference.scala_parser import ScalaSourceDependencyAnalysis
from pants.backend.scala.target_types import ScalaSourceField, ScalaSourceTarget
from pants.build_graph.address import Address
from pants.core.util_rules import source_files
Expand All @@ -21,6 +19,7 @@
from pants.jvm.resolve.coursier_setup import rules as coursier_setup_rules
from pants.jvm.target_types import JvmDependencyLockfile
from pants.testutil.rule_runner import PYTHON_BOOTSTRAP_ENV, QueryRule, RuleRunner
from pants.util.ordered_set import FrozenOrderedSet


@pytest.fixture
Expand All @@ -37,7 +36,7 @@ def rule_runner() -> RuleRunner:
*target_types.rules(),
*jvm_util_rules.rules(),
QueryRule(SourceFiles, (SourceFilesRequest,)),
QueryRule(FallibleScalaSourceDependencyAnalysisResult, (SourceFiles,)),
QueryRule(ScalaSourceDependencyAnalysis, (SourceFiles,)),
],
target_types=[JvmDependencyLockfile, ScalaSourceTarget],
)
Expand All @@ -58,9 +57,43 @@ def test_parser_simple(rule_runner: RuleRunner) -> None:
),
"SimpleSource.scala": textwrap.dedent(
"""
package org.pantsbuild.example
package org.pantsbuild
package example
class Foo {
class OuterClass {
val NestedVal = 3
var NestedVar = "foo"
trait NestedTrait {
}
class NestedClass {
}
type NestedType = Foo
object NestedObject {
}
}
trait OuterTrait {
val NestedVal = 3
var NestedVar = "foo"
trait NestedTrait {
}
class NestedClass {
}
type NestedType = Foo
object NestedObject {
}
}
object OuterObject {
val NestedVal = 3
var NestedVar = "foo"
trait NestedTrait {
}
class NestedClass {
}
type NestedType = Foo
object NestedObject {
}
}
"""
),
Expand All @@ -81,9 +114,32 @@ class Foo {
)

analysis = rule_runner.request(
FallibleScalaSourceDependencyAnalysisResult,
ScalaSourceDependencyAnalysis,
[source_files],
)

assert analysis.process_result.exit_code == 0
assert analysis.process_result.stdout == b"""{"package":"foo"}\n"""
assert analysis.provided_names == FrozenOrderedSet(
[
"org.pantsbuild.example.OuterClass",
"org.pantsbuild.example.OuterClass.NestedVal",
"org.pantsbuild.example.OuterClass.NestedVar",
"org.pantsbuild.example.OuterClass.NestedTrait",
"org.pantsbuild.example.OuterClass.NestedClass",
"org.pantsbuild.example.OuterClass.NestedType",
"org.pantsbuild.example.OuterClass.NestedObject",
"org.pantsbuild.example.OuterTrait",
"org.pantsbuild.example.OuterTrait.NestedVal",
"org.pantsbuild.example.OuterTrait.NestedVar",
"org.pantsbuild.example.OuterTrait.NestedTrait",
"org.pantsbuild.example.OuterTrait.NestedClass",
"org.pantsbuild.example.OuterTrait.NestedType",
"org.pantsbuild.example.OuterTrait.NestedObject",
"org.pantsbuild.example.OuterObject",
"org.pantsbuild.example.OuterObject.NestedVal",
"org.pantsbuild.example.OuterObject.NestedVar",
"org.pantsbuild.example.OuterObject.NestedTrait",
"org.pantsbuild.example.OuterObject.NestedClass",
"org.pantsbuild.example.OuterObject.NestedType",
"org.pantsbuild.example.OuterObject.NestedObject",
]
)

0 comments on commit fd5c30e

Please sign in to comment.