Skip to content
This repository has been archived by the owner on Dec 10, 2020. It is now read-only.

Commit

Permalink
Integrating Patrick's SimpleCodegenTask base class with WireGen.
Browse files Browse the repository at this point in the history
Refactors of other codegen classes omitted in this patch for
simplicity; I will circle back and add Jaxb and Protobufs in a
new review once this patch goes through.

There are some TODO's for integrating isolated code-generation
strategies (the real motivation behind simplifying codegen),
which I will also circle back and replace with real code after
landing this patch.

Testing Done:
test tests/python/pants_test/backend/codegen/tasks:wire_gen passes,
test tests/python/pants_test/tasks:wire_integration passes,
CI is green.

I updated the wire test-cases a bit, including adding a unit test for sources_generated_by_target. I also added another .proto to the wire/elements example, and updated the corresponding integration test, to insure codegen works properly with multiple .proto files in the same target.

Bugs closed: 1597

Reviewed at https://rbcommons.com/s/twitter/r/2274/
  • Loading branch information
gmalmquist authored and ericzundel committed May 29, 2015
1 parent 0aad08c commit d0717e9
Show file tree
Hide file tree
Showing 8 changed files with 259 additions and 66 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

package org.pantsbuild.example.wire.element;

import org.pantsbuild.example.element.Compound;
import org.pantsbuild.example.element.Element;
import org.pantsbuild.example.temperature.Temperature;

Expand All @@ -19,6 +20,13 @@ public static void main(String[] args) {
Temperature boilingPoint = new Temperature.Builder().unit("celsius").number((long)357).build();
Element mercury = new Element.Builder().symbol("Hg").name("Mercury").atomic_number(80)
.melting_point(meltingPoint).boiling_point(boilingPoint).build();
Compound water = new Compound.Builder().name("Water")
.primary_element(
new Element.Builder().symbol("O").name("Oxygen").atomic_number(8).build())
.secondary_element(
new Element.Builder().symbol("H").name("Hydrogen").atomic_number(1).build())
.build();
System.out.println(mercury.toString());
System.out.println(water.toString());
}
}
3 changes: 2 additions & 1 deletion examples/src/wire/org/pantsbuild/example/element/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

