Skip to content

Commit

Permalink
[jvm] Hook up the ability to run JUnit tests written in Scala (pantsb…
Browse files Browse the repository at this point in the history
…uild#13868)

The `scala_junit_tests` target was added in pantsbuild#13130, and became compilable from the JUnit `@rules` in pantsbuild#13519. The final piece to enable actual usage is to allow the JUnit `@rules` to match non-Java sources.

This change adds the `JunitTestSourceField` marker, which is extended by `JavaJunitTestSourceField` and `ScalaJunitTestSourceField`, and tests that JUnit tests can be written in Scala.

[ci skip-rust]
  • Loading branch information
stuhood authored Dec 13, 2021
1 parent d9bf077 commit b99c9e5
Show file tree
Hide file tree
Showing 13 changed files with 106 additions and 45 deletions.
2 changes: 1 addition & 1 deletion pants.toml
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ java_parser = "src/python/pants/backend/java/dependency_inference/java_parser.lo
scala_parser = "src/python/pants/backend/scala/dependency_inference/scala_parser.lockfile"

[junit]
lockfile = "src/python/pants/backend/java/test/junit.default.lockfile.txt"
lockfile = "src/python/pants/jvm/junit.default.lockfile.txt"

[google-java-format]
lockfile = "src/python/pants/backend/java/lint/google_java_format/google_java_format.default.lockfile.txt"
Expand Down
2 changes: 1 addition & 1 deletion src/python/pants/backend/experimental/java/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@
JunitTestTarget,
)
from pants.backend.java.target_types import rules as target_types_rules
from pants.backend.java.test import junit
from pants.jvm import classpath, jdk_rules
from pants.jvm import util_rules as jvm_util_rules
from pants.jvm.dependency_inference import symbol_mapper
from pants.jvm.goals import coursier
from pants.jvm.resolve import coursier_fetch, coursier_setup, jvm_tool
from pants.jvm.target_types import JvmArtifact
from pants.jvm.test import junit


def target_types():
Expand Down
10 changes: 2 additions & 8 deletions src/python/pants/backend/experimental/scala/register.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
# Copyright 2021 Pants project contributors (see CONTRIBUTORS.md).
# Licensed under the Apache License, Version 2.0 (see LICENSE).
from pants.backend.java.package import deploy_jar # TODO: Should move to the JVM package.
from pants.backend.java.target_types import ( # TODO: All of these should move to the JVM package.
DeployJar,
JunitTestsGeneratorTarget,
JunitTestTarget,
)
from pants.backend.java.test import junit # TODO: Should move to the JVM package.
from pants.backend.java.target_types import DeployJar # TODO: Should move to the JVM package.
from pants.backend.scala.compile import scalac
from pants.backend.scala.dependency_inference import rules as dep_inf_rules
from pants.backend.scala.goals import check, repl, tailor
Expand All @@ -22,13 +17,12 @@
from pants.jvm.goals import coursier
from pants.jvm.resolve import coursier_fetch, coursier_setup, jvm_tool
from pants.jvm.target_types import JvmArtifact
from pants.jvm.test import junit


