Skip to content

Commit

Permalink
javascript: Nodejs tests batch support (pantsbuild#18742)
Browse files Browse the repository at this point in the history
Add a `batch_compatibility_tag` field to `javascript_test` to support
running tests in batches via the nodejs package manager test runners.
  • Loading branch information
tobni authored May 1, 2023
1 parent 0047da1 commit 1b63d39
Show file tree
Hide file tree
Showing 7 changed files with 353 additions and 69 deletions.
110 changes: 92 additions & 18 deletions src/python/pants/backend/javascript/goals/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,28 @@
from __future__ import annotations

import dataclasses
from collections import defaultdict
from dataclasses import dataclass
from pathlib import PurePath
from typing import Any, Iterable
from typing import Iterable

from pants.backend.javascript import install_node_package, nodejs_project_environment
from pants.backend.javascript.install_node_package import (
InstalledNodePackage,
InstalledNodePackageRequest,
)
from pants.backend.javascript.nodejs_project_environment import NodeJsProjectEnvironmentProcess
from pants.backend.javascript.package_json import NodePackageTestScriptField, NodeTestScript
from pants.backend.javascript.package_json import (
NodePackageNameField,
NodePackageTestScriptField,
NodeTestScript,
OwningNodePackage,
OwningNodePackageRequest,
)
from pants.backend.javascript.subsystems.nodejstest import NodeJSTest
from pants.backend.javascript.target_types import (
JSSourceField,
JSTestBatchCompatibilityTagField,
JSTestExtraEnvVarsField,
JSTestSourceField,
JSTestTimeoutField,
Expand All @@ -37,6 +45,7 @@
from pants.core.target_types import AssetSourceField
from pants.core.util_rules import source_files
from pants.core.util_rules.distdir import DistDir
from pants.core.util_rules.partitions import Partition, PartitionerType, Partitions
from pants.core.util_rules.source_files import SourceFiles, SourceFilesRequest
from pants.engine.env_vars import EnvironmentVars, EnvironmentVarsRequest
from pants.engine.fs import DigestSubset, GlobExpansionConjunction
Expand All @@ -48,19 +57,21 @@
from pants.engine.target import (
Dependencies,
SourcesField,
Target,
TransitiveTargets,
TransitiveTargetsRequest,
)
from pants.engine.unions import UnionRule
from pants.util.dirutil import fast_relpath
from pants.util.frozendict import FrozenDict
from pants.util.logging import LogLevel
from pants.util.strutil import pluralize


@dataclass(frozen=True)
class JSCoverageData(CoverageData):
snapshot: Snapshot
address: Address
addresses: tuple[Address, ...]
output_files: tuple[str, ...]
output_directories: tuple[str, ...]
working_directory: str
Expand All @@ -74,6 +85,7 @@ class JSCoverageDataCollection(CoverageDataCollection[JSCoverageData]):
class JSTestFieldSet(TestFieldSet):
required_fields = (JSTestSourceField,)

batch_compatibility_tag: JSTestBatchCompatibilityTagField
source: JSTestSourceField
dependencies: Dependencies
timeout: JSTestTimeoutField
Expand All @@ -84,21 +96,70 @@ class JSTestRequest(TestRequest):
tool_subsystem = NodeJSTest
field_set_type = JSTestFieldSet

partitioner_type = PartitionerType.CUSTOM


@dataclass(frozen=True)
class TestMetadata:
extra_env_vars: tuple[str, ...]
owning_target: Target
compatibility_tag: str | None = None

__test__ = False

@property
def description(self) -> str:
return f'{self.owning_target[NodePackageNameField].value} {self.compatibility_tag or ""}'


@rule(desc="Partition NodeJS tests", level=LogLevel.DEBUG)
async def partition_nodejs_tests(
request: JSTestRequest.PartitionRequest[JSTestFieldSet],
) -> Partitions[JSTestFieldSet, TestMetadata]:
partitions = []
compatible_tests = defaultdict(list)
owning_packages = await MultiGet(
Get(OwningNodePackage, OwningNodePackageRequest(field_set.address))
for field_set in request.field_sets
)
for field_set, owning_package in zip(request.field_sets, owning_packages):
metadata = TestMetadata(
extra_env_vars=field_set.extra_env_vars.sorted(),
owning_target=owning_package.ensure_owner(),
compatibility_tag=field_set.batch_compatibility_tag.value,
)

if not metadata.compatibility_tag:
partitions.append(Partition((field_set,), metadata))
else:
compatible_tests[metadata].append(field_set)

for metadata, field_sets in compatible_tests.items():
partitions.append(Partition(tuple(field_sets), metadata))

return Partitions(partitions)


@rule(level=LogLevel.DEBUG, desc="Run javascript tests")
async def run_javascript_tests(
batch: JSTestRequest.Batch[JSTestFieldSet, Any],
batch: JSTestRequest.Batch[JSTestFieldSet, TestMetadata],
test: TestSubsystem,
test_extra_env: TestExtraEnv,
) -> TestResult:
field_set = batch.single_element
installation_get = Get(InstalledNodePackage, InstalledNodePackageRequest(field_set.address))
transitive_tgts_get = Get(TransitiveTargets, TransitiveTargetsRequest([field_set.address]))
field_sets = batch.elements
metadata = batch.partition_metadata
installation_get = Get(
InstalledNodePackage,
InstalledNodePackageRequest(metadata.owning_target.address),
)
transitive_tgts_get = Get(
TransitiveTargets, TransitiveTargetsRequest(field_set.address for field_set in field_sets)
)

field_set_source_files_get = Get(SourceFiles, SourceFilesRequest([field_set.source]))
target_env_vars_get = Get(
EnvironmentVars, EnvironmentVarsRequest(field_set.extra_env_vars.sorted())
field_set_source_files_get = Get(
SourceFiles, SourceFilesRequest(field_set.source for field_set in field_sets)
)
target_env_vars_get = Get(EnvironmentVars, EnvironmentVarsRequest(metadata.extra_env_vars))
installation, transitive_tgts, field_set_source_files, target_env_vars = await MultiGet(
installation_get, transitive_tgts_get, field_set_source_files_get, target_env_vars_get
)
Expand Down Expand Up @@ -128,6 +189,17 @@ def relative_package_dir(file: str) -> str:
output_directories.extend(test_script.coverage_output_directories)
entry_point = test_script.coverage_entry_point or entry_point

timeout_seconds: int | None = None
for field_set in field_sets:
timeout = field_set.timeout.calculate_from_global_options(test)
if timeout:
if timeout_seconds:
timeout_seconds += timeout
else:
timeout_seconds = timeout
file_description = field_sets[0].address.spec
if len(field_sets) > 1:
file_description += f"+ {pluralize(len(field_sets) - 1, 'other file')}"
process = await Get(
Process,
NodeJsProjectEnvironmentProcess(
Expand All @@ -139,11 +211,11 @@ def relative_package_dir(file: str) -> str:
*sorted(relative_package_dir(file) for file in field_set_source_files.files),
*coverage_args,
),
description=f"Running npm test for {field_set.address.spec}.",
description=f"Running npm tests for {file_description}.",
input_digest=merged_digest,
level=LogLevel.INFO,
extra_env=FrozenDict(**test_extra_env.env, **target_env_vars),
timeout_seconds=field_set.timeout.calculate_from_global_options(test),
timeout_seconds=timeout_seconds,
output_files=tuple(
installation.join_relative_workspace_directory(file) for file in output_files or ()
),
Expand All @@ -168,14 +240,14 @@ def relative_package_dir(file: str) -> str:
)
coverage_data = JSCoverageData(
coverage_snapshot,
field_set.address,
tuple(field_set.address for field_set in field_sets),
output_files=test_script.coverage_output_files,
output_directories=test_script.coverage_output_directories,
working_directory=installation.project_env.relative_workspace_directory(),
)

return TestResult.from_fallible_process_result(
result, field_set.address, test.output, coverage_data=coverage_data
return TestResult.from_batched_fallible_process_result(
result, batch, test.output, coverage_data=coverage_data
)


Expand Down Expand Up @@ -210,7 +282,9 @@ async def collect_coverage_reports(
snapshots = await MultiGet(get for _, _, get in gets_per_data)
return CoverageReports(
tuple(
_get_report(nodejs_test, dist_dir, snapshot, data.address, file, data.working_directory)
_get_report(
nodejs_test, dist_dir, snapshot, data.addresses, file, data.working_directory
)
for (file, data), snapshot in zip(
((file, report) for file, report, _ in gets_per_data), snapshots
)
Expand All @@ -222,13 +296,13 @@ def _get_report(
nodejs_test: NodeJSTest,
dist_dir: DistDir,
snapshot: Snapshot,
address: Address,
addresses: tuple[Address, ...],
file: str,
working_directory: str,
) -> FilesystemCoverageReport:
# It is up to the user to configure the output coverage reports.
file_path = PurePath(file)
output_dir = nodejs_test.render_coverage_output_dir(dist_dir, address)
output_dir = nodejs_test.render_coverage_output_dir(dist_dir, addresses)
return FilesystemCoverageReport(
coverage_insufficient=False,
result_snapshot=snapshot,
Expand Down
85 changes: 75 additions & 10 deletions src/python/pants/backend/javascript/goals/test_integration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@

from pants.backend.javascript import package_json
from pants.backend.javascript.goals import test
from pants.backend.javascript.goals.test import JSCoverageData, JSTestFieldSet, JSTestRequest
from pants.backend.javascript.goals.test import (
JSCoverageData,
JSTestFieldSet,
JSTestRequest,
TestMetadata,
)
from pants.backend.javascript.package_json import PackageJsonTarget
from pants.backend.javascript.target_types import (
JSSourcesGeneratorTarget,
Expand All @@ -21,6 +26,7 @@
from pants.build_graph.address import Address
from pants.core.goals.test import TestResult, get_filtered_environment
from pants.engine.rules import QueryRule
from pants.engine.target import Target
from pants.testutil.rule_runner import RuleRunner


Expand Down Expand Up @@ -116,13 +122,66 @@ def test_jest_tests_are_successful(
}
)
tgt = rule_runner.get_target(Address("foo/src/tests", relative_file_path="index.test.js"))
result = rule_runner.request(
TestResult, [JSTestRequest.Batch("", (JSTestFieldSet.create(tgt),), None)]
)
package = rule_runner.get_target(Address("foo", generated_name="pkg"))
result = rule_runner.request(TestResult, [given_request_for(tgt, package=package)])
assert "Test Suites: 1 passed, 1 total" in result.stderr
assert result.exit_code == 0


def test_batched_jest_tests_are_successful(rule_runner: RuleRunner) -> None:
rule_runner.write_files(
{
"foo/BUILD": "package_json()",
"foo/package.json": given_package_json(
test_script={"test": "NODE_OPTIONS=--experimental-vm-modules jest"},
runner={"jest": "^29.5"},
),
"foo/package-lock.json": (
Path(__file__).parent / "jest_resources/package-lock.json"
).read_text(),
"foo/src/BUILD": "javascript_sources()",
"foo/src/index.mjs": _SOURCE_TO_TEST,
"foo/src/tests/BUILD": "javascript_tests(name='tests', batch_compatibility_tag='default')",
"foo/src/tests/index.test.js": textwrap.dedent(
"""\
/**
* @jest-environment node
*/
import { expect } from "@jest/globals"
import { add } from "../index.mjs"
test('adds 1 + 2 to equal 3', () => {
expect(add(1, 2)).toBe(3);
});
"""
),
"foo/src/tests/another.test.js": textwrap.dedent(
"""\
/**
* @jest-environment node
*/
import { expect } from "@jest/globals"
import { add } from "../index.mjs"
test('adds 2 + 3 to equal 5', () => {
expect(add(2, 3)).toBe(5);
});
"""
),
}
)
tgt_1 = rule_runner.get_target(Address("foo/src/tests", relative_file_path="index.test.js"))
tgt_2 = rule_runner.get_target(Address("foo/src/tests", relative_file_path="another.test.js"))
package = rule_runner.get_target(Address("foo", generated_name="pkg"))
result = rule_runner.request(TestResult, [given_request_for(tgt_1, tgt_2, package=package)])
assert "Test Suites: 2 passed, 2 total" in result.stderr
assert result.exit_code == 0


def test_mocha_tests_are_successful(rule_runner: RuleRunner) -> None:
rule_runner.write_files(
{
Expand Down Expand Up @@ -150,9 +209,8 @@ def test_mocha_tests_are_successful(rule_runner: RuleRunner) -> None:
}
)
tgt = rule_runner.get_target(Address("foo/src/tests", relative_file_path="index.test.mjs"))
result = rule_runner.request(
TestResult, [JSTestRequest.Batch("", (JSTestFieldSet.create(tgt),), None)]
)
package = rule_runner.get_target(Address("foo", generated_name="pkg"))
result = rule_runner.request(TestResult, [given_request_for(tgt, package=package)])
assert "1 passing" in result.stdout
assert result.exit_code == 0

Expand Down Expand Up @@ -201,10 +259,17 @@ def test_jest_test_with_coverage_reporting(rule_runner: RuleRunner) -> None:
}
)
tgt = rule_runner.get_target(Address("foo/src/tests", relative_file_path="index.test.js"))
result = rule_runner.request(
TestResult, [JSTestRequest.Batch("", (JSTestFieldSet.create(tgt),), None)]
)
package = rule_runner.get_target(Address("foo", generated_name="pkg"))
result = rule_runner.request(TestResult, [given_request_for(tgt, package=package)])
assert result.coverage_data

rule_runner.write_digest(cast(JSCoverageData, result.coverage_data).snapshot.digest)
assert Path(rule_runner.build_root, ".coverage/clover.xml").exists()


def given_request_for(*js_test: Target, package: Target) -> JSTestRequest.Batch:
return JSTestRequest.Batch(
"",
tuple(JSTestFieldSet.create(tgt) for tgt in js_test),
TestMetadata(tuple(), package),
)
Loading

0 comments on commit 1b63d39

Please sign in to comment.