java_wire_library(name='element',
sources=[
'elements.proto',
'elements.proto', # Order matters here.
'compound.proto',
],
dependencies=[
'examples/src/wire/org/pantsbuild/example/temperature',
Expand Down
14 changes: 14 additions & 0 deletions examples/src/wire/org/pantsbuild/example/element/compound.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// Copyright 2015 Pants project contributors (see CONTRIBUTORS.md).
// Licensed under the Apache License, Version 2.0 (see LICENSE).

package org.pantsbuild.example.element;

/**
* Describes a compound of two elements.
*/
message Compound {
required string name = 1;
optional org.pantsbuild.example.element.Element primary_element = 2;
optional org.pantsbuild.example.element.Element secondary_element = 3;
}

11 changes: 11 additions & 0 deletions src/python/pants/backend/codegen/tasks/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,16 @@ python_library(
],
)

python_library(
name = 'simple_codegen_task',
sources = ['simple_codegen_task.py'],
dependencies = [
':common',
'src/python/pants/base:address',
'src/python/pants/base:build_environment',
],
)

python_library(
name = 'wire_gen',
sources = ['wire_gen.py'],
Expand All @@ -145,6 +155,7 @@ python_library(
':code_gen',
':protobuf_gen',
':protobuf_parse',
':simple_codegen_task',
'src/python/pants/backend/jvm/targets:java',
'src/python/pants/backend/jvm/targets:jvm',
'src/python/pants/backend/jvm/tasks:jvm_tool_task_mixin',
Expand Down
161 changes: 161 additions & 0 deletions src/python/pants/backend/codegen/tasks/simple_codegen_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
# coding=utf-8
# Copyright 2015 Pants project contributors (see CONTRIBUTORS.md).
# Licensed under the Apache License, Version 2.0 (see LICENSE).

from __future__ import (absolute_import, division, generators, nested_scopes, print_function,
unicode_literals, with_statement)

import os

from pants.backend.core.tasks.task import Task
from pants.base.address import SyntheticAddress
from pants.base.build_environment import get_buildroot


class SimpleCodegenTask(Task):
"""A base-class for code generation for a single target language."""

@classmethod
def get_fingerprint_strategy(cls):
"""Override this method to use a fingerprint strategy other than the default one.
:return: a fingerprint strategy, or None to use the default strategy.
"""
return None

def synthetic_target_extra_dependencies(self, target):
"""Gets any extra dependencies generated synthetic targets should have.
This method is optional for subclasses to implement, because some code generators may have no
extra dependencies.
:param Target target: the Target from which we are generating a synthetic Target. E.g., 'target'
might be a JavaProtobufLibrary, whose corresponding synthetic Target would be a JavaLibrary.
It may not be necessary to use this parameter depending on the details of the subclass.
:return: a list of dependencies.
"""
return []

@property
def synthetic_target_type(self):
"""The type of target this codegen task generates.
For example, the target type for JaxbGen would simply be JavaLibrary.
:return: a type (class) that inherits from Target.
"""
raise NotImplementedError

def is_gentarget(self, target):
"""Predicate which determines whether the target in question is relevant to this codegen task.
E.g., the JaxbGen task considers JaxbLibrary targets to be relevant, and nothing else.
:param Target target: The target to check.
:return: True if this class can generate code for the given target, False otherwise.
"""
raise NotImplementedError

def execute_codegen(self, invalid_targets):
"""Generated code for the given list of targets.
:param invalid_targets: an iterable of targets (a subset of codegen_targets()).
"""
raise NotImplementedError

def sources_generated_by_target(self, target):
"""Predicts what source files will be generated from the given codegen target.
:param Target target: the codegen target in question (eg a .proto library).
:return: an iterable of strings containing the file system paths to the sources files.
"""
raise NotImplementedError

def codegen_targets(self):
"""Finds codegen targets in the dependency graph.
:return: an iterable of dependency targets.
"""
return self.context.targets(self.is_gentarget)

def codegen_workdir(self, target):
"""The path to the directory code should be generated in.
E.g., this might be something like /home/user/repo/.pants.d/gen/jaxb/...
Generally, subclasses should not need to override this method. If they do, it is crucial that
the implementation is /deterministic/ -- that is, the return value of this method should always
be the same for the same input target.
:return: The absolute file path.
"""
# TODO(gm): This method will power the isolated/global strategies for what directories to put
# generated code in, once that exists. This will work in a similar fashion to the jvm_compile
# tasks' isolated vs global strategies, generated code per-target in a way that avoids
# collisions.
return self.workdir

def execute(self):
targets = self.codegen_targets()
with self.invalidated(targets,
invalidate_dependents=True,
fingerprint_strategy=self.get_fingerprint_strategy()) as invalidation_check:
for vts in invalidation_check.invalid_vts:
invalid_targets = vts.targets
self.execute_codegen(invalid_targets)

invalid_vts_by_target = dict([(vt.target, vt) for vt in invalidation_check.invalid_vts])
vts_artifactfiles_pairs = []

for target in targets:
target_workdir = self.codegen_workdir(target)
synthetic_name = target.id
sources_rel_path = os.path.relpath(target_workdir, get_buildroot())
spec_path = '{0}{1}'.format(type(self).__name__, sources_rel_path)
synthetic_address = SyntheticAddress(spec_path, synthetic_name)
# TODO(gm): sources_generated_by_target() shouldn't be necessary for the isolated codegen
# strategy, once that exists.
raw_generated_sources = self.sources_generated_by_target(target)
# Make the sources robust regardless of whether subclasses return relative paths, or
# absolute paths that are subclasses of the workdir.
generated_sources = [src if src.startswith(target_workdir)
else os.path.join(target_workdir, src)
for src in raw_generated_sources]
relative_generated_sources = [os.path.relpath(src, target_workdir)
for src in generated_sources]

self.target = self.context.add_new_target(
address=synthetic_address,
target_type=self.synthetic_target_type,
dependencies=self.synthetic_target_extra_dependencies(target),
sources_rel_path=sources_rel_path,
sources=relative_generated_sources,
derived_from=target,
provides=target.provides,
)
synthetic_target = self.target

build_graph = self.context.build_graph

# NOTE(pl): This bypasses the convenience function (Target.inject_dependency) in order
# to improve performance. Note that we can walk the transitive dependee subgraph once
# for transitive invalidation rather than walking a smaller subgraph for every single
# dependency injected.
for dependent_address in build_graph.dependents_of(target.address):
build_graph.inject_dependency(
dependent=dependent_address,
dependency=synthetic_target.address,
)
# NOTE(pl): See the above comment. The same note applies.
for concrete_dependency_address in build_graph.dependencies_of(target.address):
build_graph.inject_dependency(
dependent=synthetic_target.address,
dependency=concrete_dependency_address,
)
build_graph.walk_transitive_dependee_graph(
build_graph.dependencies_of(target.address),
work=lambda t: t.mark_transitive_invalidation_hash_dirty(),
)

if target in self.context.target_roots:
self.context.target_roots.append(synthetic_target)
if target in invalid_vts_by_target:
vts_artifactfiles_pairs.append((invalid_vts_by_target[target], generated_sources))

if self.artifact_cache_writes_enabled():
self.update_artifact_cache(vts_artifactfiles_pairs)
88 changes: 27 additions & 61 deletions src/python/pants/backend/codegen/tasks/wire_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,16 @@
import itertools
import logging
import os
from collections import OrderedDict, defaultdict
from collections import OrderedDict

from twitter.common.collections import OrderedSet, maybe_list
from twitter.common.collections import OrderedSet

from pants.backend.codegen.targets.java_protobuf_library import JavaProtobufLibrary
from pants.backend.codegen.targets.java_wire_library import JavaWireLibrary
from pants.backend.codegen.tasks.code_gen import CodeGen
from pants.backend.codegen.tasks.protobuf_gen import check_duplicate_conflicting_protos
from pants.backend.codegen.tasks.protobuf_parse import ProtobufParse
from pants.backend.codegen.tasks.simple_codegen_task import SimpleCodegenTask
from pants.backend.jvm.targets.java_library import JavaLibrary
from pants.backend.jvm.tasks.jvm_tool_task_mixin import JvmToolTaskMixin
from pants.base.address import SyntheticAddress
from pants.base.address_lookup_error import AddressLookupError
from pants.base.build_environment import get_buildroot
from pants.base.exceptions import TaskError
Expand All @@ -31,7 +29,7 @@
logger = logging.getLogger(__name__)


class WireGen(CodeGen, JvmToolTaskMixin):
class WireGen(SimpleCodegenTask, JvmToolTaskMixin):
@classmethod
def register_options(cls, register):
super(WireGen, cls).register_options(register)
Expand All @@ -42,7 +40,23 @@ def register_options(cls, register):
def __init__(self, *args, **kwargs):
"""Generates Java files from .proto files using the Wire protobuf compiler."""
super(WireGen, self).__init__(*args, **kwargs)
self.java_out = os.path.join(self.workdir, 'gen-java')

@property
def synthetic_target_type(self):
return JavaLibrary

def is_gentarget(self, target):
return isinstance(target, JavaWireLibrary)

def sources_generated_by_target(self, target):
genfiles = []
for source in target.sources_relative_to_source_root():
path = os.path.join(target.target_base, source)
genfiles.extend(self.calculate_genfiles(
path,
source,
target.payload.service_writer))
return genfiles

def resolve_deps(self, unresolved_deps):
deps = OrderedSet()
Expand All @@ -53,20 +67,10 @@ def resolve_deps(self, unresolved_deps):
raise self.DepLookupError('{message}\n on dependency {dep}'.format(message=e, dep=dep))
return deps

@property
def javadeps(self):
def synthetic_target_extra_dependencies(self, target):
return self.resolve_deps(self.get_options().javadeps)

def is_gentarget(self, target):
return isinstance(target, JavaWireLibrary)

def is_proto_target(self, target):
return isinstance(target, JavaProtobufLibrary)

def genlangs(self):
return {'java': lambda t: t.is_jvm}

def genlang(self, lang, targets):
def execute_codegen(self, targets):
# Invoke the generator once per target. Because the wire compiler has flags that try to reduce
# the amount of code emitted, Invoking them all together will break if one target specifies a
# service_writer and another does not, or if one specifies roots and another does not.
Expand All @@ -82,10 +86,7 @@ def genlang(self, lang, targets):
relative_sources.add(relative_source)
check_duplicate_conflicting_protos(self, sources_by_base, relative_sources, self.context.log)

if lang != 'java':
raise TaskError('Unrecognized wire gen lang: {0}'.format(lang))

args = ['--java_out={0}'.format(self.java_out)]
args = ['--java_out={0}'.format(self.codegen_workdir(target))]

# Add all params in payload to args

Expand Down Expand Up @@ -134,36 +135,6 @@ def add_to_gentargets(target):
sources_by_base[base].update(sources)
return sources_by_base

def createtarget(self, lang, gentarget, dependees):
if lang == 'java':
return self._create_java_target(gentarget, dependees)
else:
raise TaskError('Unrecognized wire gen lang: {0}'.format(lang))

def _create_java_target(self, target, dependees):
genfiles = []
for source in target.sources_relative_to_source_root():
path = os.path.join(target.target_base, source)
genfiles.extend(self.calculate_genfiles(
path,
source,
target.payload.service_writer).get('java', []))

spec_path = os.path.relpath(self.java_out, get_buildroot())
address = SyntheticAddress(spec_path, target.id)
deps = OrderedSet(self.javadeps)
tgt = self.context.add_new_target(address,
JavaLibrary,
derived_from=target,
sources=genfiles,
provides=target.provides,
dependencies=deps,
excludes=target.payload.excludes)
for dependee in dependees:
dependee.inject_dependency(tgt.address)
return tgt


def calculate_genfiles(self, path, source, service_writer):
protobuf_parse = ProtobufParse(path, source)
protobuf_parse.parse()
Expand All @@ -176,16 +147,11 @@ def calculate_genfiles(self, path, source, service_writer):
if protobuf_parse.extends:
types |= set(["Ext_{0}".format(protobuf_parse.filename)])

genfiles = defaultdict(set)
java_files = list(self.calculate_java_genfiles(protobuf_parse.package, types))
java_files = self.calculate_java_genfiles(protobuf_parse.package, types)
logger.debug('Path {path} yielded types {types} got files {java_files}'
.format(path=path, types=types, java_files=java_files))
genfiles['java'].update(java_files)
return genfiles
return set(java_files)

def calculate_java_genfiles(self, package, types):
basepath = package.replace('.', '/')
for type_ in types:
filename = os.path.join(basepath, '{0}.java'.format(type_))
logger.debug("Expecting {filename} from type {type_}".format(filename=filename, type_=type_))
yield filename
return [os.path.join(basepath, '{0}.java'.format(t)) for t in types]
Loading

0 comments on commit d0717e9

Please sign in to comment.