Skip to content

Commit

Permalink
Refactor the new SelectInterpreter and GatherSources tasks. (pantsbui…
Browse files Browse the repository at this point in the history
…ld#4337)

Basically just breaks up the big execute() methods
into several helpers.

This is for two reasons:

- Makes it easier to test invalidation (see below).
- Will enable selecting interpreters and gathering sources
  separately for each target and its deps (for use by the
  PythonEval task) side-by-side with the current functionality
  of selecting a single interpreter and creating a single
  source PEX for the entire closure. This will be implemented
  in a future change.

Note that testing invalidation isn't an idle idea - I noticed
a bug in the custom fingerprint strategy that was mixing in
the target payload for no good reason.  This change fixes
that bug and adds a test for it.
  • Loading branch information
benjyw authored Mar 17, 2017
1 parent 98e0bac commit f65bc3d
Showing 4 changed files with 107 additions and 64 deletions.
1 change: 1 addition & 0 deletions src/python/pants/backend/python/tasks2/BUILD
Original file line number Diff line number Diff line change
@@ -19,5 +19,6 @@ python_library(
'src/python/pants/util:contextutil',
'src/python/pants/util:dirutil',
'src/python/pants/util:meta',
'src/python/pants/util:memo',
]
)
54 changes: 31 additions & 23 deletions src/python/pants/backend/python/tasks2/gather_sources.py
Original file line number Diff line number Diff line change
@@ -42,32 +42,40 @@ def prepare(cls, options, round_manager):

def execute(self):
targets = self.context.targets(predicate=has_python_sources)
interpreter = self.context.products.get_data(PythonInterpreter)

with self.invalidated(targets) as invalidation_check:
pex = self._build_pex_for_versioned_targets(interpreter, invalidation_check.all_vts)
self.context.products.get_data(self.PYTHON_SOURCES, lambda: pex)

def _build_pex_for_versioned_targets(self, interpreter, versioned_targets):
if versioned_targets:
target_set_id = VersionedTargetSet.from_versioned_targets(versioned_targets).cache_key.hash
else:
# If there are no relevant targets, we still go through the motions of gathering
# an empty set of sources, to prevent downstream tasks from having to check
# for this special case.
if invalidation_check.all_vts:
target_set_id = VersionedTargetSet.from_versioned_targets(
invalidation_check.all_vts).cache_key.hash
else:
target_set_id = 'no_targets'

interpreter = self.context.products.get_data(PythonInterpreter)
path = os.path.join(self.workdir, target_set_id)

# Note that we check for the existence of the directory, instead of for invalid_vts,
# to cover the empty case.
if not os.path.isdir(path):
path_tmp = path + '.tmp'
shutil.rmtree(path_tmp, ignore_errors=True)
self._build_pex(interpreter, path_tmp, invalidation_check.all_vts)
shutil.move(path_tmp, path)

pex = PEX(os.path.realpath(path), interpreter=interpreter)
self.context.products.get_data(self.PYTHON_SOURCES, lambda: pex)

def _build_pex(self, interpreter, path, vts):
target_set_id = 'no_targets'
source_pex_path = self._source_pex_path(target_set_id)
# Note that we check for the existence of the directory, instead of for invalid_vts,
# to cover the empty case.
if not os.path.isdir(source_pex_path):
# Note that we use the same interpreter for all targets: We know the interpreter
# is compatible (since it's compatible with all targets in play).
self._safe_build_pex(interpreter, source_pex_path, [vt.target for vt in versioned_targets])
return PEX(source_pex_path, interpreter=interpreter)

def _safe_build_pex(self, interpreter, path, targets):
path_tmp = path + '.tmp'
shutil.rmtree(path_tmp, ignore_errors=True)
self._build_pex(interpreter, path_tmp, targets)
shutil.move(path_tmp, path)

def _build_pex(self, interpreter, path, targets):
builder = PEXBuilder(path=path, interpreter=interpreter, copy=True)
for vt in vts:
dump_sources(builder, vt.target, self.context.log)
for target in targets:
dump_sources(builder, target, self.context.log)
builder.freeze()

