diff --git a/src/python/pants/backend/jvm/targets/java_tests.py b/src/python/pants/backend/jvm/targets/java_tests.py index 7d08cd42e02..d833ab9830a 100644 --- a/src/python/pants/backend/jvm/targets/java_tests.py +++ b/src/python/pants/backend/jvm/targets/java_tests.py @@ -15,7 +15,7 @@ class JavaTests(JvmTarget): """JUnit tests.""" def __init__(self, cwd=None, test_platform=None, payload=None, timeout=None, - extra_jvm_options=None, **kwargs): + extra_jvm_options=None, extra_env_vars=None, **kwargs): """ :param str cwd: working directory (relative to the build root) for the tests under this target. If unspecified (None), the working directory will be controlled by junit_run's --cwd. @@ -24,12 +24,22 @@ def __init__(self, cwd=None, test_platform=None, payload=None, timeout=None, unspecified, the platform will default to the same one used for compilation. :param list extra_jvm_options: A list of key value pairs of jvm options to use when running the tests. Example: ['-Dexample.property=1'] If unspecified, no extra jvm options will be added. + :param dict extra_env_vars: A map of environment variables to set when running the tests, e.g. + { 'FOOBAR': 12 }. Using `None` as the value will cause the variable to be unset. """ self.cwd = cwd payload = payload or Payload() + + if extra_env_vars is None: + extra_env_vars = {} + for key, value in extra_env_vars.items(): + if value is not None: + extra_env_vars[key] = str(value) + payload.add_fields({ 'test_platform': PrimitiveField(test_platform), - 'extra_jvm_options': PrimitiveField(tuple(extra_jvm_options or ())) + 'extra_jvm_options': PrimitiveField(tuple(extra_jvm_options or ())), + 'extra_env_vars': PrimitiveField(tuple(extra_env_vars.items())), }) self._timeout = timeout super(JavaTests, self).__init__(payload=payload, **kwargs) diff --git a/src/python/pants/backend/jvm/tasks/junit_run.py b/src/python/pants/backend/jvm/tasks/junit_run.py index a584f0f2598..ac1e133c957 100644 --- a/src/python/pants/backend/jvm/tasks/junit_run.py +++ b/src/python/pants/backend/jvm/tasks/junit_run.py @@ -30,6 +30,7 @@ from pants.java.distribution.distribution import DistributionLocator from pants.java.executor import SubprocessExecutor from pants.task.testrunner_task_mixin import TestRunnerTaskMixin +from pants.util.contextutil import environment_as from pants.util.strutil import pluralize from pants.util.xml_parser import XmlParser @@ -288,36 +289,39 @@ def _run_tests(self, tests_to_targets): classpath_prepend = () classpath_append = () - tests_by_properties = self._tests_by_properties(tests_to_targets, - self._infer_workdir, - lambda target: target.test_platform) + tests_by_properties = self._tests_by_properties( + tests_to_targets, + self._infer_workdir, + lambda target: target.test_platform, + lambda target: target.payload.extra_jvm_options, + lambda target: target.payload.extra_env_vars, + ) # the below will be None if not set, and we'll default back to runtime_classpath classpath_product = self.context.products.get_data('instrument_classpath') result = 0 - for (workdir, platform), tests in tests_by_properties.items(): - for (target_jvm_options, target_tests) in self._partition_by_jvm_options(tests_to_targets, - tests): - for batch in self._partition(target_tests): - # Batches of test classes will likely exist within the same targets: dedupe them. - relevant_targets = set(map(tests_to_targets.get, batch)) - complete_classpath = OrderedSet() - complete_classpath.update(classpath_prepend) - complete_classpath.update(self.tool_classpath('junit')) - complete_classpath.update(self.classpath(relevant_targets, - classpath_product=classpath_product)) - complete_classpath.update(classpath_append) - distribution = self.preferred_jvm_distribution([platform]) - with binary_util.safe_args(batch, self.get_options()) as batch_tests: - self.context.log.debug('CWD = {}'.format(workdir)) - self.context.log.debug('platform = {}'.format(platform)) + for (workdir, platform, target_jvm_options, target_env_vars), tests in tests_by_properties.items(): + for batch in self._partition(tests): + # Batches of test classes will likely exist within the same targets: dedupe them. + relevant_targets = set(map(tests_to_targets.get, batch)) + complete_classpath = OrderedSet() + complete_classpath.update(classpath_prepend) + complete_classpath.update(self.tool_classpath('junit')) + complete_classpath.update(self.classpath(relevant_targets, + classpath_product=classpath_product)) + complete_classpath.update(classpath_append) + distribution = self.preferred_jvm_distribution([platform]) + with binary_util.safe_args(batch, self.get_options()) as batch_tests: + self.context.log.debug('CWD = {}'.format(workdir)) + self.context.log.debug('platform = {}'.format(platform)) + with environment_as(**dict(target_env_vars)): self._executor = SubprocessExecutor(distribution) result += abs(distribution.execute_java( executor=self._executor, classpath=complete_classpath, main=JUnitRun._MAIN, - jvm_options=self.jvm_options + extra_jvm_options + target_jvm_options, + jvm_options=self.jvm_options + extra_jvm_options + list(target_jvm_options), args=self._args + batch_tests + [u'-xmlreport'], workunit_factory=self.context.new_workunit, workunit_name='run', @@ -326,8 +330,8 @@ def _run_tests(self, tests_to_targets): synthetic_jar_dir=self.workdir, )) - if result != 0 and self._fail_fast: - break + if result != 0 and self._fail_fast: + break if result != 0: failed_targets_and_tests = self._get_failed_targets(tests_to_targets) @@ -362,21 +366,6 @@ def combined_property(target): return self._tests_by_property(tests_to_targets, combined_property) - def _partition_by_jvm_options(self, tests_to_targets, tests): - """Partitions a list of tests by the jvm options to run them with. - - :param dict tests_to_targets: A mapping from each test to its target. - :param list tests: The list of tests to run. - :returns: A list of tuples where the first element is an array of jvm options and the second - is a list of tests to run with the jvm options. Each test in tests will appear in exactly - one one tuple. - """ - jvm_options_to_tests = defaultdict(list) - for test in tests: - extra_jvm_options = tests_to_targets[test].payload.extra_jvm_options - jvm_options_to_tests[extra_jvm_options].append(test) - return [(list(jvm_options), tests) for jvm_options, tests in jvm_options_to_tests.items()] - def _partition(self, tests): stride = min(self._batch_size, len(tests)) for i in range(0, len(tests), stride): diff --git a/tests/python/pants_test/backend/jvm/tasks/BUILD b/tests/python/pants_test/backend/jvm/tasks/BUILD index b1c0d7a2ca7..252805b36e8 100644 --- a/tests/python/pants_test/backend/jvm/tasks/BUILD +++ b/tests/python/pants_test/backend/jvm/tasks/BUILD @@ -304,6 +304,7 @@ python_tests( 'src/python/pants/ivy', 'src/python/pants/java/distribution:distribution', 'src/python/pants/java:executor', + 'src/python/pants/util:contextutil', 'src/python/pants/util:dirutil', 'src/python/pants/util:timeout', 'tests/python/pants_test/jvm:jvm_tool_task_test_base', diff --git a/tests/python/pants_test/backend/jvm/tasks/test_junit_run.py b/tests/python/pants_test/backend/jvm/tasks/test_junit_run.py index 00d62184696..d5e952c5b5d 100644 --- a/tests/python/pants_test/backend/jvm/tasks/test_junit_run.py +++ b/tests/python/pants_test/backend/jvm/tasks/test_junit_run.py @@ -21,6 +21,7 @@ from pants.ivy.ivy_subsystem import IvySubsystem from pants.java.distribution.distribution import DistributionLocator from pants.java.executor import SubprocessExecutor +from pants.util.contextutil import environment_as from pants.util.dirutil import safe_file_dump from pants.util.timeout import TimeoutReached from pants_test.jvm.jvm_tool_task_test_base import JvmToolTaskTestBase @@ -141,7 +142,7 @@ def test_junit_runner_timeout_fail(self): args, kwargs = mock_timeout.call_args self.assertEqual(args, (1,)) - def execute_junit_runner(self, content, **kwargs): + def execute_junit_runner(self, content, create_some_resources=True, **kwargs): # Create the temporary base test directory test_rel_path = 'tests/java/org/pantsbuild/foo' test_abs_path = self.create_dir(test_rel_path) @@ -178,15 +179,18 @@ def execute_junit_runner(self, content, **kwargs): else: target = self.create_library(test_rel_path, 'java_tests', 'foo_test', ['FooTest.java']) - # Create a synthetic resource target. - resources = self.make_target('some_resources', Resources) + target_roots = [] + if create_some_resources: + # Create a synthetic resource target. + target_roots.append(self.make_target('some_resources', Resources)) + target_roots.append(target) # Set the context with the two targets, one java_tests target and # one synthetic resources target. # The synthetic resources target is to make sure we won't regress # in the future with bug like https://github.com/pantsbuild/pants/issues/508. Note # in that bug, the resources target must be the first one in the list. - context = self.context(target_roots=[resources, target]) + context = self.context(target_roots=target_roots) # Before we run the task, we need to inject the "runtime_classpath" with # the compiled test java classes that JUnitRun will know which test @@ -296,3 +300,79 @@ def test_junit_runner_multiple_extra_jvm_options(self): """), target_name='foo:foo_test' ) + + def test_junit_runner_extra_env_vars(self): + self.make_target( + spec='foo:foo_test', + target_type=JavaTests, + sources=['FooTest.java'], + extra_env_vars={ + 'HELLO': 27, + 'THERE': 32, + }, + ) + + self.make_target( + spec='bar:bar_test', + target_type=JavaTests, + sources=['FooTest.java'], + extra_env_vars={ + 'THE_ANSWER': 42, + 'HELLO': 12, + }, + ) + + self.execute_junit_runner(dedent(""" + import org.junit.Test; + import static org.junit.Assert.assertEquals; + public class FooTest { + @Test + public void testFoo() { + assertEquals("27", System.getenv().get("HELLO")); + assertEquals("32", System.getenv().get("THERE")); + } + } + """), target_name='foo:foo_test') + + # Execute twice in a row to make sure the environment changes aren't sticky. + self.execute_junit_runner(dedent(""" + import org.junit.Test; + import static org.junit.Assert.assertEquals; + import static org.junit.Assert.assertFalse; + public class FooTest { + @Test + public void testFoo() { + assertEquals("12", System.getenv().get("HELLO")); + assertEquals("42", System.getenv().get("THE_ANSWER")); + assertFalse(System.getenv().containsKey("THERE")); + } + } + """), target_name='bar:bar_test', create_some_resources=False) + + def test_junit_runner_extra_env_vars_none(self): + with environment_as(THIS_VARIABLE="12", THAT_VARIABLE="This is a variable."): + self.make_target( + spec='foo:foo_test', + target_type=JavaTests, + sources=['FooTest.java'], + extra_env_vars={ + 'HELLO': None, + 'THERE': False, + 'THIS_VARIABLE': None + }, + ) + + self.execute_junit_runner(dedent(""" + import org.junit.Test; + import static org.junit.Assert.assertEquals; + import static org.junit.Assert.assertFalse; + public class FooTest { + @Test + public void testFoo() { + assertEquals("False", System.getenv().get("THERE")); + assertEquals("This is a variable.", System.getenv().get("THAT_VARIABLE")); + assertFalse(System.getenv().containsKey("HELLO")); + assertFalse(System.getenv().containsKey("THIS_VARIABLE")); + } + } + """), target_name='foo:foo_test')