Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consolidate descriptor handling in checkmember.py #18831

Merged
merged 4 commits into from
Mar 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 11 additions & 109 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,7 @@
import mypy.checkexpr
from mypy import errorcodes as codes, join, message_registry, nodes, operators
from mypy.binder import ConditionalTypeBinder, Frame, get_declaration
from mypy.checkmember import (
MemberContext,
analyze_decorator_or_funcbase_access,
analyze_descriptor_access,
analyze_member_access,
)
from mypy.checkmember import analyze_member_access
from mypy.checkpattern import PatternChecker
from mypy.constraints import SUPERTYPE_OF
from mypy.erasetype import erase_type, erase_typevars, remove_instance_last_known_values
Expand Down Expand Up @@ -3233,7 +3228,7 @@ def check_assignment(
)
else:
self.try_infer_partial_generic_type_from_assignment(lvalue, rvalue, "=")
lvalue_type, index_lvalue, inferred = self.check_lvalue(lvalue)
lvalue_type, index_lvalue, inferred = self.check_lvalue(lvalue, rvalue)
# If we're assigning to __getattr__ or similar methods, check that the signature is
# valid.
if isinstance(lvalue, NameExpr) and lvalue.node:
Expand Down Expand Up @@ -4339,7 +4334,9 @@ def check_multi_assignment_from_iterable(
else:
self.msg.type_not_iterable(rvalue_type, context)

def check_lvalue(self, lvalue: Lvalue) -> tuple[Type | None, IndexExpr | None, Var | None]:
def check_lvalue(
self, lvalue: Lvalue, rvalue: Expression | None = None
) -> tuple[Type | None, IndexExpr | None, Var | None]:
lvalue_type = None
index_lvalue = None
inferred = None
Expand All @@ -4357,7 +4354,7 @@ def check_lvalue(self, lvalue: Lvalue) -> tuple[Type | None, IndexExpr | None, V
elif isinstance(lvalue, IndexExpr):
index_lvalue = lvalue
elif isinstance(lvalue, MemberExpr):
lvalue_type = self.expr_checker.analyze_ordinary_member_access(lvalue, True)
lvalue_type = self.expr_checker.analyze_ordinary_member_access(lvalue, True, rvalue)
self.store_type(lvalue, lvalue_type)
elif isinstance(lvalue, NameExpr):
lvalue_type = self.expr_checker.analyze_ref_expr(lvalue, lvalue=True)
Expand Down Expand Up @@ -4704,12 +4701,8 @@ def check_member_assignment(

Return the inferred rvalue_type, inferred lvalue_type, and whether to use the binder
for this assignment.

Note: this method exists here and not in checkmember.py, because we need to take
care about interaction between binder and __set__().
"""
instance_type = get_proper_type(instance_type)
attribute_type = get_proper_type(attribute_type)
# Descriptors don't participate in class-attribute access
if (isinstance(instance_type, FunctionLike) and instance_type.is_type_obj()) or isinstance(
instance_type, TypeType
Expand All @@ -4721,107 +4714,16 @@ def check_member_assignment(
get_lvalue_type = self.expr_checker.analyze_ordinary_member_access(
lvalue, is_lvalue=False
)
use_binder = is_same_type(get_lvalue_type, attribute_type)

if not isinstance(attribute_type, Instance):
# TODO: support __set__() for union types.
rvalue_type, _ = self.check_simple_assignment(attribute_type, rvalue, context)
return rvalue_type, attribute_type, use_binder

mx = MemberContext(
is_lvalue=False,
is_super=False,
is_operator=False,
original_type=instance_type,
context=context,
self_type=None,
chk=self,
)
get_type = analyze_descriptor_access(attribute_type, mx, assignment=True)
if not attribute_type.type.has_readable_member("__set__"):
# If there is no __set__, we type-check that the assigned value matches
# the return type of __get__. This doesn't match the python semantics,
# (which allow you to override the descriptor with any value), but preserves
# the type of accessing the attribute (even after the override).
rvalue_type, _ = self.check_simple_assignment(get_type, rvalue, context)
return rvalue_type, get_type, use_binder

dunder_set = attribute_type.type.get_method("__set__")
if dunder_set is None:
self.fail(
message_registry.DESCRIPTOR_SET_NOT_CALLABLE.format(
attribute_type.str_with_options(self.options)
),
context,
)
return AnyType(TypeOfAny.from_error), get_type, False

bound_method = analyze_decorator_or_funcbase_access(
defn=dunder_set,
itype=attribute_type,
name="__set__",
mx=mx.copy_modified(self_type=attribute_type),
)
typ = map_instance_to_supertype(attribute_type, dunder_set.info)
dunder_set_type = expand_type_by_instance(bound_method, typ)

callable_name = self.expr_checker.method_fullname(attribute_type, "__set__")
dunder_set_type = self.expr_checker.transform_callee_type(
callable_name,
dunder_set_type,
[TempNode(instance_type, context=context), rvalue],
[nodes.ARG_POS, nodes.ARG_POS],
context,
object_type=attribute_type,
)

# For non-overloaded setters, the result should be type-checked like a regular assignment.
# Hence, we first only try to infer the type by using the rvalue as type context.
type_context = rvalue
with self.msg.filter_errors():
_, inferred_dunder_set_type = self.expr_checker.check_call(
dunder_set_type,
[TempNode(instance_type, context=context), type_context],
[nodes.ARG_POS, nodes.ARG_POS],
context,
object_type=attribute_type,
callable_name=callable_name,
)

# And now we in fact type check the call, to show errors related to wrong arguments
# count, etc., replacing the type context for non-overloaded setters only.
inferred_dunder_set_type = get_proper_type(inferred_dunder_set_type)
if isinstance(inferred_dunder_set_type, CallableType):
type_context = TempNode(AnyType(TypeOfAny.special_form), context=context)
self.expr_checker.check_call(
dunder_set_type,
[TempNode(instance_type, context=context), type_context],
[nodes.ARG_POS, nodes.ARG_POS],
context,
object_type=attribute_type,
callable_name=callable_name,
)

# Search for possible deprecations:
mx.chk.check_deprecated(dunder_set, mx.context)
mx.chk.warn_deprecated_overload_item(
dunder_set, mx.context, target=inferred_dunder_set_type, selftype=attribute_type
)

# In the following cases, a message already will have been recorded in check_call.
if (not isinstance(inferred_dunder_set_type, CallableType)) or (
len(inferred_dunder_set_type.arg_types) < 2
):
return AnyType(TypeOfAny.from_error), get_type, False

set_type = inferred_dunder_set_type.arg_types[1]
# Special case: if the rvalue_type is a subtype of both '__get__' and '__set__' types,
# and '__get__' type is narrower than '__set__', then we invoke the binder to narrow type
# by this assignment. Technically, this is not safe, but in practice this is
# what a user expects.
rvalue_type, _ = self.check_simple_assignment(set_type, rvalue, context)
infer = is_subtype(rvalue_type, get_type) and is_subtype(get_type, set_type)
return rvalue_type if infer else set_type, get_type, infer
rvalue_type, _ = self.check_simple_assignment(attribute_type, rvalue, context)
infer = is_subtype(rvalue_type, get_lvalue_type) and is_subtype(
get_lvalue_type, attribute_type
)
return rvalue_type if infer else attribute_type, attribute_type, infer

def check_indexed_assignment(
self, lvalue: IndexExpr, rvalue: Expression, context: Context
Expand Down
10 changes: 8 additions & 2 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3334,8 +3334,13 @@ def visit_member_expr(self, e: MemberExpr, is_lvalue: bool = False) -> Type:
self.chk.warn_deprecated(e.node, e)
return narrowed

def analyze_ordinary_member_access(self, e: MemberExpr, is_lvalue: bool) -> Type:
"""Analyse member expression or member lvalue."""
def analyze_ordinary_member_access(
self, e: MemberExpr, is_lvalue: bool, rvalue: Expression | None = None
) -> Type:
"""Analyse member expression or member lvalue.

An rvalue can be provided optionally to infer better setter type when is_lvalue is True.
"""
if e.kind is not None:
# This is a reference to a module attribute.
return self.analyze_ref_expr(e)
Expand Down Expand Up @@ -3366,6 +3371,7 @@ def analyze_ordinary_member_access(self, e: MemberExpr, is_lvalue: bool) -> Type
in_literal_context=self.is_literal_context(),
module_symbol_table=module_symbol_table,
is_self=is_self,
rvalue=rvalue,
)

return member_type
Expand Down
Loading