Skip to content

Commit

Permalink
Disambiguate Python module providers by level of ancestry. (pantsbuil…
Browse files Browse the repository at this point in the history
…d#17489)

I.e., if we're looking for a provider for foo.bar, and we have one for foo.bar 
and one for foo, take the former over the latter.

Previously we would see this as an ambiguity, and not infer a dep. This 
happened to a real user, in a situation where a namespace package was split 
between first-party and third-party code: 
pantsbuild#17286

This change involves a distinction between ModuleProvider and a new 
PossibleModuleProvider class, which also tracks the ancestry. Therefore {First,Third}PartyPythonModuleMapping now encapsulate the underlying dict 
instead of extending the dict. This clarifies that their role is to use the dict
 to produce other data, and they are not intended to be queried like dicts.

Co-authored-by: Tal Amuyal <[email protected]>
  • Loading branch information
benjyw and TalAmuyal authored Nov 8, 2022
1 parent 546bffc commit 077ea71
Show file tree
Hide file tree
Showing 3 changed files with 260 additions and 171 deletions.
6 changes: 3 additions & 3 deletions src/python/pants/backend/codegen/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@ def find_python_runtime_library_or_raise_error(
disable_inference_option: str,
) -> Address:
addresses = [
module_provider.addr
for module_provider in module_mapping.providers_for_module(
possible_module_provider.provider.addr
for possible_module_provider in module_mapping.providers_for_module(
runtime_library_module, resolve=resolve
)
if module_provider.typ == ModuleProviderType.IMPL
if possible_module_provider.provider.typ == ModuleProviderType.IMPL
]
if len(addresses) == 1:
return addresses[0]
Expand Down
142 changes: 94 additions & 48 deletions src/python/pants/backend/python/dependency_inference/module_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,13 @@ class ModuleProvider:
typ: ModuleProviderType


@dataclass(frozen=True, order=True)
class PossibleModuleProvider:
provider: ModuleProvider
# 0 = The provider mapped to the module itself, 1 = the provider mapped to its parent, etc.
ancestry: int


def module_from_stripped_path(path: PurePath) -> str:
module_name_with_slashes = (
path.parent if path.name in ("__init__.py", "__init__.pyi") else path.with_suffix("")
Expand Down Expand Up @@ -127,38 +134,48 @@ class FirstPartyPythonMappingImplMarker:
"""


class FirstPartyPythonModuleMapping(
FrozenDict[ResolveName, FrozenDict[str, Tuple[ModuleProvider, ...]]]
):
@dataclass(frozen=True)
class FirstPartyPythonModuleMapping:
resolves_to_modules_to_providers: FrozenDict[
ResolveName, FrozenDict[str, Tuple[ModuleProvider, ...]]
]

"""A merged mapping of each resolve name to the first-party module names contained and their
owning addresses.
This mapping may have been constructed from multiple distinct implementations, e.g.
implementations for each codegen backends.
"""

def _providers_for_resolve(self, module: str, resolve: str) -> tuple[ModuleProvider, ...]:
mapping = self.get(resolve)
def _providers_for_resolve(
self, module: str, resolve: str
) -> tuple[PossibleModuleProvider, ...]:
mapping = self.resolves_to_modules_to_providers.get(resolve)
if not mapping:
return ()

result = mapping.get(module, ())
if result:
return result
return tuple(PossibleModuleProvider(provider, 0) for provider in result)

# If the module is not found, try the parent, if any. This is to accommodate `from`
# imports, where we don't care about the specific symbol, but only the module. For example,
# with `from my_project.app import App`, we only care about the `my_project.app` part.
# If the module is not found, try the parent, if any. This is to handle `from` imports
# where the "module" we were handed was actually a symbol inside the module.
# E.g., with `from my_project.app import App`, we would be passed "my_project.app.App".
#
# We do not look past the direct parent, as this could cause multiple ambiguous owners to
# be resolved. This contrasts with the third-party module mapping, which will try every
# ancestor.
# TODO: Now that we capture the ancestry, we could look past the direct parent.
# One reason to do so would be to unify more of the FirstParty and ThirdParty impls.
if "." not in module:
return ()
parent_module = module.rsplit(".", maxsplit=1)[0]
return mapping.get(parent_module, ())
parent_providers = mapping.get(parent_module, ())
return tuple(PossibleModuleProvider(mp, 1) for mp in parent_providers)

def providers_for_module(self, module: str, resolve: str | None) -> tuple[ModuleProvider, ...]:
def providers_for_module(
self, module: str, resolve: str | None
) -> tuple[PossibleModuleProvider, ...]:
"""Find all providers for the module.
If `resolve` is None, will not consider resolves, i.e. any `python_source` et al can be
Expand All @@ -168,7 +185,8 @@ def providers_for_module(self, module: str, resolve: str | None) -> tuple[Module
return self._providers_for_resolve(module, resolve)
return tuple(
itertools.chain.from_iterable(
self._providers_for_resolve(module, resolve) for resolve in list(self.keys())
self._providers_for_resolve(module, resolve)
for resolve in list(self.resolves_to_modules_to_providers.keys())
)
)

Expand All @@ -193,13 +211,15 @@ async def merge_first_party_module_mappings(
for module, providers in modules_to_providers.items():
resolves_to_modules_to_providers[resolve][module].extend(providers)
return FirstPartyPythonModuleMapping(
(
resolve,
FrozenDict(
(mod, tuple(sorted(providers))) for mod, providers in sorted(mapping.items())
),
FrozenDict(
(
resolve,
FrozenDict(
(mod, tuple(sorted(providers))) for mod, providers in sorted(mapping.items())
),
)
for resolve, mapping in sorted(resolves_to_modules_to_providers.items())
)
for resolve, mapping in sorted(resolves_to_modules_to_providers.items())
)


Expand Down Expand Up @@ -242,29 +262,36 @@ async def map_first_party_python_targets_to_modules(
# -----------------------------------------------------------------------------------------------


class ThirdPartyPythonModuleMapping(
FrozenDict[ResolveName, FrozenDict[str, Tuple[ModuleProvider, ...]]]
):
@dataclass(frozen=True)
class ThirdPartyPythonModuleMapping:
"""A mapping of each resolve to the modules they contain and the addresses providing those
modules."""

def _providers_for_resolve(self, module: str, resolve: str) -> tuple[ModuleProvider, ...]:
mapping = self.get(resolve)
resolves_to_modules_to_providers: FrozenDict[
ResolveName, FrozenDict[str, Tuple[ModuleProvider, ...]]
]

def _providers_for_resolve(
self, module: str, resolve: str, ancestry: int = 0
) -> tuple[PossibleModuleProvider, ...]:
mapping = self.resolves_to_modules_to_providers.get(resolve)
if not mapping:
return ()

result = mapping.get(module, ())
if result:
return result
return tuple(PossibleModuleProvider(mp, ancestry) for mp in result)

# If the module is not found, recursively try the ancestor modules, if any. For example,
# pants.task.task.Task -> pants.task.task -> pants.task -> pants
if "." not in module:
return ()
parent_module = module.rsplit(".", maxsplit=1)[0]
return self._providers_for_resolve(parent_module, resolve)
return self._providers_for_resolve(parent_module, resolve, ancestry + 1)

def providers_for_module(self, module: str, resolve: str | None) -> tuple[ModuleProvider, ...]:
def providers_for_module(
self, module: str, resolve: str | None
) -> tuple[PossibleModuleProvider, ...]:
"""Find all providers for the module.
If `resolve` is None, will not consider resolves, i.e. any `python_requirement` can be
Expand All @@ -274,21 +301,22 @@ def providers_for_module(self, module: str, resolve: str | None) -> tuple[Module
return self._providers_for_resolve(module, resolve)
return tuple(
itertools.chain.from_iterable(
self._providers_for_resolve(module, resolve) for resolve in list(self.keys())
self._providers_for_resolve(module, resolve)
for resolve in list(self.resolves_to_modules_to_providers.keys())
)
)


@rule(desc="Creating map of third party targets to Python modules", level=LogLevel.DEBUG)
async def map_third_party_modules_to_addresses(
all_python_tgts: AllPythonTargets,
all_python_targets: AllPythonTargets,
python_setup: PythonSetup,
) -> ThirdPartyPythonModuleMapping:
resolves_to_modules_to_providers: DefaultDict[
ResolveName, DefaultDict[str, list[ModuleProvider]]
] = defaultdict(lambda: defaultdict(list))

for tgt in all_python_tgts.third_party:
for tgt in all_python_targets.third_party:
resolve = tgt[PythonRequirementResolveField].normalized_value(python_setup)

def add_modules(modules: Iterable[str], *, type_stub: bool = False) -> None:
Expand Down Expand Up @@ -335,13 +363,15 @@ def add_modules(modules: Iterable[str], *, type_stub: bool = False) -> None:
add_modules(DEFAULT_MODULE_MAPPING.get(proj_name, (fallback_value,)))

return ThirdPartyPythonModuleMapping(
(
resolve,
FrozenDict(
(mod, tuple(sorted(providers))) for mod, providers in sorted(mapping.items())
),
FrozenDict(
(
resolve,
FrozenDict(
(mod, tuple(sorted(providers))) for mod, providers in sorted(mapping.items())
),
)
for resolve, mapping in sorted(resolves_to_modules_to_providers.items())
)
for resolve, mapping in sorted(resolves_to_modules_to_providers.items())
)


Expand Down Expand Up @@ -386,24 +416,40 @@ async def map_module_to_address(
first_party_mapping: FirstPartyPythonModuleMapping,
third_party_mapping: ThirdPartyPythonModuleMapping,
) -> PythonModuleOwners:
providers = [
possible_providers: tuple[PossibleModuleProvider, ...] = (
*third_party_mapping.providers_for_module(request.module, resolve=request.resolve),
*first_party_mapping.providers_for_module(request.module, resolve=request.resolve),
]
addresses = tuple(provider.addr for provider in providers)
)

# There's no ambiguity if there are only 0-1 providers.
if len(providers) < 2:
return PythonModuleOwners(addresses)
# We attempt to disambiguate conflicting providers by taking - for each provider type -
# the providers for the closest ancestors to the requested modules. This prevents
# issues with namespace packages that are split between first-party and third-party
# (e.g., https://github.com/pantsbuild/pants/discussions/17286).

# Map from provider type to mutable pair of
# [closest ancestry, list of provider of that type at that ancestry level].
type_to_closest_providers: dict[ModuleProviderType, list] = defaultdict(lambda: [999, []])
for possible_provider in possible_providers:
val = type_to_closest_providers[possible_provider.provider.typ]
if possible_provider.ancestry < val[0]:
val[0] = possible_provider.ancestry
val[1] = []
# NB This must come after the < check above, so we handle the possible_provider
# that caused that check to pass.
if possible_provider.ancestry == val[0]:
val[1].append(possible_provider.provider)

closest_providers: list[ModuleProvider] = list(
itertools.chain(*[val[1] for val in type_to_closest_providers.values()])
)
addresses = tuple(provider.addr for provider in closest_providers)

# Else, it's ambiguous unless there are exactly two providers and one is a type stub and the
# other an implementation.
if len(providers) == 2 and (providers[0].typ == ModuleProviderType.TYPE_STUB) ^ (
providers[1].typ == ModuleProviderType.TYPE_STUB
):
return PythonModuleOwners(addresses)
# Check that we have at most one closest provider for each provider type.
# If we have more than one, signal ambiguity.
if any(len(val[1]) > 1 for val in type_to_closest_providers.values()):
return PythonModuleOwners((), ambiguous=addresses)

return PythonModuleOwners((), ambiguous=addresses)
return PythonModuleOwners(addresses)


def rules():
Expand Down
Loading

0 comments on commit 077ea71

Please sign in to comment.