Skip to content

Commit

Permalink
Fix type hints for UnionMembership.get() (pantsbuild#12092)
Browse files Browse the repository at this point in the history
MyPy wasn't correctly inferring the type of `UnionMembership[MyType]` and `UnionMembership.get(MyType)`.

[ci skip-rust]
[ci skip-build-wheels]
  • Loading branch information
Eric-Arellano authored May 20, 2021
1 parent c325cdc commit 5b111c8
Show file tree
Hide file tree
Showing 13 changed files with 47 additions and 49 deletions.
9 changes: 4 additions & 5 deletions src/python/pants/backend/go/lint/fmt.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# Copyright 2021 Pants project contributors (see CONTRIBUTORS.md).
# Licensed under the Apache License, Version 2.0 (see LICENSE).

from dataclasses import dataclass
from typing import Iterable, List, Type
from typing import Iterable

from pants.backend.go.target_types import GoSources
from pants.core.goals.fmt import EnrichedFmtResult, LanguageFmtResults, LanguageFmtTargets
Expand Down Expand Up @@ -35,10 +36,8 @@ async def format_golang_targets(
)
prior_formatter_result = original_sources.snapshot

results: List[EnrichedFmtResult] = []
fmt_request_types: Iterable[Type[GoLangFmtRequest]] = union_membership.union_rules[
GoLangFmtRequest
]
results = []
fmt_request_types: Iterable[type[StyleRequest]] = union_membership[GoLangFmtRequest]
for fmt_request_type in fmt_request_types:
result = await Get(
EnrichedFmtResult,
Expand Down
4 changes: 2 additions & 2 deletions src/python/pants/backend/python/goals/pytest_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,11 +145,11 @@ async def run_all_setup_plugins(
wrapped_tgt = await Get(WrappedTarget, Address, request.address)
applicable_setup_request_types = tuple(
request
for request in union_membership.get(PytestPluginSetupRequest) # type: ignore[misc]
for request in union_membership.get(PytestPluginSetupRequest)
if request.is_applicable(wrapped_tgt.target)
)
setups = await MultiGet(
Get(PytestPluginSetup, PytestPluginSetupRequest, request(wrapped_tgt.target)) # type: ignore[misc]
Get(PytestPluginSetup, PytestPluginSetupRequest, request(wrapped_tgt.target)) # type: ignore[misc, abstract]
for request in applicable_setup_request_types
)
return AllPytestPluginSetups(setups)
Expand Down
4 changes: 2 additions & 2 deletions src/python/pants/backend/python/goals/setup_py.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ async def determine_setup_kwargs(
exported_target: ExportedTarget, union_membership: UnionMembership
) -> SetupKwargs:
target = exported_target.target
setup_kwargs_requests = union_membership.get(SetupKwargsRequest) # type: ignore[misc]
setup_kwargs_requests = union_membership.get(SetupKwargsRequest)
applicable_setup_kwargs_requests = tuple(
request for request in setup_kwargs_requests if request.is_applicable(target)
)
Expand All @@ -459,7 +459,7 @@ async def determine_setup_kwargs(
"precise so that only one implementation is applicable for this target."
)
setup_kwargs_request = tuple(applicable_setup_kwargs_requests)[0]
return await Get(SetupKwargs, SetupKwargsRequest, setup_kwargs_request(target))
return await Get(SetupKwargs, SetupKwargsRequest, setup_kwargs_request(target)) # type: ignore[abstract]


@rule
Expand Down
3 changes: 2 additions & 1 deletion src/python/pants/backend/python/lint/python_fmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Licensed under the Apache License, Version 2.0 (see LICENSE).

from dataclasses import dataclass
from typing import Iterable

from pants.backend.python.target_types import PythonSources
from pants.core.goals.fmt import EnrichedFmtResult, LanguageFmtResults, LanguageFmtTargets
Expand Down Expand Up @@ -33,7 +34,7 @@ async def format_python_target(
prior_formatter_result = original_sources.snapshot

results = []
fmt_request_types = union_membership.union_rules[PythonFmtRequest]
fmt_request_types: Iterable[type[StyleRequest]] = union_membership[PythonFmtRequest]
for fmt_request_type in fmt_request_types:
result = await Get(
EnrichedFmtResult,
Expand Down
2 changes: 1 addition & 1 deletion src/python/pants/backend/python/util_rules/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,5 @@ python_tests(
# We shell out to pex in tests; so we have a dependency, but not an explicit import.
"3rdparty/python:pex",
],
timeout=90,
timeout=120,
)
3 changes: 2 additions & 1 deletion src/python/pants/backend/shell/lint/shell_fmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Iterable

from pants.backend.shell.target_types import ShellSources
from pants.core.goals.fmt import EnrichedFmtResult, LanguageFmtResults, LanguageFmtTargets
Expand Down Expand Up @@ -35,7 +36,7 @@ async def format_shell_targets(
prior_formatter_result = original_sources.snapshot

results = []
fmt_request_types = union_membership.union_rules[ShellFmtRequest]
fmt_request_types: Iterable[type[StyleRequest]] = union_membership[ShellFmtRequest]
for fmt_request_type in fmt_request_types:
result = await Get(
EnrichedFmtResult,
Expand Down
10 changes: 5 additions & 5 deletions src/python/pants/core/goals/fmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import itertools
from collections import defaultdict
from dataclasses import dataclass
from typing import ClassVar, Iterable, List, Optional, Tuple, Type, TypeVar, cast
from typing import ClassVar, Optional, Tuple, Type, TypeVar, cast

from pants.core.util_rules.filter_empty_sources import TargetsWithSources, TargetsWithSourcesRequest
from pants.engine.console import Console
Expand Down Expand Up @@ -196,7 +196,7 @@ async def fmt(
union_membership: UnionMembership,
) -> Fmt:
language_target_collection_types = union_membership[LanguageFmtTargets]
language_target_collections: Iterable[LanguageFmtTargets] = tuple(
language_target_collections = tuple(
language_target_collection_type(
Targets(
target
Expand All @@ -206,7 +206,7 @@ async def fmt(
)
for language_target_collection_type in language_target_collection_types
)
targets_with_sources: Iterable[TargetsWithSources] = await MultiGet(
targets_with_sources = await MultiGet(
Get(
TargetsWithSources,
TargetsWithSourcesRequest(language_target_collection.targets),
Expand All @@ -216,7 +216,7 @@ async def fmt(
# NB: We must convert back the generic TargetsWithSources objects back into their
# corresponding LanguageFmtTargets, e.g. back to PythonFmtTargets, in order for the union
# rule to work.
valid_language_target_collections: Iterable[LanguageFmtTargets] = tuple(
valid_language_target_collections = tuple(
language_target_collection_cls(
Targets(
target
Expand Down Expand Up @@ -246,7 +246,7 @@ async def fmt(
for language_target_collection in valid_language_target_collections
)

individual_results: List[FmtResult] = list(
individual_results = list(
itertools.chain.from_iterable(
language_result.results for language_result in per_language_results
)
Expand Down
8 changes: 4 additions & 4 deletions src/python/pants/core/goals/lint.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,20 +211,20 @@ async def lint(
lint_subsystem: LintSubsystem,
union_membership: UnionMembership,
) -> Lint:
request_types = union_membership[LintRequest]
requests: Iterable[StyleRequest] = tuple(
request_types = cast("Iterable[type[StyleRequest]]", union_membership[LintRequest])
requests = tuple(
request_type(
request_type.field_set_type.create(target)
for target in targets
if request_type.field_set_type.is_applicable(target)
)
for request_type in request_types
)
field_sets_with_sources: Iterable[FieldSetsWithSources] = await MultiGet(
field_sets_with_sources = await MultiGet(
Get(FieldSetsWithSources, FieldSetsWithSourcesRequest(request.field_sets))
for request in requests
)
valid_requests: Iterable[StyleRequest] = tuple(
valid_requests = tuple(
request_cls(request)
for request_cls, request in zip(request_types, field_sets_with_sources)
if request
Expand Down
7 changes: 2 additions & 5 deletions src/python/pants/core/goals/repl.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from abc import ABC
from dataclasses import dataclass
from pathlib import PurePath
from typing import ClassVar, Dict, Iterable, Mapping, Optional, Tuple, Type, cast
from typing import ClassVar, Iterable, Mapping, Optional, Tuple, cast

from pants.base.build_root import BuildRoot
from pants.engine.addresses import Addresses
Expand Down Expand Up @@ -101,10 +101,7 @@ async def run_repl(
# TODO: When we support multiple languages, detect the default repl to use based
# on the targets. For now we default to the python repl.
repl_shell_name = repl_subsystem.shell or "python"

implementations: Dict[str, Type[ReplImplementation]] = {
impl.name: impl for impl in union_membership[ReplImplementation]
}
implementations = {impl.name: impl for impl in union_membership[ReplImplementation]}
repl_implementation_cls = implementations.get(repl_shell_name)
if repl_implementation_cls is None:
available = sorted(implementations.keys())
Expand Down
14 changes: 6 additions & 8 deletions src/python/pants/core/goals/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from dataclasses import dataclass
from enum import Enum
from pathlib import PurePath
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union, cast
from typing import Any, ClassVar, Dict, List, Optional, Tuple, TypeVar, Union, cast

from pants.core.goals.package import BuiltPackage, PackageFieldSet
from pants.core.util_rules.distdir import DistDir
Expand Down Expand Up @@ -202,7 +202,7 @@ class CoverageData(ABC):

@union
class CoverageDataCollection(Collection[_CD]):
element_type: Type[_CD]
element_type: ClassVar[type[_CD]]


class CoverageReport(ABC):
Expand Down Expand Up @@ -450,13 +450,11 @@ async def run_tests(
key=lambda cov_data: str(type(cov_data)),
)

coverage_types_to_collection_types: Dict[
Type[CoverageData], Type[CoverageDataCollection]
] = {
collection_cls.element_type: collection_cls
coverage_types_to_collection_types = {
collection_cls.element_type: collection_cls # type: ignore[misc]
for collection_cls in union_membership.get(CoverageDataCollection)
}
coverage_collections: List[CoverageDataCollection] = []
coverage_collections = []
for data_cls, data in itertools.groupby(all_coverage_data, lambda data: type(data)):
collection_cls = coverage_types_to_collection_types[data_cls]
coverage_collections.append(collection_cls(data))
Expand All @@ -466,7 +464,7 @@ async def run_tests(
for coverage_collection in coverage_collections
)

coverage_report_files: List[PurePath] = []
coverage_report_files: list[PurePath] = []
for coverage_reports in coverage_reports_collections:
report_files = coverage_reports.materialize(console, workspace)
coverage_report_files.extend(report_files)
Expand Down
20 changes: 11 additions & 9 deletions src/python/pants/core/goals/typecheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Dict, Iterable, Optional, Tuple
from typing import Any, Dict, Iterable, Optional, Tuple, cast

from pants.core.goals.style_request import StyleRequest
from pants.core.util_rules.filter_empty_sources import (
Expand Down Expand Up @@ -146,20 +146,22 @@ class Typecheck(Goal):
async def typecheck(
console: Console, targets: Targets, union_membership: UnionMembership
) -> Typecheck:
typecheck_request_types = union_membership[TypecheckRequest]
requests: Iterable[StyleRequest] = tuple(
lint_request_type(
lint_request_type.field_set_type.create(target)
typecheck_request_types = cast(
"Iterable[type[StyleRequest]]", union_membership[TypecheckRequest]
)
requests = tuple(
typecheck_request_type(
typecheck_request_type.field_set_type.create(target)
for target in targets
if lint_request_type.field_set_type.is_applicable(target)
if typecheck_request_type.field_set_type.is_applicable(target)
)
for lint_request_type in typecheck_request_types
for typecheck_request_type in typecheck_request_types
)
field_sets_with_sources: Iterable[FieldSetsWithSources] = await MultiGet(
field_sets_with_sources = await MultiGet(
Get(FieldSetsWithSources, FieldSetsWithSourcesRequest(request.field_sets))
for request in requests
)
valid_requests: Iterable[StyleRequest] = tuple(
valid_requests = tuple(
request_cls(request)
for request_cls, request in zip(typecheck_request_types, field_sets_with_sources)
if request
Expand Down
2 changes: 1 addition & 1 deletion src/python/pants/engine/internals/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1000,7 +1000,7 @@ async def find_valid_field_sets_for_target_roots(
specs,
union_membership,
registered_target_types,
field_set_types=union_membership.union_rules[request.field_set_superclass],
field_set_types=union_membership[request.field_set_superclass],
goal_description=request.goal_description,
)
if request.no_applicable_targets_behavior == NoApplicableTargetsBehavior.error:
Expand Down
10 changes: 5 additions & 5 deletions src/python/pants/engine/unions.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __post_init__(self) -> None:
raise ValueError(msg)


_T = TypeVar("_T")
_T = TypeVar("_T", bound=type)


@frozen_after_init
Expand All @@ -79,7 +79,7 @@ def __init__(self, union_rules: Mapping[Type, Iterable[Type]]) -> None:
{base: FrozenOrderedSet(members) for base, members in union_rules.items()}
)

def __getitem__(self, union_type: Type[_T]) -> FrozenOrderedSet[Type[_T]]:
def __getitem__(self, union_type: _T) -> FrozenOrderedSet[_T]:
"""Get all members of this union type.
If the union type does not exist because it has no members registered, this will raise an
Expand All @@ -89,9 +89,9 @@ def __getitem__(self, union_type: Type[_T]) -> FrozenOrderedSet[Type[_T]]:
- this is only a convention and is not actually enforced. So, you may have inaccurate type
hints.
"""
return self.union_rules[union_type]
return self.union_rules[union_type] # type: ignore[return-value]

def get(self, union_type: Type[_T]) -> FrozenOrderedSet[Type[_T]]:
def get(self, union_type: _T) -> FrozenOrderedSet[_T]:
"""Get all members of this union type.
If the union type does not exist because it has no members registered, return an empty
Expand All @@ -101,7 +101,7 @@ def get(self, union_type: Type[_T]) -> FrozenOrderedSet[Type[_T]]:
- this is only a convention and is not actually enforced. So, you may have inaccurate type
hints.
"""
return self.union_rules.get(union_type, FrozenOrderedSet())
return self.union_rules.get(union_type, FrozenOrderedSet()) # type: ignore[return-value]

def is_member(self, union_type: Type, putative_member: Type) -> bool:
members = self.union_rules.get(union_type)
Expand Down

0 comments on commit 5b111c8

Please sign in to comment.