Skip to content

Commit

Permalink
Remove PrimitiveField and align AsyncField with Field (pantsbui…
Browse files Browse the repository at this point in the history
…ld#11231)

Broken out of pantsbuild#10932.

The difference between `AsyncField` and `PrimitiveField` is fairly artificial and it creates more complexity than we need. Really, the only difference is a) storing the `Address` and b) an expression of intent.

This makes `Field` act like `PrimitiveField`, which is now deleted. `AsyncField` now uses the same naming as `PrimitiveField` used to, i.e.`value` instead of `sanitized_raw_value` and `compute_value()` instead of `sanitize_raw_value()`.

This also cleans up some of the relevant `target_test.py`. We were testing something that's no longer worth it, as it's outside the scope of the code and we have multiple places we test this pattern. We add a new test that `__eq__` and `__hash__` are correct.

[ci skip-rust]
[ci skip-build-wheels]
  • Loading branch information
Eric-Arellano authored Nov 24, 2020
1 parent 85f43eb commit 82fc6ca
Show file tree
Hide file tree
Showing 8 changed files with 162 additions and 255 deletions.
4 changes: 2 additions & 2 deletions src/python/pants/backend/python/target_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@
BoolField,
Dependencies,
DictStringToStringSequenceField,
Field,
InjectDependenciesRequest,
InjectedDependencies,
IntField,
InvalidFieldException,
InvalidFieldTypeException,
PrimitiveField,
ProvidesField,
ScalarField,
Sources,
Expand Down Expand Up @@ -491,7 +491,7 @@ def format_invalid_requirement_string_error(
)


class PythonRequirementsField(PrimitiveField):
class PythonRequirementsField(Field):
"""A sequence of pip-style requirement strings, e.g. ['foo==1.8', 'bar<=3 ;
python_version<'3']."""