def target_types():
return [
DeployJar,
JunitTestTarget,
JunitTestsGeneratorTarget,
JvmArtifact,
ScalaJunitTestTarget,
ScalaJunitTestsGeneratorTarget,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
JunitTestsGeneratorTarget,
)
from pants.backend.java.target_types import rules as java_target_rules
from pants.backend.java.test.junit import rules as junit_rules
from pants.core.util_rules import config_files, source_files
from pants.core.util_rules.external_tool import rules as external_tool_rules
from pants.engine.addresses import Address, Addresses, UnparsedAddressInputs
Expand All @@ -33,6 +32,7 @@
from pants.jvm.resolve.coursier_fetch import rules as coursier_fetch_rules
from pants.jvm.resolve.coursier_setup import rules as coursier_setup_rules
from pants.jvm.target_types import JvmArtifact
from pants.jvm.test.junit import rules as junit_rules
from pants.jvm.testutil import maybe_skip_jdk_test
from pants.jvm.util_rules import rules as util_rules
from pants.testutil.rule_runner import PYTHON_BOOTSTRAP_ENV, QueryRule, RuleRunner
Expand Down
4 changes: 2 additions & 2 deletions src/python/pants/backend/java/subsystems/junit.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ class JUnit(JvmToolBase):
"org.junit.jupiter:junit-jupiter-engine:{version}",
"org.junit.vintage:junit-vintage-engine:{version}",
)
default_lockfile_resource = ("pants.backend.java.test", "junit.default.lockfile.txt")
default_lockfile_url = git_url("src/python/pants/backend/java/test/junit.default.lockfile.txt")
default_lockfile_resource = ("pants.jvm.test", "junit.default.lockfile.txt")
default_lockfile_url = git_url("src/python/pants/jvm/test/junit.default.lockfile.txt")

@classmethod
def register_options(cls, register):
Expand Down
7 changes: 4 additions & 3 deletions src/python/pants/backend/java/target_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
)
from pants.engine.unions import UnionMembership, UnionRule
from pants.jvm.target_types import (
JunitTestSourceField,
JvmCompatibleResolveNamesField,
JvmProvidesTypesField,
JvmResolveNameField,
Expand Down Expand Up @@ -56,15 +57,15 @@ class JavaGeneratorFieldSet(FieldSet):
# -----------------------------------------------------------------------------------------------


class JavaTestSourceField(JavaSourceField):
pass
class JavaJunitTestSourceField(JavaSourceField, JunitTestSourceField):
"""A JUnit test file written in Java."""


class JunitTestTarget(Target):
alias = "junit_test"
core_fields = (
*COMMON_TARGET_FIELDS,
JavaTestSourceField,
JavaJunitTestSourceField,
Dependencies,
JvmCompatibleResolveNamesField,
JvmProvidesTypesField,
Expand Down
10 changes: 7 additions & 3 deletions src/python/pants/backend/scala/target_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@
generate_file_level_targets,
)
from pants.engine.unions import UnionMembership, UnionRule
from pants.jvm.target_types import JvmCompatibleResolveNamesField, JvmProvidesTypesField
from pants.jvm.target_types import (
JunitTestSourceField,
JvmCompatibleResolveNamesField,
JvmProvidesTypesField,
)


class ScalaSourceField(SingleSourceField):
Expand All @@ -47,11 +51,11 @@ class ScalaGeneratorFieldSet(FieldSet):


# -----------------------------------------------------------------------------------------------
# `junit_test` target
# `scala_junit_test` target
# -----------------------------------------------------------------------------------------------


class ScalaTestSourceField(ScalaSourceField):
class ScalaTestSourceField(ScalaSourceField, JunitTestSourceField):
pass


Expand Down
30 changes: 20 additions & 10 deletions src/python/pants/jvm/target_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,22 @@

from __future__ import annotations

from abc import ABCMeta

from pants.engine.target import (
COMMON_TARGET_FIELDS,
FieldSet,
SingleSourceField,
SpecialCasedDependencies,
StringField,
StringSequenceField,
Target,
)
from pants.util.docutil import git_url

# -----------------------------------------------------------------------------------------------
# `jvm_artifact` targets
# -----------------------------------------------------------------------------------------------

_DEFAULT_PACKAGE_MAPPING_URL = git_url(
"src/python/pants/jvm/dependency_inference/jvm_artifact_mappings.py"
)
Expand Down Expand Up @@ -123,6 +128,20 @@ class JvmArtifact(Target):
)


# -----------------------------------------------------------------------------------------------
# JUnit test support field(s)
# -----------------------------------------------------------------------------------------------