def _source_pex_path(self, target_set_id):
return os.path.realpath(os.path.join(self.workdir, target_set_id))
70 changes: 41 additions & 29 deletions src/python/pants/backend/python/tasks2/select_interpreter.py
Original file line number Diff line number Diff line change
@@ -18,6 +18,7 @@
from pants.python.python_repos import PythonRepos
from pants.task.task import Task
from pants.util.dirutil import safe_mkdir_for
from pants.util.memo import memoized_method


class PythonInterpreterFingerprintStrategy(DefaultFingerprintHashingMixin, FingerprintStrategy):
@@ -31,7 +32,6 @@ def compute_fingerprint(self, python_target):
if not hash_elements_for_target:
return None
hasher = hashlib.sha1()
hasher.update(python_target.payload.fingerprint())
for element in hash_elements_for_target:
hasher.update(element)
return hasher.hexdigest()
@@ -60,36 +60,48 @@ def execute(self):
invalidation_check.all_vts).cache_key.hash
else:
target_set_id = 'no_targets'
interpreter_path_file = os.path.join(self.workdir, target_set_id, 'interpreter.path')
interpreter_path_file = self._interpreter_path_file(target_set_id)
if not os.path.exists(interpreter_path_file):
interpreter_cache = PythonInterpreterCache(PythonSetup.global_instance(),
PythonRepos.global_instance(),
logger=self.context.log.debug)

# Cache setup's requirement fetching can hang if run concurrently by another pants proc.
self.context.acquire_lock()
try:
interpreter_cache.setup()
finally:
self.context.release_lock()

interpreter = interpreter_cache.select_interpreter_for_targets(python_tgts)
safe_mkdir_for(interpreter_path_file)
with open(interpreter_path_file, 'w') as outfile:
outfile.write(b'{}\t{}\n'.format(interpreter.binary, str(interpreter.identity)))
for dist, location in interpreter.extras.items():
dist_name, dist_version = dist
outfile.write(b'{}\t{}\t{}\n'.format(dist_name, dist_version, location))
self._create_interpreter_path_file(interpreter_path_file, python_tgts)

if not interpreter:
with open(interpreter_path_file, 'r') as infile:
lines = infile.readlines()
binary, identity = lines[0].strip().split('\t')
extras = {}
for line in lines[1:]:
dist_name, dist_version, location = line.strip().split('\t')
extras[(dist_name, dist_version)] = location

interpreter = PythonInterpreter(binary, PythonIdentity.from_path(identity), extras)
interpreter = self._get_interpreter(interpreter_path_file)

self.context.products.get_data(PythonInterpreter, lambda: interpreter)

@memoized_method
def _interpreter_cache(self):
interpreter_cache = PythonInterpreterCache(PythonSetup.global_instance(),
PythonRepos.global_instance(),
logger=self.context.log.debug)
# Cache setup's requirement fetching can hang if run concurrently by another pants proc.
self.context.acquire_lock()
try:
interpreter_cache.setup()
finally:
self.context.release_lock()
return interpreter_cache

def _create_interpreter_path_file(self, interpreter_path_file, targets):
interpreter_cache = self._interpreter_cache()
interpreter = interpreter_cache.select_interpreter_for_targets(targets)
safe_mkdir_for(interpreter_path_file)
with open(interpreter_path_file, 'w') as outfile:
outfile.write(b'{}\t{}\n'.format(interpreter.binary, str(interpreter.identity)))
for dist, location in interpreter.extras.items():
dist_name, dist_version = dist
outfile.write(b'{}\t{}\t{}\n'.format(dist_name, dist_version, location))

def _interpreter_path_file(self, target_set_id):
return os.path.join(self.workdir, target_set_id, 'interpreter.path')

@staticmethod
def _get_interpreter(interpreter_path_file):
with open(interpreter_path_file, 'r') as infile:
lines = infile.readlines()
binary, identity = lines[0].strip().split('\t')
extras = {}
for line in lines[1:]:
dist_name, dist_version, location = line.strip().split('\t')
extras[(dist_name, dist_version)] = location
return PythonInterpreter(binary, PythonIdentity.from_path(identity), extras)
Original file line number Diff line number Diff line change
@@ -30,28 +30,30 @@ def setUp(self):
def fake_interpreter(id_str):
return PythonInterpreter('/fake/binary', PythonIdentity.from_id_string(id_str))