Expand Down
2 changes: 1 addition & 1 deletion src/python/pants/engine/internals/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,7 @@ def parse_dependencies_field(

addresses: List[AddressInput] = []
ignored_addresses: List[AddressInput] = []
for v in field.sanitized_raw_value or ():
for v in field.value or ():
is_ignore = v.startswith("!")
if is_ignore:
# Check if it's a transitive exclude, rather than a direct exclude.
Expand Down
2 changes: 1 addition & 1 deletion src/python/pants/engine/internals/graph_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,7 +828,7 @@ class DefaultSources(Sources):
# than the normal `all` conjunction.
sources_rule_runner.create_files("src/fortran", files=["default.f95", "f1.f08", "ignored.f08"])
sources = DefaultSources(None, address=addr)
assert set(sources.sanitized_raw_value or ()) == set(DefaultSources.default)
assert set(sources.value or ()) == set(DefaultSources.default)

hydrated_sources = sources_rule_runner.request(
HydratedSources, [HydrateSourcesRequest(sources)]
Expand Down
180 changes: 72 additions & 108 deletions src/python/pants/engine/target.py

Large diffs are not rendered by default.

180 changes: 53 additions & 127 deletions src/python/pants/engine/target_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,23 @@

from dataclasses import dataclass
from enum import Enum
from pathlib import PurePath
from typing import Any, Dict, Iterable, List, Optional, Tuple
from typing import Any, Dict, Iterable, Optional, Tuple

import pytest
from typing_extensions import final

from pants.engine.addresses import Address
from pants.engine.fs import EMPTY_DIGEST, PathGlobs, Snapshot
from pants.engine.rules import Get, rule
from pants.engine.target import (
AsyncField,
AsyncStringSequenceField,
BoolField,
Dependencies,
DictStringToStringField,
DictStringToStringSequenceField,
Field,
FieldSet,
InvalidFieldChoiceException,
InvalidFieldException,
InvalidFieldTypeException,
PrimitiveField,
RequiredFieldMissingException,
ScalarField,
SequenceField,
Expand All @@ -37,17 +33,16 @@
generate_subtarget_address,
)
from pants.engine.unions import UnionMembership
from pants.testutil.rule_runner import MockGet, run_rule_with_mocks
from pants.util.collections import ensure_str_list
from pants.util.frozendict import FrozenDict
from pants.util.meta import FrozenInstanceError
from pants.util.ordered_set import OrderedSet

# -----------------------------------------------------------------------------------------------
# Test core Field and Target abstractions
# -----------------------------------------------------------------------------------------------


class FortranExtensions(PrimitiveField):
class FortranExtensions(Field):
alias = "fortran_extensions"
value: Tuple[str, ...]
default = ()
Expand All @@ -69,55 +64,18 @@ def compute_value(
return tuple(value_or_default)


class FortranVersion(StringField):
alias = "version"


class UnrelatedField(BoolField):
alias = "unrelated"
default = False


class FortranSources(AsyncField):
alias = "sources"
sanitized_raw_value: Optional[Tuple[str, ...]]
default = None

@classmethod
def sanitize_raw_value(
cls, raw_value: Optional[Iterable[str]], address: Address
) -> Optional[Tuple[str, ...]]:
value_or_default = super().sanitize_raw_value(raw_value, address=address)
if value_or_default is None:
return None
return tuple(ensure_str_list(value_or_default))


@dataclass(frozen=True)
class FortranSourcesRequest:
field: FortranSources


@dataclass(frozen=True)
class FortranSourcesResult:
snapshot: Snapshot


@rule
async def hydrate_fortran_sources(request: FortranSourcesRequest) -> FortranSourcesResult:
sources_field = request.field
result = await Get(Snapshot, PathGlobs(sources_field.sanitized_raw_value or ()))
# Validate after hydration
non_fortran_sources = [
fp for fp in result.files if PurePath(fp).suffix not in (".f95", ".f03", ".f08")
]
if non_fortran_sources:
raise ValueError(
f"Received non-Fortran sources in {sources_field.alias} for target "
f"{sources_field.address}: {non_fortran_sources}."
)
return FortranSourcesResult(result)


class FortranTarget(Target):
alias = "fortran"
core_fields = (FortranExtensions, FortranSources)
core_fields = (FortranExtensions, FortranVersion)


def test_invalid_fields_rejected() -> None:
Expand All @@ -139,11 +97,6 @@ def test_get_field() -> None:

# Default field value. This happens when the field is registered on the target type, but the
# user does not explicitly set the field in the BUILD file.
#
# NB: `default_raw_value` is not used in this case - that parameter is solely used when
# the field is not registered on the target type. To override the default field value, either
# subclass the Field and create a new target, or, in your call site, interpret the result and
# and apply your default.
default_field_tgt = FortranTarget({}, address=Address("", target_name="default"))
assert default_field_tgt[FortranExtensions].value == ()
assert default_field_tgt.get(FortranExtensions).value == ()
Expand Down Expand Up @@ -171,7 +124,7 @@ def test_get_field() -> None:
).value == (not UnrelatedField.default)


def test_primitive_field_hydration_is_eager() -> None:
def test_field_hydration_is_eager() -> None:
with pytest.raises(InvalidFieldException) as exc:
FortranTarget(
{FortranExtensions.alias: ["FortranExt1", "DoesNotStartWithFortran"]},
Expand All @@ -185,10 +138,10 @@ def test_has_fields() -> None:
empty_union_membership = UnionMembership({})
tgt = FortranTarget({}, address=Address("", target_name="lib"))

assert tgt.field_types == (FortranExtensions, FortranSources)
assert tgt.field_types == (FortranExtensions, FortranVersion)
assert FortranTarget.class_field_types(union_membership=empty_union_membership) == (
FortranExtensions,
FortranSources,
FortranVersion,
)

assert tgt.has_fields([]) is True
Expand Down Expand Up @@ -225,47 +178,6 @@ def test_has_fields() -> None:
)


def test_async_field() -> None:
def hydrate_field(
*, raw_source_files: List[str], hydrated_source_files: Tuple[str, ...]
) -> FortranSourcesResult:
sources_field = FortranTarget(
{FortranSources.alias: raw_source_files}, address=Address("", target_name="lib")
)[FortranSources]
result: FortranSourcesResult = run_rule_with_mocks(
hydrate_fortran_sources,
rule_args=[FortranSourcesRequest(sources_field)],
mock_gets=[
MockGet(
output_type=Snapshot,
input_type=PathGlobs,
mock=lambda _: Snapshot(EMPTY_DIGEST, files=hydrated_source_files, dirs=()),
)
],
)
return result

# Normal field
expected_files = ("important.f95", "big_banks.f08", "big_loans.f08")
assert (
hydrate_field(
raw_source_files=["important.f95", "big_*.f08"], hydrated_source_files=expected_files
).snapshot.files
== expected_files
)

# Test that `raw_value` gets sanitized/validated eagerly.
with pytest.raises(ValueError) as exc:
FortranTarget({FortranSources.alias: [0, 1, 2]}, address=Address("", target_name="lib"))
assert "Not all elements of the iterable have type" in str(exc)

# Test post-hydration validation.
with pytest.raises(ValueError) as exc:
hydrate_field(raw_source_files=["*.js"], hydrated_source_files=("not_fortran.js",))
assert "Received non-Fortran sources" in str(exc)
assert "//:lib" in str(exc)


def test_add_custom_fields() -> None:
class CustomField(BoolField):
alias = "custom_field"
Expand All @@ -279,14 +191,14 @@ class CustomField(BoolField):
tgt_values, address=Address("", target_name="lib"), union_membership=union_membership
)

assert tgt.field_types == (FortranExtensions, FortranSources, CustomField)
assert tgt.core_fields == (FortranExtensions, FortranSources)
assert tgt.field_types == (FortranExtensions, FortranVersion, CustomField)
assert tgt.core_fields == (FortranExtensions, FortranVersion)
assert tgt.plugin_fields == (CustomField,)
assert tgt.has_field(CustomField) is True

assert FortranTarget.class_field_types(union_membership=union_membership) == (
FortranExtensions,
FortranSources,
FortranVersion,
CustomField,
)
assert FortranTarget.class_has_field(CustomField, union_membership=union_membership) is True
Expand Down Expand Up @@ -392,37 +304,51 @@ class CustomFortranTarget(Target):


def test_required_field() -> None:
class RequiredPrimitiveField(StringField):
alias = "primitive"
required = True

class RequiredAsyncField(AsyncField):
alias = "async"
class RequiredField(StringField):
alias = "field"
required = True

@final
@property
def request(self):
raise NotImplementedError

class RequiredTarget(Target):
alias = "required_target"
core_fields = (RequiredPrimitiveField, RequiredAsyncField)
core_fields = (RequiredField,)

address = Address("", target_name="lib")

# No errors when all defined
RequiredTarget({"primitive": "present", "async": 0}, address=address)
# No errors when defined
RequiredTarget({"field": "present"}, address=address)

with pytest.raises(RequiredFieldMissingException) as exc:
RequiredTarget({"primitive": "present"}, address=address)
RequiredTarget({}, address=address)
assert str(address) in str(exc.value)
assert "async" in str(exc.value)
assert "field" in str(exc.value)

with pytest.raises(RequiredFieldMissingException) as exc:
RequiredTarget({"async": 0}, address=address)
assert str(address) in str(exc.value)
assert "primitive" in str(exc.value)

def test_async_field() -> None:
class ExampleField(AsyncField):
alias = "field"
default = 10

addr = Address("", target_name="tgt")
field = ExampleField(None, address=addr)
assert field.value == 10
assert field.address == addr

# Ensure equality and __hash__ work correctly.
other = ExampleField(None, address=addr)
assert field == other
assert hash(field) == hash(other)

other = ExampleField(25, address=addr)
assert field != other
assert hash(field) != hash(other)

other = ExampleField(None, address=Address("", target_name="other"))
assert field != other
assert hash(field) != hash(other)

# Ensure it's still frozen.
with pytest.raises(FrozenInstanceError):
field.y = "foo" # type: ignore[attr-defined]


# -----------------------------------------------------------------------------------------------
Expand Down Expand Up @@ -507,9 +433,9 @@ class NoFieldsTarget(Target):

@dataclass(frozen=True)
class FortranFieldSet(FieldSet):
required_fields = (FortranSources,)
required_fields = (FortranVersion,)

sources: FortranSources
version: FortranVersion
unrelated_field: UnrelatedField

@dataclass(frozen=True)
Expand Down Expand Up @@ -727,8 +653,8 @@ class Example(AsyncStringSequenceField):
alias = "example"

addr = Address("", target_name="example")
assert Example(["hello", "world"], address=addr).sanitized_raw_value == ("hello", "world")
assert Example(None, address=addr).sanitized_raw_value is None
assert Example(["hello", "world"], address=addr).value == ("hello", "world")
assert Example(None, address=addr).value is None
with pytest.raises(InvalidFieldTypeException):
Example("strings are technically iterable...", address=addr)
with pytest.raises(InvalidFieldTypeException):
Expand Down
Loading

0 comments on commit 82fc6ca

Please sign in to comment.