From e3cf637cc72ba10f34bd6eb2d11ac6242e819b72 Mon Sep 17 00:00:00 2001 From: Stu Hood Date: Thu, 29 Aug 2019 09:07:55 -0700 Subject: [PATCH] Split out a double-check-cache job for jvm/rsc compile. (#8221) ### Problem #8190 moved cache writing out of the completion of the zinc and rsc jobs and into a dependent job. But at the same time, we also had multiple attempts to "double check" the cache happening concurrently due to both the zinc and rsc jobs checking, and that race could lead to partial entries being extracted. ### Solution Since we can't actually cancel or coordinate the concurrent work, we can't safely double check the cache once either job has started. So instead, this change extracts the cache double-check into its own job that both the zinc and rsc tasks will depend on. --- .../jvm/tasks/jvm_compile/jvm_compile.py | 61 +++++++----- .../jvm/tasks/jvm_compile/rsc/rsc_compile.py | 99 ++++++++++--------- .../tasks/jvm_compile/rsc/test_rsc_compile.py | 49 ++++++++- 3 files changed, 135 insertions(+), 74 deletions(-) diff --git a/src/python/pants/backend/jvm/tasks/jvm_compile/jvm_compile.py b/src/python/pants/backend/jvm/tasks/jvm_compile/jvm_compile.py index e276e8008d8..86d0257ca9f 100644 --- a/src/python/pants/backend/jvm/tasks/jvm_compile/jvm_compile.py +++ b/src/python/pants/backend/jvm/tasks/jvm_compile/jvm_compile.py @@ -711,26 +711,31 @@ def _upstream_analysis(self, compile_contexts, classpath_entries): else: yield compile_context.classes_dir.path, compile_context.analysis_file + def exec_graph_double_check_cache_key_for_target(self, target): + return 'double_check_cache({})'.format(target.address.spec) + def exec_graph_key_for_target(self, compile_target): return "compile({})".format(compile_target.address.spec) def _create_compile_jobs(self, compile_contexts, invalid_targets, invalid_vts, classpath_product): class Counter: - def __init__(self, size, initial=0): + def __init__(self, size=0): self.size = size - self.count = initial + self.count = 0 def __call__(self): self.count += 1 return self.count + def increment_size(self, by=1): + self.size += by + def format_length(self): return len(str(self.size)) - counter = Counter(len(invalid_vts)) jobs = [] + counter = Counter() - jobs.extend(self.pre_compile_jobs(counter)) invalid_target_set = set(invalid_targets) for ivts in invalid_vts: # Invalidated targets are a subset of relevant targets: get the context for this one. @@ -738,26 +743,32 @@ def format_length(self): invalid_dependencies = self._collect_invalid_compile_dependencies(compile_target, invalid_target_set) - jobs.extend( - self.create_compile_jobs(compile_target, compile_contexts, invalid_dependencies, ivts, - counter, classpath_product)) + new_jobs, new_count = self.create_compile_jobs( + compile_target, compile_contexts, invalid_dependencies, ivts, counter, classpath_product) + jobs.extend(new_jobs) + counter.increment_size(by=new_count) - counter.size = len(jobs) return jobs - def pre_compile_jobs(self, counter): - """Override this to provide jobs that are not related to particular targets. - - This is only called when there are invalid targets.""" - return [] - def create_compile_jobs(self, compile_target, all_compile_contexts, invalid_dependencies, ivts, counter, classpath_product): + """Return a list of jobs, and a count of those jobs that represent meaningful ("countable") work.""" context_for_target = all_compile_contexts[compile_target] compile_context = self.select_runtime_context(context_for_target) - job = Job(self.exec_graph_key_for_target(compile_target), + compile_deps = [self.exec_graph_key_for_target(target) for target in invalid_dependencies] + + # The cache checking job doesn't technically have any dependencies, but we want to delay it + # until immediately before we would otherwise try compiling, so we indicate that it depends on + # all compile dependencies. + double_check_cache_job = Job(self.exec_graph_double_check_cache_key_for_target(compile_target), + functools.partial(self._default_double_check_cache_for_vts, ivts), + compile_deps) + # The compile job depends on the cache check job. This decomposition is necessary in order to + # support more complex situations where compilation runs multiple jobs in parallel, and wants to + # double check the cache before starting any of them. + compile_job = Job(self.exec_graph_key_for_target(compile_target), functools.partial( self._default_work_for_vts, ivts, @@ -766,15 +777,15 @@ def create_compile_jobs(self, compile_target, all_compile_contexts, invalid_depe counter, all_compile_contexts, classpath_product), - [self.exec_graph_key_for_target(target) for target in invalid_dependencies], + [double_check_cache_job.key] + compile_deps, self._size_estimator(compile_context.sources), # If compilation and analysis work succeeds, validate the vts. # Otherwise, fail it. on_success=ivts.update, on_failure=ivts.force_invalidate) - return [job] + return ([double_check_cache_job, compile_job], 1) - def check_cache(self, vts, counter): + def check_cache(self, vts): """Manually checks the artifact cache (usually immediately before compilation.) Returns true if the cache was hit successfully, indicating that no compilation is necessary. @@ -790,7 +801,6 @@ def check_cache(self, vts, counter): 'Cache returned unexpected target: {} vs {}'.format(cached_vts, [vts]) ) self.context.log.info('Hit cache during double check for {}'.format(vts.target.address.spec)) - counter() return True def should_compile_incrementally(self, vts, ctx): @@ -916,13 +926,18 @@ def _get_jvm_distribution(self): self.HERMETIC: lambda: self._HermeticDistribution('.jdk', local_distribution), })() + def _default_double_check_cache_for_vts(self, vts): + # Double check the cache before beginning compilation + if self.check_cache(vts): + vts.update() + def _default_work_for_vts(self, vts, ctx, input_classpath_product_key, counter, all_compile_contexts, output_classpath_product): progress_message = ctx.target.address.spec - # Double check the cache before beginning compilation - hit_cache = self.check_cache(vts, counter) - - if not hit_cache: + # See whether the cache-doublecheck job hit the cache: if so, noop: otherwise, compile. + if vts.valid: + counter() + else: # Compute the compile classpath for this target. dependency_cp_entries = self._zinc.compile_classpath_entries( input_classpath_product_key, diff --git a/src/python/pants/backend/jvm/tasks/jvm_compile/rsc/rsc_compile.py b/src/python/pants/backend/jvm/tasks/jvm_compile/rsc/rsc_compile.py index de7a9a4dc8a..3a4cb74d4c9 100644 --- a/src/python/pants/backend/jvm/tasks/jvm_compile/rsc/rsc_compile.py +++ b/src/python/pants/backend/jvm/tasks/jvm_compile/rsc/rsc_compile.py @@ -292,25 +292,6 @@ def _zinc_key_for_target(self, target, workflow): def _write_to_cache_key_for_target(self, target): return 'write_to_cache({})'.format(target.address.spec) - def _check_cache_before_work(self, work_str, vts, ctx, counter, debug = False, work_fn = lambda: None): - hit_cache = self.check_cache(vts, counter) - - if not hit_cache: - counter_val = str(counter()).rjust(counter.format_length(), ' ') - counter_str = '[{}/{}] '.format(counter_val, counter.size) - log_fn = self.context.log.debug if debug else self.context.log.info - log_fn( - counter_str, - f'{work_str} ', - items_to_report_element(ctx.sources, '{} source'.format(self.name())), - ' in ', - items_to_report_element([t.address.reference() for t in vts.targets], 'target'), - ' (', - ctx.target.address.spec, - ').') - - work_fn() - def create_compile_jobs(self, compile_target, compile_contexts, @@ -323,7 +304,19 @@ def work_for_vts_rsc(vts, ctx): target = ctx.target tgt, = vts.targets - def work_fn(): + # If we didn't hit the cache in the cache job, run rsc. + if not vts.valid: + counter_val = str(counter()).rjust(counter.format_length(), ' ') + counter_str = '[{}/{}] '.format(counter_val, counter.size) + self.context.log.info( + counter_str, + 'Rsc-ing ', + items_to_report_element(ctx.sources, '{} source'.format(self.name())), + ' in ', + items_to_report_element([t.address.reference() for t in vts.targets], 'target'), + ' (', + ctx.target.address.spec, + ').') # This does the following # - Collect the rsc classpath elements, including zinc compiles of rsc incompatible targets # and rsc compiles of rsc compatible targets. @@ -391,16 +384,11 @@ def nonhermetic_digest_classpath(): 'rsc' ) - # Double check the cache before beginning compilation - self._check_cache_before_work('Rsc-ing', vts, ctx, counter, work_fn=work_fn) - # Update the products with the latest classes. self.register_extra_products_from_contexts([ctx.target], compile_contexts) - def work_for_vts_write_to_cache(vts, ctx): - self._check_cache_before_work('Writing to cache for', vts, ctx, counter, debug=True) - ### Create Jobs for ExecutionGraph + cache_doublecheck_jobs = [] rsc_jobs = [] zinc_jobs = [] @@ -410,6 +398,8 @@ def work_for_vts_write_to_cache(vts, ctx): rsc_compile_context = merged_compile_context.rsc_cc zinc_compile_context = merged_compile_context.zinc_cc + cache_doublecheck_key = self.exec_graph_double_check_cache_key_for_target(compile_target) + def all_zinc_rsc_invalid_dep_keys(invalid_deps): """Get the rsc key for an rsc-and-zinc target, or the zinc key for a zinc-only target.""" for tgt in invalid_deps: @@ -420,6 +410,14 @@ def all_zinc_rsc_invalid_dep_keys(invalid_deps): # Rely on the results of zinc compiles for zinc-compatible targets yield self._key_for_target_as_dep(tgt, tgt_rsc_cc.workflow) + def make_cache_doublecheck_job(dep_keys): + # As in JvmCompile.create_compile_jobs, we create a cache-double-check job that all "real" work + # depends on. It depends on completion of the same dependencies as the rsc job in order to run + # as late as possible, while still running before rsc or zinc. + return Job(cache_doublecheck_key, + functools.partial(self._default_double_check_cache_for_vts, ivts), + dependencies=list(dep_keys)) + def make_rsc_job(target, dep_targets): return Job( key=self._rsc_key_for_target(target), @@ -432,7 +430,7 @@ def make_rsc_job(target, dep_targets): ), # The rsc jobs depend on other rsc jobs, and on zinc jobs for targets that are not # processed by rsc. - dependencies=list(all_zinc_rsc_invalid_dep_keys(dep_targets)), + dependencies=[cache_doublecheck_key] + list(all_zinc_rsc_invalid_dep_keys(dep_targets)), size=self._size_estimator(rsc_compile_context.sources), ) @@ -453,7 +451,7 @@ def make_zinc_job(target, input_product_key, output_products, dep_keys): counter, compile_contexts, CompositeProductAdder(*output_products)), - dependencies=list(dep_keys), + dependencies=[cache_doublecheck_key] + list(dep_keys), size=self._size_estimator(zinc_compile_context.sources), ) @@ -470,6 +468,19 @@ def record(k, v): record('workflow', workflow.value) record('execution_strategy', self.execution_strategy) + # Create the cache doublecheck job. + workflow.resolve_for_enum_variant({ + 'zinc-only': lambda: cache_doublecheck_jobs.append( + make_cache_doublecheck_job(list(all_zinc_rsc_invalid_dep_keys(invalid_dependencies))) + ), + 'zinc-java': lambda: cache_doublecheck_jobs.append( + make_cache_doublecheck_job(list(only_zinc_invalid_dep_keys(invalid_dependencies))) + ), + 'rsc-and-zinc': lambda: cache_doublecheck_jobs.append( + make_cache_doublecheck_job(list(all_zinc_rsc_invalid_dep_keys(invalid_dependencies))) + ), + })() + # Create the rsc job. # Currently, rsc only supports outlining scala. workflow.resolve_for_enum_variant({ @@ -519,25 +530,19 @@ def record(k, v): )), })() - all_jobs = rsc_jobs + zinc_jobs - - if all_jobs: - write_to_cache_job = Job( - key=self._write_to_cache_key_for_target(compile_target), - fn=functools.partial( - work_for_vts_write_to_cache, - ivts, - rsc_compile_context, - ), - dependencies=[job.key for job in all_jobs], - run_asap=True, - # If compilation and analysis work succeeds, validate the vts. - # Otherwise, fail it. - on_success=ivts.update, - on_failure=ivts.force_invalidate) - all_jobs.append(write_to_cache_job) - - return all_jobs + compile_jobs = rsc_jobs + zinc_jobs + + # Create a job that depends on all real work having completed that will eagerly write to the + # cache by calling `vt.update()`. + write_to_cache_job = Job( + key=self._write_to_cache_key_for_target(compile_target), + fn=ivts.update, + dependencies=[job.key for job in compile_jobs], + run_asap=True, + on_failure=ivts.force_invalidate) + + all_jobs = cache_doublecheck_jobs + rsc_jobs + zinc_jobs + [write_to_cache_job] + return (all_jobs, len(compile_jobs)) class RscZincMergedCompileContexts(datatype([ ('rsc_cc', RscCompileContext), diff --git a/tests/python/pants_test/backend/jvm/tasks/jvm_compile/rsc/test_rsc_compile.py b/tests/python/pants_test/backend/jvm/tasks/jvm_compile/rsc/test_rsc_compile.py index 00693f24163..43f0c57ec5d 100644 --- a/tests/python/pants_test/backend/jvm/tasks/jvm_compile/rsc/test_rsc_compile.py +++ b/tests/python/pants_test/backend/jvm/tasks/jvm_compile/rsc/test_rsc_compile.py @@ -72,12 +72,17 @@ def test_force_compiler_tags(self): classpath_product=None) dependee_graph = self.construct_dependee_graph_str(jobs, task) - print(dependee_graph) self.assertEqual(dedent(""" + double_check_cache(java/classpath:java_lib) <- { + zinc[zinc-java](java/classpath:java_lib) + } zinc[zinc-java](java/classpath:java_lib) <- { write_to_cache(java/classpath:java_lib) } write_to_cache(java/classpath:java_lib) <- {} + double_check_cache(scala/classpath:scala_lib) <- { + zinc[zinc-only](scala/classpath:scala_lib) + } zinc[zinc-only](scala/classpath:scala_lib) <- { write_to_cache(scala/classpath:scala_lib) } @@ -115,12 +120,17 @@ def test_no_dependencies_between_scala_and_java_targets(self): classpath_product=None) dependee_graph = self.construct_dependee_graph_str(jobs, task) - print(dependee_graph) self.assertEqual(dedent(""" + double_check_cache(java/classpath:java_lib) <- { + zinc[zinc-java](java/classpath:java_lib) + } zinc[zinc-java](java/classpath:java_lib) <- { write_to_cache(java/classpath:java_lib) } write_to_cache(java/classpath:java_lib) <- {} + double_check_cache(scala/classpath:scala_lib) <- { + zinc[zinc-only](scala/classpath:scala_lib) + } zinc[zinc-only](scala/classpath:scala_lib) <- { write_to_cache(scala/classpath:scala_lib) } @@ -152,12 +162,15 @@ def test_default_workflow_of_zinc_only_zincs_scala(self): classpath_product=None) dependee_graph = self.construct_dependee_graph_str(jobs, task) - print(dependee_graph) self.assertEqual(dedent(""" + double_check_cache(scala/classpath:scala_lib) <- { + zinc[zinc-only](scala/classpath:scala_lib) + } zinc[zinc-only](scala/classpath:scala_lib) <- { write_to_cache(scala/classpath:scala_lib) } - write_to_cache(scala/classpath:scala_lib) <- {}""").strip(), + write_to_cache(scala/classpath:scala_lib) <- {} + """).strip(), dependee_graph) def test_rsc_dep_for_scala_java_and_test_targets(self): @@ -208,30 +221,49 @@ def test_rsc_dep_for_scala_java_and_test_targets(self): dependee_graph = self.construct_dependee_graph_str(jobs, task) + self.maxDiff = None self.assertEqual(dedent(""" + double_check_cache(java/classpath:java_lib) <- { + zinc[zinc-java](java/classpath:java_lib) + } zinc[zinc-java](java/classpath:java_lib) <- { write_to_cache(java/classpath:java_lib) } write_to_cache(java/classpath:java_lib) <- {} + double_check_cache(scala/classpath:scala_lib) <- { + rsc(scala/classpath:scala_lib), + zinc[rsc-and-zinc](scala/classpath:scala_lib) + } rsc(scala/classpath:scala_lib) <- { write_to_cache(scala/classpath:scala_lib), + double_check_cache(scala/classpath:scala_test), zinc[zinc-only](scala/classpath:scala_test) } zinc[rsc-and-zinc](scala/classpath:scala_lib) <- { write_to_cache(scala/classpath:scala_lib) } write_to_cache(scala/classpath:scala_lib) <- {} + double_check_cache(scala/classpath:scala_dep) <- { + rsc(scala/classpath:scala_dep), + zinc[rsc-and-zinc](scala/classpath:scala_dep) + } rsc(scala/classpath:scala_dep) <- { + double_check_cache(scala/classpath:scala_lib), rsc(scala/classpath:scala_lib), zinc[rsc-and-zinc](scala/classpath:scala_lib), write_to_cache(scala/classpath:scala_dep), + double_check_cache(scala/classpath:scala_test), zinc[zinc-only](scala/classpath:scala_test) } zinc[rsc-and-zinc](scala/classpath:scala_dep) <- { + double_check_cache(java/classpath:java_lib), zinc[zinc-java](java/classpath:java_lib), write_to_cache(scala/classpath:scala_dep) } write_to_cache(scala/classpath:scala_dep) <- {} + double_check_cache(scala/classpath:scala_test) <- { + zinc[zinc-only](scala/classpath:scala_test) + } zinc[zinc-only](scala/classpath:scala_test) <- { write_to_cache(scala/classpath:scala_test) } @@ -280,14 +312,23 @@ def test_scala_lib_with_java_sources_not_passed_to_rsc(self): dependee_graph = self.construct_dependee_graph_str(jobs, task) self.assertEqual(dedent(""" + double_check_cache(java/classpath:java_lib) <- { + zinc[zinc-java](java/classpath:java_lib) + } zinc[zinc-java](java/classpath:java_lib) <- { write_to_cache(java/classpath:java_lib) } write_to_cache(java/classpath:java_lib) <- {} + double_check_cache(scala/classpath:scala_with_direct_java_sources) <- { + zinc[zinc-java](scala/classpath:scala_with_direct_java_sources) + } zinc[zinc-java](scala/classpath:scala_with_direct_java_sources) <- { write_to_cache(scala/classpath:scala_with_direct_java_sources) } write_to_cache(scala/classpath:scala_with_direct_java_sources) <- {} + double_check_cache(scala/classpath:scala_with_indirect_java_sources) <- { + zinc[zinc-java](scala/classpath:scala_with_indirect_java_sources) + } zinc[zinc-java](scala/classpath:scala_with_indirect_java_sources) <- { write_to_cache(scala/classpath:scala_with_indirect_java_sources) }