class JunitTestSourceField(SingleSourceField, metaclass=ABCMeta):
"""A marker that indicates that a source field represents a JUnit test."""


# -----------------------------------------------------------------------------------------------
# Generic resolve support fields
# -----------------------------------------------------------------------------------------------


class JvmCompatibleResolveNamesField(StringSequenceField):
alias = "compatible_resolves"
required = False
Expand All @@ -141,12 +160,3 @@ class JvmResolveNameField(StringField):
"one of the resolves in `--jvm-resolves`. If not supplied, the default resolve will be "
"used, otherwise, one resolve that is compatible with all dependency targets will be used."
)


class JvmRequirementsField(SpecialCasedDependencies):
alias = "requirements"
required = True
help = (
"A sequence of addresses to targets compatible with `jvm_artifact` that specify the coordinates for "
"third-party JVM dependencies."
)
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from dataclasses import dataclass

from pants.backend.java.subsystems.junit import JUnit
from pants.backend.java.target_types import JavaTestSourceField
from pants.core.goals.test import TestDebugRequest, TestFieldSet, TestResult, TestSubsystem
from pants.engine.addresses import Addresses
from pants.engine.fs import Digest, DigestSubset, MergeDigests, PathGlobs, RemovePrefix, Snapshot
Expand All @@ -16,16 +15,17 @@
from pants.jvm.jdk_rules import JdkSetup
from pants.jvm.resolve.coursier_fetch import MaterializedClasspath, MaterializedClasspathRequest
from pants.jvm.resolve.jvm_tool import JvmToolLockfileRequest, JvmToolLockfileSentinel
from pants.jvm.target_types import JunitTestSourceField
from pants.util.logging import LogLevel

logger = logging.getLogger(__name__)


@dataclass(frozen=True)
class JavaTestFieldSet(TestFieldSet):
required_fields = (JavaTestSourceField,)
class JunitTestFieldSet(TestFieldSet):
required_fields = (JunitTestSourceField,)

sources: JavaTestSourceField
sources: JunitTestSourceField


