Skip to content

Commit

Permalink
Add sugar to getting values from UnionMembership (pantsbuild#9856)
Browse files Browse the repository at this point in the history
A common idiom is for us to get all union members for some type, e.g. getting all `ReplImplementations`.

This gives some sugar for less typing that also works with MyPy automatically, rather than needing casts.

[ci skip-rust-tests]
[ci skip-jvm-tests]
  • Loading branch information
Eric-Arellano authored May 22, 2020
1 parent 71bf50e commit ef9e054
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 35 deletions.
5 changes: 1 addition & 4 deletions src/python/pants/core/goals/fmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,10 +173,7 @@ async def fmt(
workspace: Workspace,
union_membership: UnionMembership,
) -> Fmt:
language_target_collection_types: Iterable[Type[LanguageFmtTargets]] = (
union_membership.union_rules[LanguageFmtTargets]
)

language_target_collection_types = union_membership[LanguageFmtTargets]
language_target_collections: Iterable[LanguageFmtTargets] = tuple(
language_target_collection_type(
TargetsWithOrigins(
Expand Down
7 changes: 2 additions & 5 deletions src/python/pants/core/goals/lint.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,17 +105,14 @@ async def lint(
options: LintOptions,
union_membership: UnionMembership,
) -> Lint:
field_set_collection_types: Iterable[Type[LinterFieldSets]] = union_membership.union_rules[
LinterFieldSets
]

field_set_collection_types = union_membership[LinterFieldSets]
field_set_collections: Iterable[LinterFieldSets] = tuple(
field_set_collection_type(
field_set_collection_type.field_set_type.create(target_with_origin)
for target_with_origin in targets_with_origins
if field_set_collection_type.field_set_type.is_valid(target_with_origin.target)
)
for field_set_collection_type in field_set_collection_types
for field_set_collection_type in union_membership[LinterFieldSets]
)
field_set_collections_with_sources: Iterable[FieldSetsWithSources] = await MultiGet(
Get[FieldSetsWithSources](FieldSetsWithSourcesRequest(field_set_collection))
Expand Down
18 changes: 6 additions & 12 deletions src/python/pants/core/goals/repl.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from abc import ABC
from dataclasses import dataclass
from pathlib import PurePath
from typing import ClassVar, Iterable, Tuple, Type, cast
from typing import ClassVar, Dict, Tuple, Type, cast

from pants.base.build_root import BuildRoot
from pants.engine.console import Console
Expand Down Expand Up @@ -74,21 +74,15 @@ async def run_repl(
union_membership: UnionMembership,
global_options: GlobalOptions,
) -> Repl:

# We can guarantee that we will only even enter this `goal_rule` if there exists an implementer
# of the `ReplImplementation` union because `LegacyGraphSession.run_goal_rules()` will not
# execute this rule's body if there are no implementations registered.
membership: Iterable[Type[ReplImplementation]] = union_membership.union_rules[
ReplImplementation
]
implementations = {impl.name: impl for impl in membership}

default_repl = "python"
repl_shell_name = cast(str, options.values.shell or default_repl)
repl_shell_name = cast(str, options.values.shell) or default_repl

implementations: Dict[str, Type[ReplImplementation]] = {
impl.name: impl for impl in union_membership[ReplImplementation]
}
repl_implementation_cls = implementations.get(repl_shell_name)
if repl_implementation_cls is None:
available = sorted(set(implementations.keys()))
available = sorted(implementations.keys())
console.print_stderr(
f"{repr(repl_shell_name)} is not a registered REPL. Available REPLs (which may "
f"be specified through the option `--repl-shell`): {available}"
Expand Down
2 changes: 1 addition & 1 deletion src/python/pants/core/goals/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ async def run_tests(
Type[CoverageData], Type[CoverageDataCollection]
] = {
collection_cls.element_type: collection_cls
for collection_cls in union_membership.union_rules[CoverageDataCollection]
for collection_cls in union_membership.get(CoverageDataCollection)
}
coverage_collections: List[CoverageDataCollection] = []
for data_cls, data in itertools.groupby(all_coverage_data, lambda data: type(data)):
Expand Down
13 changes: 3 additions & 10 deletions src/python/pants/engine/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -1337,9 +1337,7 @@ def can_generate(cls, output_type: Type["Sources"], union_membership: UnionMembe
This method is useful when you need to filter targets before hydrating them, such as how
you may filter targets via `tgt.has_field(MyField)`.
"""
generate_request_types: Iterable[
Type[GenerateSourcesRequest]
] = union_membership.union_rules.get(GenerateSourcesRequest, ())
generate_request_types = union_membership.get(GenerateSourcesRequest)
return any(
issubclass(cls, generate_request_type.input)
and issubclass(generate_request_type.output, output_type)
Expand Down Expand Up @@ -1460,9 +1458,7 @@ async def hydrate_sources(
# to determine if the sources_field is valid or not.
# We could alternatively use `sources_field.can_generate()`, but we want to error if there are
# 2+ generators due to ambiguity.
generate_request_types: Iterable[
Type[GenerateSourcesRequest]
] = union_membership.union_rules.get(GenerateSourcesRequest, ())
generate_request_types = union_membership.get(GenerateSourcesRequest)
relevant_generate_request_types = [
generate_request_type
for generate_request_type in generate_request_types
Expand Down Expand Up @@ -1618,10 +1614,7 @@ async def resolve_dependencies(
# Inject any dependencies. This is determined by the `request.field` class. For example, if
# there is a rule to inject for FortranDependencies, then FortranDependencies and any subclass
# of FortranDependencies will use that rule.
inject_request_types = cast(
Iterable[Type[InjectDependenciesRequest]],
union_membership.union_rules.get(InjectDependenciesRequest, ()),
)
inject_request_types = union_membership.get(InjectDependenciesRequest)
injected = await MultiGet(
Get[InjectedDependencies](InjectDependenciesRequest, inject_request_type(request.field))
for inject_request_type in inject_request_types
Expand Down
32 changes: 29 additions & 3 deletions src/python/pants/engine/unions.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
# Copyright 2020 Pants project contributors (see CONTRIBUTORS.md).
# Licensed under the Apache License, Version 2.0 (see LICENSE).

import typing
from dataclasses import dataclass
from typing import Iterable, Mapping, Type
from typing import Iterable, Mapping, Type, TypeVar

from pants.util.frozendict import FrozenDict
from pants.util.meta import decorated_type_checkable, frozen_after_init
Expand Down Expand Up @@ -49,6 +48,9 @@ def non_member_error_message(subject):
)


_T = TypeVar("_T")


@frozen_after_init
@dataclass(unsafe_hash=True)
class UnionMembership:
Expand All @@ -59,6 +61,30 @@ def __init__(self, union_rules: Mapping[Type, Iterable[Type]]) -> None:
{base: FrozenOrderedSet(members) for base, members in union_rules.items()}
)

def __getitem__(self, union_type: Type[_T]) -> FrozenOrderedSet[Type[_T]]:
"""Get all members of this union type.
If the union type does not exist because it has no members registered, this will raise an
IndexError.
Note that the type hint assumes that all union members will have subclassed the union type
- this is only a convention and is not actually enforced. So, you may have inaccurate type
hints.
"""
return self.union_rules[union_type]

def get(self, union_type: Type[_T]) -> FrozenOrderedSet[Type[_T]]:
"""Get all members of this union type.
If the union type does not exist because it has no members registered, return an empty
FrozenOrderedSet.
Note that the type hint assumes that all union members will have subclassed the union type
- this is only a convention and is not actually enforced. So, you may have inaccurate type
hints.
"""
return self.union_rules.get(union_type, FrozenOrderedSet()) # type: ignore[arg-type]

def is_member(self, union_type: Type, putative_member: Type) -> bool:
members = self.union_rules.get(union_type)
if members is None:
Expand All @@ -69,7 +95,7 @@ def has_members(self, union_type: Type) -> bool:
"""Check whether the union has an implementation or not."""
return bool(self.union_rules.get(union_type))

def has_members_for_all(self, union_types: typing.Iterable[Type]) -> bool:
def has_members_for_all(self, union_types: Iterable[Type]) -> bool:
"""Check whether every union given has an implementation or not."""
return all(self.has_members(union_type) for union_type in union_types)

Expand Down

0 comments on commit ef9e054

Please sign in to comment.