Skip to content

Commit

Permalink
Change AsyncField to be AsyncFieldMixin (pantsbuild#11232)
Browse files Browse the repository at this point in the history
By doing this, it's now possible to combine any arbitrary field template with being an async field. Previously, we duplicated the templates, like `AsyncStringSequenceField`. This also builds off of pantsbuild#11231 to make async fields far less magical - literally, all they do is add an `address: Address` field.

To land this, we remove both `@dataclass` and `ABC` from `Field`, which were causing unnecessary confusion and messing up inheritance. The dataclass was only generating a custom `__eq__` and `__hash__`, which is easy to set explicitly. We also remove both things from `Target` for simplicity.

[ci skip-rust]
[ci skip-build-wheels]
  • Loading branch information
Eric-Arellano authored Nov 24, 2020
1 parent 8bb2c31 commit dcb41a2
Show file tree
Hide file tree
Showing 4 changed files with 153 additions and 97 deletions.
4 changes: 2 additions & 2 deletions src/python/pants/engine/internals/graph_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,8 +776,8 @@ def test_sources_normal_hydration(sources_rule_runner: RuleRunner) -> None:
assert (
sources.filespec
== {
"includes": ["src/fortran/*.f03", "src/fortran/f1.f95"],
"excludes": ["src/fortran/**/ignore*", "src/fortran/ignored.f03"],
"includes": ["src/fortran/f1.f95", "src/fortran/*.f03"],
"excludes": ["src/fortran/ignored.f03", "src/fortran/**/ignore*"],
}
== hydrated_sources.filespec
)
Expand Down
152 changes: 78 additions & 74 deletions src/python/pants/engine/target.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright 2020 Pants project contributors (see CONTRIBUTORS.md).
# Licensed under the Apache License, Version 2.0 (see LICENSE).

from __future__ import annotations

import collections.abc
import dataclasses
import itertools
Expand Down Expand Up @@ -59,8 +61,7 @@


@frozen_after_init
@dataclass(unsafe_hash=True)
class Field(ABC):
class Field:
"""A Field.
The majority of fields should use field templates like `BoolField`, `StringField`, and
Expand All @@ -73,8 +74,8 @@ class Field(ABC):
(and should be immutable) so that this Field may be used by the engine. This means, for
example, using tuples rather than lists and using `FrozenOrderedSet` rather than `set`.
If you plan to use the engine to fully hydrate the value, you can instead subclass `AsyncField`
or one of its templates. This will store an `address: Address` property on the `Field`.
If you plan to use the engine to fully hydrate the value, you can also inherit
`AsyncFieldMixin`, which will store an `address: Address` property on the `Field` instance.
Subclasses should also override the type hints for `value` and `raw_value` to be more precise
than `Any`. The type hint for `raw_value` is used to generate documentation, e.g. for
Expand All @@ -99,20 +100,19 @@ def compute_value(cls, raw_value: Optional[int], *, address: Address) -> Optiona
return value_or_default
"""

# This exists on every Field subclass.
value: Optional[ImmutableValue]
# Subclasses must define these.
# Subclasses must define this.
alias: ClassVar[str]
# Subclasses must define at least one of these two.
default: ClassVar[ImmutableValue]
# Subclasses may define these.
required: ClassVar[bool] = False
# Subclasses may define these.
deprecated_removal_version: ClassVar[Optional[str]] = None
deprecated_removal_hint: ClassVar[Optional[str]] = None

@final
def __init__(self, raw_value: Optional[Any], *, address: Address) -> None:
self._check_deprecated(raw_value, address)
self.value = self.compute_value(raw_value, address=address)
self.value: Optional[ImmutableValue] = self.compute_value(raw_value, address=address)

@classmethod
def compute_value(cls, raw_value: Optional[Any], *, address: Address) -> ImmutableValue:
Expand Down Expand Up @@ -156,34 +156,35 @@ def __repr__(self) -> str:
def __str__(self) -> str:
return f"{self.alias}={self.value}"

def __hash__(self) -> int:
return hash((self.__class__, self.value))

@frozen_after_init
@dataclass(unsafe_hash=True)
class AsyncField(Field, metaclass=ABCMeta):
"""A field that needs the engine in order to be hydrated.
def __eq__(self, other: Union[Any, Field]) -> bool:
if not isinstance(other, Field):
return NotImplemented
return (self.__class__, self.value) == (other.__class__, other.value)


# NB: By subclassing `Field`, MyPy understands our type hints, and it means it doesn't matter which
# order you use for inheriting the field template vs. the mixin.
class AsyncFieldMixin(Field):
"""A mixin to store the field's original `Address` for use during hydration by the engine.
Typically, you should create a dataclass representing the hydrated value and another for the
request, then a rule to go from the request to the hydrated value. The request class should
store the `AsyncField` as a property.
Typically, you should also create a dataclass representing the hydrated value and another for
the request, then a rule to go from the request to the hydrated value. The request class should
store the async field as a property.
(Why use the request class as the rule input, rather than the field itself? It's a wrapper so
that subclasses of the AsyncField work properly, given that the engine uses exact type IDs.
that subclasses of the async field work properly, given that the engine uses exact type IDs.
This is like WrappedTarget.)
For example:
class Sources(AsyncField):
class Sources(StringSequenceField, AsyncFieldMixin):
alias = "sources"
value: Optional[Tuple[str, ...]]
def compute_value(
raw_value: Optional[List[str]], *, address: Address
) -> Optional[Tuple[str, ...]]:
...
# Example extension point provided by this field. Subclasses can override this to do
# whatever validation they'd like. Each AsyncField must define its own entry points
# like this to allow subclasses to change behavior.
# Often, async fields will want to define entry points like this to allow subclasses to
# change behavior.
def validate_resolved_files(self, files: Sequence[str]) -> None:
pass
Expand All @@ -206,24 +207,39 @@ def hydrate_sources(request: HydrateSourcesRequest) -> HydratedSources:
return HydratedSources(result)
Then, call sites can `await Get` if they need to hydrate the field, even if they subclassed
the original `AsyncField` to have custom behavior:
the original async field to have custom behavior:
sources1 = await Get(HydratedSources, HydrateSourcesRequest(my_tgt.get(Sources)))
sources2 = await Get(HydratedSources, HydrateSourcesRequest(custom_tgt.get(CustomSources)))
"""

address: Address

# We cheat here by ignoring the `@final` declaration and `@frozen_after_init`. We don't want to
# generalize this further because we want to avoid having `Field` subclasses start adding
# arbitrary fields.
@final # type: ignore[misc]
def __init__(self, raw_value: Optional[Any], *, address: Address) -> None:
super().__init__(raw_value, address=address)
# We must temporarily unfreeze the field, but then we refreeze to continue avoiding
# subclasses from adding arbitrary fields.
self._unfreeze_instance() # type: ignore[attr-defined]
self.address = address
self._freeze_instance() # type: ignore[attr-defined]

def __repr__(self) -> str:
return (
f"{self.__class__}(alias={repr(self.alias)}, address={self.address}, "
"value={repr(self.value)}, default={repr(self.default)})"
)

def __hash__(self) -> int:
return hash((self.__class__, self.value, self.address))

def __eq__(self, other: Union[Any, AsyncFieldMixin]) -> bool:
if not isinstance(other, AsyncFieldMixin):
return NotImplemented
return (self.__class__, self.value, self.address) == (
other.__class__,
other.value,
other.address,
)


# -----------------------------------------------------------------------------------------------
# Core Target abstractions
Expand All @@ -235,8 +251,7 @@ def __init__(self, raw_value: Optional[Any], *, address: Address) -> None:


@frozen_after_init
@dataclass(unsafe_hash=True)
class Target(ABC):
class Target:
"""A Target represents a combination of fields that are valid _together_."""

# Subclasses must define these
Expand Down Expand Up @@ -332,6 +347,18 @@ def __str__(self) -> str:
address = f"address=\"{self.address}\"{', ' if fields else ''}"
return f"{self.alias}({address}{fields})"

def __hash__(self) -> int:
return hash((self.__class__, self.address, self.field_values))

def __eq__(self, other: Union[Target, Any]) -> bool:
if not isinstance(other, Target):
return NotImplemented
return (self.__class__, self.address, self.field_values) == (
other.__class__,
other.address,
other.field_values,
)

@final
@classmethod
def _find_plugin_fields(cls, union_membership: UnionMembership) -> Tuple[Type[Field], ...]:
Expand Down Expand Up @@ -397,9 +424,9 @@ def get(self, field: Type[_F], *, default_raw_value: Optional[Any] = None) -> _F
"""Get the requested `Field` instance belonging to this target.
This will return an instance of the requested field type, e.g. an instance of
`Compatibility`, `Sources`, `EntryPoint`, etc. Usually, you will want to grab the
`Field`'s inner value, e.g. `tgt.get(Compatibility).value`. (For `AsyncField`s, you would
call `await Get(SourcesResult, SourcesRequest, tgt.get(Sources).request)`).
`InterpreterConstraints`, `Sources`, `EntryPoint`, etc. Usually, you will want to grab the
`Field`'s inner value, e.g. `tgt.get(Compatibility).value`. (For async fields like
`Sources`, you may need to hydrate the value.).
This works with subclasses of `Field`s. For example, if you subclass `Sources` to define a
custom subclass `PythonSources`, both `python_tgt.get(PythonSources)` and
Expand Down Expand Up @@ -965,7 +992,7 @@ def __init__(
T = TypeVar("T")


class ScalarField(Generic[T], Field, metaclass=ABCMeta):
class ScalarField(Generic[T], Field):
"""A field with a scalar value (vs. a compound value like a sequence or dict).
Subclasses must define the class properties `expected_type` and `expected_type_description`.
Expand Down Expand Up @@ -1002,7 +1029,7 @@ def compute_value(cls, raw_value: Optional[Any], *, address: Address) -> Optiona
return value_or_default


class BoolField(Field, metaclass=ABCMeta):
class BoolField(Field):
"""A field whose value is a boolean.
If subclasses do not set the class property `required = True` or `default`, the value will
Expand All @@ -1029,7 +1056,7 @@ def compute_value(cls, raw_value: Optional[bool], *, address: Address) -> Option
return value_or_default


class IntField(ScalarField[int], metaclass=ABCMeta):
class IntField(ScalarField[int]):
expected_type = int
expected_type_description = "an integer"

Expand All @@ -1038,7 +1065,7 @@ def compute_value(cls, raw_value: Optional[int], *, address: Address) -> Optiona
return super().compute_value(raw_value, address=address)


class FloatField(ScalarField[float], metaclass=ABCMeta):
class FloatField(ScalarField[float]):
expected_type = float
expected_type_description = "a float"

Expand All @@ -1047,7 +1074,7 @@ def compute_value(cls, raw_value: Optional[float], *, address: Address) -> Optio
return super().compute_value(raw_value, address=address)


class StringField(ScalarField[str], metaclass=ABCMeta):
class StringField(ScalarField[str]):
"""A field whose value is a string.
If you expect the string to only be one of several values, set the class property
Expand All @@ -1074,7 +1101,7 @@ def compute_value(cls, raw_value: Optional[str], *, address: Address) -> Optiona
return value_or_default


class SequenceField(Generic[T], Field, metaclass=ABCMeta):
class SequenceField(Generic[T], Field):
"""A field whose value is a homogeneous sequence.
Subclasses must define the class properties `expected_element_type` and
Expand Down Expand Up @@ -1117,7 +1144,7 @@ def compute_value(
return tuple(value_or_default)


class StringSequenceField(SequenceField[str], metaclass=ABCMeta):
class StringSequenceField(SequenceField[str]):
expected_element_type = str
expected_type_description = "an iterable of strings (e.g. a list of strings)"

Expand All @@ -1128,7 +1155,7 @@ def compute_value(
return super().compute_value(raw_value, address=address)


class StringOrStringSequenceField(SequenceField[str], metaclass=ABCMeta):
class StringOrStringSequenceField(SequenceField[str]):
"""The raw_value may either be a string or be an iterable of strings.
This is syntactic sugar that we use for certain fields to make BUILD files simpler when the user
Expand All @@ -1151,7 +1178,7 @@ def compute_value(
return super().compute_value(raw_value, address=address)


class DictStringToStringField(Field, metaclass=ABCMeta):
class DictStringToStringField(Field):
value: Optional[FrozenDict[str, str]]
default: ClassVar[Optional[FrozenDict[str, str]]] = None

Expand All @@ -1172,7 +1199,7 @@ def compute_value(
return FrozenDict(value_or_default)


class DictStringToStringSequenceField(Field, metaclass=ABCMeta):
class DictStringToStringSequenceField(Field):
value: Optional[FrozenDict[str, Tuple[str, ...]]]
default: ClassVar[Optional[FrozenDict[str, Tuple[str, ...]]]] = None

Expand Down Expand Up @@ -1202,35 +1229,12 @@ def compute_value(
return FrozenDict(result)


class AsyncStringSequenceField(AsyncField):
value: Optional[Tuple[str, ...]]
default: ClassVar[Optional[Tuple[str, ...]]] = None

@classmethod
def compute_value(
cls, raw_value: Optional[Iterable[str]], *, address: Address
) -> Optional[Tuple[str, ...]]:
value_or_default = super().compute_value(raw_value, address=address)
if value_or_default is None:
return None
try:
ensure_str_list(value_or_default)
except ValueError:
raise InvalidFieldTypeException(
address,
cls.alias,
value_or_default,
expected_type="an iterable of strings (e.g. a list of strings)",
)
return tuple(sorted(value_or_default))


# -----------------------------------------------------------------------------------------------
# Sources and codegen
# -----------------------------------------------------------------------------------------------


class Sources(AsyncStringSequenceField):
class Sources(StringSequenceField, AsyncFieldMixin):
"""A list of files and globs that belong to this target.
Paths are relative to the BUILD file's directory. You can ignore files/globs by prefixing them
Expand Down Expand Up @@ -1502,7 +1506,7 @@ def debug_hint(self) -> str:
# NB: To hydrate the dependencies, use one of:
# await Get(Addresses, DependenciesRequest(tgt[Dependencies])
# await Get(Targets, DependenciesRequest(tgt[Dependencies])
class Dependencies(AsyncStringSequenceField):
class Dependencies(StringSequenceField, AsyncFieldMixin):
"""Addresses to other targets that this target depends on, e.g. ['helloworld/subdir:lib'].
Alternatively, you may include file names. Pants will find which target owns that file, and
Expand Down Expand Up @@ -1674,7 +1678,7 @@ def __iter__(self) -> Iterator[Address]:
return iter(self.dependencies)


class SpecialCasedDependencies(AsyncStringSequenceField):
class SpecialCasedDependencies(StringSequenceField, AsyncFieldMixin):
"""Subclass this for fields that act similarly to the `dependencies` field, but are handled
differently than normal dependencies.
Expand Down
Loading

0 comments on commit dcb41a2

Please sign in to comment.