def fake_target(spec, compatibility=None, dependencies=None):
return self.make_target(spec=spec, target_type=PythonLibrary, sources=[],
dependencies=dependencies, compatibility=compatibility)

self.fake_interpreters = [
fake_interpreter('FakePython 2 77 777'),
fake_interpreter('FakePython 2 88 888'),
fake_interpreter('FakePython 2 99 999')
]

self.tgt1 = fake_target('tgt1')
self.tgt2 = fake_target('tgt2', compatibility=['FakePython>2.77.777'])
self.tgt3 = fake_target('tgt3', compatibility=['FakePython>2.88.888'])
self.tgt4 = fake_target('tgt4', compatibility=['FakePython<2.99.999'])
self.tgt20 = fake_target('tgt20', dependencies=[self.tgt2])
self.tgt30 = fake_target('tgt30', dependencies=[self.tgt3])
self.tgt40 = fake_target('tgt40', dependencies=[self.tgt4])
self.tgt1 = self._fake_target('tgt1')
self.tgt2 = self._fake_target('tgt2', compatibility=['FakePython>2.77.777'])
self.tgt3 = self._fake_target('tgt3', compatibility=['FakePython>2.88.888'])
self.tgt4 = self._fake_target('tgt4', compatibility=['FakePython<2.99.999'])
self.tgt20 = self._fake_target('tgt20', dependencies=[self.tgt2])
self.tgt30 = self._fake_target('tgt30', dependencies=[self.tgt3])
self.tgt40 = self._fake_target('tgt40', dependencies=[self.tgt4])

def _fake_target(self, spec, compatibility=None, sources=None, dependencies=None):
return self.make_target(spec=spec, target_type=PythonLibrary, sources=sources or [],
dependencies=dependencies, compatibility=compatibility)

def _select_interpreter(self, target_roots):
def _select_interpreter(self, target_roots, should_invalidate=None):
"""Return the version string of the interpreter selected for the target roots."""
context = self.context(target_roots=target_roots)
task = self.create_task(context)
if should_invalidate is not None:
task._create_interpreter_path_file = mock.MagicMock(wraps=task._create_interpreter_path_file)

# Mock out the interpreter cache setup, so we don't actually look for real interpreters
# on the filesystem.
@@ -62,6 +64,12 @@ def se(me, *args, **kwargs):
mock_resolve.side_effect = se
task.execute()

if should_invalidate is not None:
if should_invalidate:
task._create_interpreter_path_file.assert_called_once()
else:
task._create_interpreter_path_file.assert_not_called()

interpreter = context.products.get_data(PythonInterpreter)
self.assertTrue(isinstance(interpreter, PythonInterpreter))
return interpreter.version_string
@@ -75,9 +83,23 @@ def test_interpreter_selection(self):
self.assertEquals('FakePython-2.88.888', self._select_interpreter([self.tgt20]))
self.assertEquals('FakePython-2.99.999', self._select_interpreter([self.tgt30]))
self.assertEquals('FakePython-2.77.777', self._select_interpreter([self.tgt40]))
self.assertEquals('FakePython-2.99.999', self._select_interpreter([self.tgt2, self.tgt3]))
self.assertEquals('FakePython-2.88.888', self._select_interpreter([self.tgt2, self.tgt4]))

with self.assertRaises(TaskError) as cm:
self._select_interpreter([self.tgt3, self.tgt4])
self.assertIn('Unable to detect a suitable interpreter for compatibilities: '
'FakePython<2.99.999 && FakePython>2.88.888', str(cm.exception))

def test_interpreter_selection_invalidation(self):
tgta = self._fake_target('tgta', compatibility=['FakePython>2.77.777'],
dependencies=[self.tgt3])
self.assertEquals('FakePython-2.99.999',
self._select_interpreter([tgta], should_invalidate=True))

# A new target with different sources, but identical compatibility, shouldn't invalidate.
self.create_file('tgtb/foo/bar/baz.py', 'fake content')
tgtb = self._fake_target('tgtb', compatibility=['FakePython>2.77.777'],
dependencies=[self.tgt3], sources=['foo/bar/baz.py'])
self.assertEquals('FakePython-2.99.999',
self._select_interpreter([tgtb], should_invalidate=False))

0 comments on commit f65bc3d

Please sign in to comment.