Skip to content

Commit

Permalink
Factor up shared test partitioning code. (pantsbuild#5416)
Browse files Browse the repository at this point in the history
This factors the handling of `--fast`/`--no-fast` partitioning as well
as the attendant results summarization and caching and invalidation
handling to a mixin that `PytestRun` and `JUnitRun` share.

Fixes pantsbuild#5307
  • Loading branch information
jsirois authored Feb 1, 2018
1 parent 6098667 commit a6a98fe
Show file tree
Hide file tree
Showing 3 changed files with 305 additions and 302 deletions.
168 changes: 34 additions & 134 deletions src/python/pants/backend/jvm/tasks/junit_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,16 @@
from pants.backend.jvm.tasks.reports.junit_html_report import JUnitHtmlReport, NoJunitHtmlReport
from pants.base.build_environment import get_buildroot
from pants.base.deprecated import deprecated_conditional
from pants.base.exceptions import ErrorWhileTesting, TargetDefinitionException, TaskError
from pants.base.exceptions import TargetDefinitionException, TaskError
from pants.base.workunit import WorkUnitLabel
from pants.build_graph.files import Files
from pants.build_graph.target import Target
from pants.build_graph.target_scopes import Scopes
from pants.invalidation.cache_manager import VersionedTargetSet
from pants.java.distribution.distribution import DistributionLocator
from pants.java.executor import SubprocessExecutor
from pants.java.junit.junit_xml_parser import RegistryOfTests, Test, parse_failed_targets
from pants.process.lock import OwnerPrintingInterProcessFileLock
from pants.task.testrunner_task_mixin import TestResult, TestRunnerTaskMixin
from pants.task.testrunner_task_mixin import PartitionedTestRunnerTaskMixin, TestResult
from pants.util import desktop
from pants.util.argutil import ensure_arg, remove_arg
from pants.util.contextutil import environment_as, temporary_dir
Expand Down Expand Up @@ -131,7 +130,7 @@ def iter_possible_tests(self, context):
yield Test(classname=self._classname, methodname=self._methodname)


class JUnitRun(TestRunnerTaskMixin, JvmToolTaskMixin, JvmTask):
class JUnitRun(PartitionedTestRunnerTaskMixin, JvmToolTaskMixin, JvmTask):
"""
:API: public
"""
Expand All @@ -146,10 +145,6 @@ def implementation_version(cls):
def register_options(cls, register):
super(JUnitRun, cls).register_options(register)

register('--fast', type=bool, default=True, fingerprint=True,
help='Run all tests in a single junit invocation. If turned off, each test target '
'will run in its own junit invocation, which will be slower, but isolates '
'tests from process-wide state created by tests in other targets.')
register('--batch-size', advanced=True, type=int, default=cls._BATCH_ALL, fingerprint=True,
help='Run at most this many tests in a single test process.')
register('--test', type=list, fingerprint=True,
Expand All @@ -174,11 +169,6 @@ def register_options(cls, register):
help='Set the working directory. If no argument is passed, use the build root. '
'If cwd is set on a target, it will supersede this option. It is an error to '
'use this option in combination with `--chroot`')
register('--chroot', advanced=True, fingerprint=True, type=bool, default=False,
help='Run tests in a chroot. Any loose files tests depend on via `{}` dependencies '
'will be copied to the chroot. If cwd is set on a target, it will supersede this'
'option. It is an error to use this option in combination with `--cwd`'
.format(Files.alias()))
register('--strict-jvm-version', type=bool, advanced=True, fingerprint=True,
help='If true, will strictly require running junits with the same version of java as '
'the platform -target level. Otherwise, the platform -target level will be '
Expand Down Expand Up @@ -234,13 +224,12 @@ def __init__(self, *args, **kwargs):
options = self.get_options()
self._tests_to_run = options.test
self._batch_size = options.batch_size
self._fail_fast = options.fail_fast

if options.cwd and options.chroot:
if options.cwd and self.run_tests_in_chroot:
raise self.OptionError('Cannot set both `cwd` ({}) and ask for a `chroot` at the same time.'
.format(options.cwd))

if options.chroot:
if self.run_tests_in_chroot:
self._working_dir = None
else:
self._working_dir = options.cwd or get_buildroot()
Expand All @@ -252,7 +241,7 @@ def __init__(self, *args, **kwargs):
self._legacy_report_layout = options.legacy_report_layout

@memoized_method
def _args(self, output_dir):
def _args(self, fail_fast, output_dir):
args = self.args[:]

options = self.get_options()
Expand All @@ -263,7 +252,7 @@ def _args(self, output_dir):
else:
args.append('-output-mode=NONE')

if self._fail_fast:
if fail_fast:
args.append('-fail-fast')
args.append('-outdir')
args.append(output_dir)
Expand Down Expand Up @@ -398,15 +387,11 @@ def _chroot(self, targets, workdir):
)
yield chroot

@property
def _per_target(self):
return not self.get_options().fast

@property
def _batched(self):
return self._batch_size != self._BATCH_ALL

def _run_junit(self, test_targets, output_dir, coverage):
def run_tests(self, fail_fast, test_targets, output_dir, coverage):
test_registry = self._collect_test_targets(test_targets)
if test_registry.empty:
return TestResult.rc(0)
Expand Down Expand Up @@ -447,7 +432,7 @@ def parse_error_handler(parse_error):
distribution = JvmPlatform.preferred_jvm_distribution([platform], self._strict_jvm_version)

# Override cmdline args with values from junit_test() target that specify concurrency:
args = self._args(batch_output_dir) + [u'-xmlreport']
args = self._args(fail_fast, batch_output_dir) + [u'-xmlreport']

if concurrency is not None:
args = remove_arg(args, '-default-parallel')
Expand Down Expand Up @@ -495,7 +480,7 @@ def parse_error_handler(parse_error):
self.report_all_info_for_single_test(self.options_scope, test_target,
test_name, test_info)

if result != 0 and self._fail_fast:
if result != 0 and fail_fast:
break

if result == 0:
Expand Down Expand Up @@ -573,119 +558,34 @@ def _validate_target(self, target):
msg = 'JUnitTests target must include a non-empty set of sources.'
raise TargetDefinitionException(target, msg)

@staticmethod
def _vts_for_partition(invalidation_check):
return VersionedTargetSet.from_versioned_targets(invalidation_check.all_vts)

def check_artifact_cache_for(self, invalidation_check):
# We generate artifacts, namely coverage reports, that cover the full target set.
return [self._vts_for_partition(invalidation_check)]

@staticmethod
def _collect_files(directory):
def collect_files(self, output_dir, coverage):
def files_iter():
for dir_path, _, file_names in os.walk(directory):
for dir_path, _, file_names in os.walk(output_dir):
for filename in file_names:
yield os.path.join(dir_path, filename)
return list(files_iter())

def _iter_partitions(self, targets, output_dir):
if self._per_target:
for target in targets:
yield (target,), os.path.join(output_dir, target.id)
else:
if targets:
yield tuple(targets), output_dir

def _execute(self, all_targets):
with self._isolation(all_targets) as (output_dir, reports, coverage):
results = {}
failure = False
for (partition, partition_output_dir) in self._iter_partitions(self._get_test_targets(),
output_dir):
try:
rv = self._run_partition(test_targets=partition,
output_dir=partition_output_dir,
coverage=coverage)
except ErrorWhileTesting as e:
rv = TestResult.from_error(e)

results[partition] = rv
if not rv.success:
failure = True
if self._fail_fast:
break

for partition in sorted(results):
rv = results[partition]
if len(partition) == 1 or rv.success:
log = self.context.log.info if rv.success else self.context.log.error
for target in partition:
log('{0:80}.....{1:>10}'.format(target.address.reference(), rv))
else:
# There is not much useful we can display in summary for a multi-target partition with
# failures without parsing those failures to link them to individual targets; ie: targets
# 2 and 8 failed in this partition of 10 targets.
# TODO(John Sirois): Punting here works since we have in practice just 2 partitionings:
# 1. All targets in singleton partitions
# 2. All targets in 1 partition
# If we get to the point where we have multiple partitions with multiple targets, some
# sort of summary for the multi-target partitions will probably be needed.
pass

msgs = [str(_rv) for _rv in results.values() if not _rv.success]
failed_targets = [target
for _rv in results.values() if not _rv.success
for target in _rv.failed_targets]
if len(failed_targets) > 0:
error = ErrorWhileTesting('\n'.join(msgs), failed_targets=failed_targets)
elif failure:
# A low-level test execution failure occurred before tests were run.
error = TaskError()
else:
error = None

reports.generate(output_dir, exc=error)
if error:
raise error

def _run_partition(self, test_targets, output_dir, coverage):
with self.invalidated(targets=test_targets,
# Re-run tests when the code they test (and depend on) changes.
invalidate_dependents=True) as invalidation_check:

invalid_test_tgts = [invalid_test_tgt
for vts in invalidation_check.invalid_vts
for invalid_test_tgt in vts.targets]

# Processing proceeds through:
# 1.) output -> output_dir
# 2.) [iff all == invalid] output_dir -> cache: We do this manually for now.
# 3.) [iff invalid == 0 and all > 0] cache -> workdir: Done transparently by `invalidated`.

# 1.) Write all results that will be potentially cached to output_dir.
result = self._run_junit(invalid_test_tgts, output_dir, coverage).checked()

cache_vts = self._vts_for_partition(invalidation_check)
if invalidation_check.all_vts == invalidation_check.invalid_vts:
# 2.) All tests in the partition were invalid, cache successful test results.
if result.success and self.artifact_cache_writes_enabled():
self.update_artifact_cache([(cache_vts, self._collect_files(output_dir))])
elif not invalidation_check.invalid_vts:
# 3.) The full partition was valid, our results will have been staged for/by caching
# if not already local.
pass
@contextmanager
def partitions(self, per_target, all_targets, test_targets):
with self._isolation(per_target, all_targets) as (output_dir, reports, coverage):
if per_target:
def iter_partitions():
for test_target in test_targets:
partition = (test_target,)
args = (os.path.join(output_dir, test_target.id), coverage)
yield partition, args
else:
# The partition was partially invalid.

# We don't cache results; so others will need to re-run this partition.
# NB: We will presumably commit this change now though and so others will get this
# partition in a state that executes successfully; so when the 1st of the others
# executes against this partition; they will hit `all_vts == invalid_vts` and
# cache the results. That 1st of others is hopefully CI!
cache_vts.force_invalidate()

return result
def iter_partitions():
if test_targets:
partition = tuple(test_targets)
args = (output_dir, coverage)
yield partition, args

try:
yield iter_partitions
finally:
_, error, _ = sys.exc_info()
reports.generate(output_dir, exc=error)

class Reports(object):
def __init__(self, junit_html_report, coverage):
Expand All @@ -707,9 +607,9 @@ def _maybe_open_report(self, report_file_path):
raise TaskError(e)

@contextmanager
def _isolation(self, all_targets):
def _isolation(self, per_target, all_targets):
run_dir = '_runs'
mode_dir = 'isolated' if self._per_target else 'combined'
mode_dir = 'isolated' if per_target else 'combined'
batch_dir = str(self._batch_size) if self._batched else 'all'
output_dir = os.path.join(self.workdir,
run_dir,
Expand Down
Loading

0 comments on commit a6a98fe

Please sign in to comment.