class JunitToolLockfileSentinel(JvmToolLockfileSentinel):
Expand All @@ -38,7 +38,7 @@ async def run_junit_test(
jdk_setup: JdkSetup,
junit: JUnit,
test_subsystem: TestSubsystem,
field_set: JavaTestFieldSet,
field_set: JunitTestFieldSet,
) -> TestResult:
classpath, junit_classpath = await MultiGet(
Get(Classpath, Addresses([field_set.address])),
Expand Down Expand Up @@ -108,7 +108,7 @@ async def run_junit_test(

# Required by standard test rules. Do nothing for now.
@rule(level=LogLevel.DEBUG)
async def setup_junit_debug_request(_field_set: JavaTestFieldSet) -> TestDebugRequest:
async def setup_junit_debug_request(_field_set: JunitTestFieldSet) -> TestDebugRequest:
raise NotImplementedError("TestDebugResult is not implemented for JUnit (yet?).")


Expand All @@ -122,6 +122,6 @@ async def generate_junit_lockfile_request(
def rules():
return [
*collect_rules(),
UnionRule(TestFieldSet, JavaTestFieldSet),
UnionRule(TestFieldSet, JunitTestFieldSet),
UnionRule(JvmToolLockfileSentinel, JunitToolLockfileSentinel),
]
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
from pants.backend.java.compile.javac import rules as javac_rules
from pants.backend.java.target_types import JavaSourcesGeneratorTarget, JunitTestsGeneratorTarget
from pants.backend.java.target_types import rules as target_types_rules
from pants.backend.java.test.junit import JavaTestFieldSet
from pants.backend.java.test.junit import rules as junit_rules
from pants.backend.scala.compile.scalac import rules as scalac_rules
from pants.backend.scala.target_types import ScalaJunitTestsGeneratorTarget
from pants.backend.scala.target_types import rules as scala_target_types_rules
from pants.build_graph.address import Address
from pants.core.goals.test import TestResult
from pants.core.util_rules import config_files, source_files
Expand All @@ -31,6 +32,8 @@
from pants.jvm.resolve.coursier_fetch import rules as coursier_fetch_rules
from pants.jvm.resolve.coursier_setup import rules as coursier_setup_rules
from pants.jvm.target_types import JvmArtifact
from pants.jvm.test.junit import JunitTestFieldSet
from pants.jvm.test.junit import rules as junit_rules
from pants.jvm.testutil import maybe_skip_jdk_test
from pants.jvm.util_rules import rules as util_rules
from pants.testutil.rule_runner import PYTHON_BOOTSTRAP_ENV, QueryRule, RuleRunner
Expand All @@ -46,24 +49,27 @@ def rule_runner() -> RuleRunner:
rule_runner = RuleRunner(
preserve_tmpdirs=True,
rules=[
*config_files.rules(),
*classpath.rules(),
*config_files.rules(),
*coursier_fetch_rules(),
*coursier_setup_rules(),
*external_tool_rules(),
*source_files.rules(),
*java_util_rules(),
*javac_rules(),
*junit_rules(),
*util_rules(),
*java_util_rules(),
*scala_target_types_rules(),
*scalac_rules(),
*source_files.rules(),
*target_types_rules(),
*util_rules(),
QueryRule(CoarsenedTargets, (Addresses,)),
QueryRule(TestResult, (JavaTestFieldSet,)),
QueryRule(TestResult, (JunitTestFieldSet,)),
],
target_types=[
JvmArtifact,
JavaSourcesGeneratorTarget,
JunitTestsGeneratorTarget,
ScalaJunitTestsGeneratorTarget,
],
)
rule_runner.set_options(
Expand Down Expand Up @@ -240,7 +246,6 @@ def test_vintage_success_with_dep(rule_runner: RuleRunner) -> None:
java_sources(
name='example-lib',
)
junit_tests(
Expand Down Expand Up @@ -289,6 +294,53 @@ def test_vintage_success_with_dep(rule_runner: RuleRunner) -> None:
assert re.search(r"1 tests found", test_result.stdout) is not None


@maybe_skip_jdk_test
def test_vintage_scala_simple_success(rule_runner: RuleRunner) -> None:
rule_runner.write_files(
{
"coursier_resolve.lockfile": JUNIT4_RESOLVED_LOCKFILE.to_json().decode("utf-8"),
"BUILD": dedent(
"""\
jvm_artifact(
name = 'junit_junit',
group = 'junit',
artifact = 'junit',
version = '4.13.2',
)
scala_junit_tests(
name='example-test',
dependencies= [
':junit_junit',
],
)
"""
),
"SimpleTest.scala": dedent(
"""
package org.pantsbuild.example
import junit.framework.TestCase
import junit.framework.Assert._
class SimpleTest extends TestCase {
def testHello(): Unit = {
assertTrue("Hello!" == "Hello!")
}
}
"""
),
}
)

test_result = run_junit_test(rule_runner, "example-test", "SimpleTest.scala")

assert test_result.exit_code == 0
assert re.search(r"Finished:\s+testHello", test_result.stdout) is not None
assert re.search(r"1 tests successful", test_result.stdout) is not None
assert re.search(r"1 tests found", test_result.stdout) is not None


# This is hard-coded to make the test somewhat more hermetic.
# To regenerate (e.g. to update the resolved version), run the
# following in a test:
Expand Down Expand Up @@ -560,4 +612,4 @@ def run_junit_test(
tgt = rule_runner.get_target(
Address(spec_path="", target_name=target_name, relative_file_path=relative_file_path)
)
return rule_runner.request(TestResult, [JavaTestFieldSet.create(tgt)])
return rule_runner.request(TestResult, [JunitTestFieldSet.create(tgt)])

0 comments on commit b99c9e5

Please sign in to comment.