From f65bc3d8e3445afc996db3fea8c6d2577aefec3b Mon Sep 17 00:00:00 2001 From: Benjy Weinberger Date: Fri, 17 Mar 2017 13:30:04 +0200 Subject: [PATCH] Refactor the new SelectInterpreter and GatherSources tasks. (#4337) Basically just breaks up the big execute() methods into several helpers. This is for two reasons: - Makes it easier to test invalidation (see below). - Will enable selecting interpreters and gathering sources separately for each target and its deps (for use by the PythonEval task) side-by-side with the current functionality of selecting a single interpreter and creating a single source PEX for the entire closure. This will be implemented in a future change. Note that testing invalidation isn't an idle idea - I noticed a bug in the custom fingerprint strategy that was mixing in the target payload for no good reason. This change fixes that bug and adds a test for it. --- src/python/pants/backend/python/tasks2/BUILD | 1 + .../backend/python/tasks2/gather_sources.py | 54 ++++++++------ .../python/tasks2/select_interpreter.py | 70 +++++++++++-------- .../python/tasks2/test_select_interpreter.py | 46 ++++++++---- 4 files changed, 107 insertions(+), 64 deletions(-) diff --git a/src/python/pants/backend/python/tasks2/BUILD b/src/python/pants/backend/python/tasks2/BUILD index 6715217d106..1216455dab7 100644 --- a/src/python/pants/backend/python/tasks2/BUILD +++ b/src/python/pants/backend/python/tasks2/BUILD @@ -19,5 +19,6 @@ python_library( 'src/python/pants/util:contextutil', 'src/python/pants/util:dirutil', 'src/python/pants/util:meta', + 'src/python/pants/util:memo', ] ) diff --git a/src/python/pants/backend/python/tasks2/gather_sources.py b/src/python/pants/backend/python/tasks2/gather_sources.py index 6280f5f2561..d4576aab8d4 100644 --- a/src/python/pants/backend/python/tasks2/gather_sources.py +++ b/src/python/pants/backend/python/tasks2/gather_sources.py @@ -42,32 +42,40 @@ def prepare(cls, options, round_manager): def execute(self): targets = self.context.targets(predicate=has_python_sources) + interpreter = self.context.products.get_data(PythonInterpreter) + with self.invalidated(targets) as invalidation_check: + pex = self._build_pex_for_versioned_targets(interpreter, invalidation_check.all_vts) + self.context.products.get_data(self.PYTHON_SOURCES, lambda: pex) + + def _build_pex_for_versioned_targets(self, interpreter, versioned_targets): + if versioned_targets: + target_set_id = VersionedTargetSet.from_versioned_targets(versioned_targets).cache_key.hash + else: # If there are no relevant targets, we still go through the motions of gathering # an empty set of sources, to prevent downstream tasks from having to check # for this special case. - if invalidation_check.all_vts: - target_set_id = VersionedTargetSet.from_versioned_targets( - invalidation_check.all_vts).cache_key.hash - else: - target_set_id = 'no_targets' - - interpreter = self.context.products.get_data(PythonInterpreter) - path = os.path.join(self.workdir, target_set_id) - - # Note that we check for the existence of the directory, instead of for invalid_vts, - # to cover the empty case. - if not os.path.isdir(path): - path_tmp = path + '.tmp' - shutil.rmtree(path_tmp, ignore_errors=True) - self._build_pex(interpreter, path_tmp, invalidation_check.all_vts) - shutil.move(path_tmp, path) - - pex = PEX(os.path.realpath(path), interpreter=interpreter) - self.context.products.get_data(self.PYTHON_SOURCES, lambda: pex) - - def _build_pex(self, interpreter, path, vts): + target_set_id = 'no_targets' + source_pex_path = self._source_pex_path(target_set_id) + # Note that we check for the existence of the directory, instead of for invalid_vts, + # to cover the empty case. + if not os.path.isdir(source_pex_path): + # Note that we use the same interpreter for all targets: We know the interpreter + # is compatible (since it's compatible with all targets in play). + self._safe_build_pex(interpreter, source_pex_path, [vt.target for vt in versioned_targets]) + return PEX(source_pex_path, interpreter=interpreter) + + def _safe_build_pex(self, interpreter, path, targets): + path_tmp = path + '.tmp' + shutil.rmtree(path_tmp, ignore_errors=True) + self._build_pex(interpreter, path_tmp, targets) + shutil.move(path_tmp, path) + + def _build_pex(self, interpreter, path, targets): builder = PEXBuilder(path=path, interpreter=interpreter, copy=True) - for vt in vts: - dump_sources(builder, vt.target, self.context.log) + for target in targets: + dump_sources(builder, target, self.context.log) builder.freeze() + + def _source_pex_path(self, target_set_id): + return os.path.realpath(os.path.join(self.workdir, target_set_id)) diff --git a/src/python/pants/backend/python/tasks2/select_interpreter.py b/src/python/pants/backend/python/tasks2/select_interpreter.py index 72e3245eb2f..079f24027f2 100644 --- a/src/python/pants/backend/python/tasks2/select_interpreter.py +++ b/src/python/pants/backend/python/tasks2/select_interpreter.py @@ -18,6 +18,7 @@ from pants.python.python_repos import PythonRepos from pants.task.task import Task from pants.util.dirutil import safe_mkdir_for +from pants.util.memo import memoized_method class PythonInterpreterFingerprintStrategy(DefaultFingerprintHashingMixin, FingerprintStrategy): @@ -31,7 +32,6 @@ def compute_fingerprint(self, python_target): if not hash_elements_for_target: return None hasher = hashlib.sha1() - hasher.update(python_target.payload.fingerprint()) for element in hash_elements_for_target: hasher.update(element) return hasher.hexdigest() @@ -60,36 +60,48 @@ def execute(self): invalidation_check.all_vts).cache_key.hash else: target_set_id = 'no_targets' - interpreter_path_file = os.path.join(self.workdir, target_set_id, 'interpreter.path') + interpreter_path_file = self._interpreter_path_file(target_set_id) if not os.path.exists(interpreter_path_file): - interpreter_cache = PythonInterpreterCache(PythonSetup.global_instance(), - PythonRepos.global_instance(), - logger=self.context.log.debug) - - # Cache setup's requirement fetching can hang if run concurrently by another pants proc. - self.context.acquire_lock() - try: - interpreter_cache.setup() - finally: - self.context.release_lock() - - interpreter = interpreter_cache.select_interpreter_for_targets(python_tgts) - safe_mkdir_for(interpreter_path_file) - with open(interpreter_path_file, 'w') as outfile: - outfile.write(b'{}\t{}\n'.format(interpreter.binary, str(interpreter.identity))) - for dist, location in interpreter.extras.items(): - dist_name, dist_version = dist - outfile.write(b'{}\t{}\t{}\n'.format(dist_name, dist_version, location)) + self._create_interpreter_path_file(interpreter_path_file, python_tgts) if not interpreter: - with open(interpreter_path_file, 'r') as infile: - lines = infile.readlines() - binary, identity = lines[0].strip().split('\t') - extras = {} - for line in lines[1:]: - dist_name, dist_version, location = line.strip().split('\t') - extras[(dist_name, dist_version)] = location - - interpreter = PythonInterpreter(binary, PythonIdentity.from_path(identity), extras) + interpreter = self._get_interpreter(interpreter_path_file) self.context.products.get_data(PythonInterpreter, lambda: interpreter) + + @memoized_method + def _interpreter_cache(self): + interpreter_cache = PythonInterpreterCache(PythonSetup.global_instance(), + PythonRepos.global_instance(), + logger=self.context.log.debug) + # Cache setup's requirement fetching can hang if run concurrently by another pants proc. + self.context.acquire_lock() + try: + interpreter_cache.setup() + finally: + self.context.release_lock() + return interpreter_cache + + def _create_interpreter_path_file(self, interpreter_path_file, targets): + interpreter_cache = self._interpreter_cache() + interpreter = interpreter_cache.select_interpreter_for_targets(targets) + safe_mkdir_for(interpreter_path_file) + with open(interpreter_path_file, 'w') as outfile: + outfile.write(b'{}\t{}\n'.format(interpreter.binary, str(interpreter.identity))) + for dist, location in interpreter.extras.items(): + dist_name, dist_version = dist + outfile.write(b'{}\t{}\t{}\n'.format(dist_name, dist_version, location)) + + def _interpreter_path_file(self, target_set_id): + return os.path.join(self.workdir, target_set_id, 'interpreter.path') + + @staticmethod + def _get_interpreter(interpreter_path_file): + with open(interpreter_path_file, 'r') as infile: + lines = infile.readlines() + binary, identity = lines[0].strip().split('\t') + extras = {} + for line in lines[1:]: + dist_name, dist_version, location = line.strip().split('\t') + extras[(dist_name, dist_version)] = location + return PythonInterpreter(binary, PythonIdentity.from_path(identity), extras) diff --git a/tests/python/pants_test/backend/python/tasks2/test_select_interpreter.py b/tests/python/pants_test/backend/python/tasks2/test_select_interpreter.py index 93580a37ae0..5ddef1e18d6 100644 --- a/tests/python/pants_test/backend/python/tasks2/test_select_interpreter.py +++ b/tests/python/pants_test/backend/python/tasks2/test_select_interpreter.py @@ -30,28 +30,30 @@ def setUp(self): def fake_interpreter(id_str): return PythonInterpreter('/fake/binary', PythonIdentity.from_id_string(id_str)) - def fake_target(spec, compatibility=None, dependencies=None): - return self.make_target(spec=spec, target_type=PythonLibrary, sources=[], - dependencies=dependencies, compatibility=compatibility) - self.fake_interpreters = [ fake_interpreter('FakePython 2 77 777'), fake_interpreter('FakePython 2 88 888'), fake_interpreter('FakePython 2 99 999') ] - self.tgt1 = fake_target('tgt1') - self.tgt2 = fake_target('tgt2', compatibility=['FakePython>2.77.777']) - self.tgt3 = fake_target('tgt3', compatibility=['FakePython>2.88.888']) - self.tgt4 = fake_target('tgt4', compatibility=['FakePython<2.99.999']) - self.tgt20 = fake_target('tgt20', dependencies=[self.tgt2]) - self.tgt30 = fake_target('tgt30', dependencies=[self.tgt3]) - self.tgt40 = fake_target('tgt40', dependencies=[self.tgt4]) + self.tgt1 = self._fake_target('tgt1') + self.tgt2 = self._fake_target('tgt2', compatibility=['FakePython>2.77.777']) + self.tgt3 = self._fake_target('tgt3', compatibility=['FakePython>2.88.888']) + self.tgt4 = self._fake_target('tgt4', compatibility=['FakePython<2.99.999']) + self.tgt20 = self._fake_target('tgt20', dependencies=[self.tgt2]) + self.tgt30 = self._fake_target('tgt30', dependencies=[self.tgt3]) + self.tgt40 = self._fake_target('tgt40', dependencies=[self.tgt4]) + + def _fake_target(self, spec, compatibility=None, sources=None, dependencies=None): + return self.make_target(spec=spec, target_type=PythonLibrary, sources=sources or [], + dependencies=dependencies, compatibility=compatibility) - def _select_interpreter(self, target_roots): + def _select_interpreter(self, target_roots, should_invalidate=None): """Return the version string of the interpreter selected for the target roots.""" context = self.context(target_roots=target_roots) task = self.create_task(context) + if should_invalidate is not None: + task._create_interpreter_path_file = mock.MagicMock(wraps=task._create_interpreter_path_file) # Mock out the interpreter cache setup, so we don't actually look for real interpreters # on the filesystem. @@ -62,6 +64,12 @@ def se(me, *args, **kwargs): mock_resolve.side_effect = se task.execute() + if should_invalidate is not None: + if should_invalidate: + task._create_interpreter_path_file.assert_called_once() + else: + task._create_interpreter_path_file.assert_not_called() + interpreter = context.products.get_data(PythonInterpreter) self.assertTrue(isinstance(interpreter, PythonInterpreter)) return interpreter.version_string @@ -75,9 +83,23 @@ def test_interpreter_selection(self): self.assertEquals('FakePython-2.88.888', self._select_interpreter([self.tgt20])) self.assertEquals('FakePython-2.99.999', self._select_interpreter([self.tgt30])) self.assertEquals('FakePython-2.77.777', self._select_interpreter([self.tgt40])) + self.assertEquals('FakePython-2.99.999', self._select_interpreter([self.tgt2, self.tgt3])) self.assertEquals('FakePython-2.88.888', self._select_interpreter([self.tgt2, self.tgt4])) with self.assertRaises(TaskError) as cm: self._select_interpreter([self.tgt3, self.tgt4]) self.assertIn('Unable to detect a suitable interpreter for compatibilities: ' 'FakePython<2.99.999 && FakePython>2.88.888', str(cm.exception)) + + def test_interpreter_selection_invalidation(self): + tgta = self._fake_target('tgta', compatibility=['FakePython>2.77.777'], + dependencies=[self.tgt3]) + self.assertEquals('FakePython-2.99.999', + self._select_interpreter([tgta], should_invalidate=True)) + + # A new target with different sources, but identical compatibility, shouldn't invalidate. + self.create_file('tgtb/foo/bar/baz.py', 'fake content') + tgtb = self._fake_target('tgtb', compatibility=['FakePython>2.77.777'], + dependencies=[self.tgt3], sources=['foo/bar/baz.py']) + self.assertEquals('FakePython-2.99.999', + self._select_interpreter([tgtb], should_invalidate=False))