diff --git a/mypy/checker.py b/mypy/checker.py index 62acfc9e3abe..12afa4d3edf5 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -12,12 +12,7 @@ import mypy.checkexpr from mypy import errorcodes as codes, join, message_registry, nodes, operators from mypy.binder import ConditionalTypeBinder, Frame, get_declaration -from mypy.checkmember import ( - MemberContext, - analyze_decorator_or_funcbase_access, - analyze_descriptor_access, - analyze_member_access, -) +from mypy.checkmember import analyze_member_access from mypy.checkpattern import PatternChecker from mypy.constraints import SUPERTYPE_OF from mypy.erasetype import erase_type, erase_typevars, remove_instance_last_known_values @@ -3233,7 +3228,7 @@ def check_assignment( ) else: self.try_infer_partial_generic_type_from_assignment(lvalue, rvalue, "=") - lvalue_type, index_lvalue, inferred = self.check_lvalue(lvalue) + lvalue_type, index_lvalue, inferred = self.check_lvalue(lvalue, rvalue) # If we're assigning to __getattr__ or similar methods, check that the signature is # valid. if isinstance(lvalue, NameExpr) and lvalue.node: @@ -4339,7 +4334,9 @@ def check_multi_assignment_from_iterable( else: self.msg.type_not_iterable(rvalue_type, context) - def check_lvalue(self, lvalue: Lvalue) -> tuple[Type | None, IndexExpr | None, Var | None]: + def check_lvalue( + self, lvalue: Lvalue, rvalue: Expression | None = None + ) -> tuple[Type | None, IndexExpr | None, Var | None]: lvalue_type = None index_lvalue = None inferred = None @@ -4357,7 +4354,7 @@ def check_lvalue(self, lvalue: Lvalue) -> tuple[Type | None, IndexExpr | None, V elif isinstance(lvalue, IndexExpr): index_lvalue = lvalue elif isinstance(lvalue, MemberExpr): - lvalue_type = self.expr_checker.analyze_ordinary_member_access(lvalue, True) + lvalue_type = self.expr_checker.analyze_ordinary_member_access(lvalue, True, rvalue) self.store_type(lvalue, lvalue_type) elif isinstance(lvalue, NameExpr): lvalue_type = self.expr_checker.analyze_ref_expr(lvalue, lvalue=True) @@ -4704,12 +4701,8 @@ def check_member_assignment( Return the inferred rvalue_type, inferred lvalue_type, and whether to use the binder for this assignment. - - Note: this method exists here and not in checkmember.py, because we need to take - care about interaction between binder and __set__(). """ instance_type = get_proper_type(instance_type) - attribute_type = get_proper_type(attribute_type) # Descriptors don't participate in class-attribute access if (isinstance(instance_type, FunctionLike) and instance_type.is_type_obj()) or isinstance( instance_type, TypeType @@ -4721,107 +4714,16 @@ def check_member_assignment( get_lvalue_type = self.expr_checker.analyze_ordinary_member_access( lvalue, is_lvalue=False ) - use_binder = is_same_type(get_lvalue_type, attribute_type) - - if not isinstance(attribute_type, Instance): - # TODO: support __set__() for union types. - rvalue_type, _ = self.check_simple_assignment(attribute_type, rvalue, context) - return rvalue_type, attribute_type, use_binder - - mx = MemberContext( - is_lvalue=False, - is_super=False, - is_operator=False, - original_type=instance_type, - context=context, - self_type=None, - chk=self, - ) - get_type = analyze_descriptor_access(attribute_type, mx, assignment=True) - if not attribute_type.type.has_readable_member("__set__"): - # If there is no __set__, we type-check that the assigned value matches - # the return type of __get__. This doesn't match the python semantics, - # (which allow you to override the descriptor with any value), but preserves - # the type of accessing the attribute (even after the override). - rvalue_type, _ = self.check_simple_assignment(get_type, rvalue, context) - return rvalue_type, get_type, use_binder - - dunder_set = attribute_type.type.get_method("__set__") - if dunder_set is None: - self.fail( - message_registry.DESCRIPTOR_SET_NOT_CALLABLE.format( - attribute_type.str_with_options(self.options) - ), - context, - ) - return AnyType(TypeOfAny.from_error), get_type, False - - bound_method = analyze_decorator_or_funcbase_access( - defn=dunder_set, - itype=attribute_type, - name="__set__", - mx=mx.copy_modified(self_type=attribute_type), - ) - typ = map_instance_to_supertype(attribute_type, dunder_set.info) - dunder_set_type = expand_type_by_instance(bound_method, typ) - - callable_name = self.expr_checker.method_fullname(attribute_type, "__set__") - dunder_set_type = self.expr_checker.transform_callee_type( - callable_name, - dunder_set_type, - [TempNode(instance_type, context=context), rvalue], - [nodes.ARG_POS, nodes.ARG_POS], - context, - object_type=attribute_type, - ) - - # For non-overloaded setters, the result should be type-checked like a regular assignment. - # Hence, we first only try to infer the type by using the rvalue as type context. - type_context = rvalue - with self.msg.filter_errors(): - _, inferred_dunder_set_type = self.expr_checker.check_call( - dunder_set_type, - [TempNode(instance_type, context=context), type_context], - [nodes.ARG_POS, nodes.ARG_POS], - context, - object_type=attribute_type, - callable_name=callable_name, - ) - - # And now we in fact type check the call, to show errors related to wrong arguments - # count, etc., replacing the type context for non-overloaded setters only. - inferred_dunder_set_type = get_proper_type(inferred_dunder_set_type) - if isinstance(inferred_dunder_set_type, CallableType): - type_context = TempNode(AnyType(TypeOfAny.special_form), context=context) - self.expr_checker.check_call( - dunder_set_type, - [TempNode(instance_type, context=context), type_context], - [nodes.ARG_POS, nodes.ARG_POS], - context, - object_type=attribute_type, - callable_name=callable_name, - ) - - # Search for possible deprecations: - mx.chk.check_deprecated(dunder_set, mx.context) - mx.chk.warn_deprecated_overload_item( - dunder_set, mx.context, target=inferred_dunder_set_type, selftype=attribute_type - ) - # In the following cases, a message already will have been recorded in check_call. - if (not isinstance(inferred_dunder_set_type, CallableType)) or ( - len(inferred_dunder_set_type.arg_types) < 2 - ): - return AnyType(TypeOfAny.from_error), get_type, False - - set_type = inferred_dunder_set_type.arg_types[1] # Special case: if the rvalue_type is a subtype of both '__get__' and '__set__' types, # and '__get__' type is narrower than '__set__', then we invoke the binder to narrow type # by this assignment. Technically, this is not safe, but in practice this is # what a user expects. - rvalue_type, _ = self.check_simple_assignment(set_type, rvalue, context) - infer = is_subtype(rvalue_type, get_type) and is_subtype(get_type, set_type) - return rvalue_type if infer else set_type, get_type, infer + rvalue_type, _ = self.check_simple_assignment(attribute_type, rvalue, context) + infer = is_subtype(rvalue_type, get_lvalue_type) and is_subtype( + get_lvalue_type, attribute_type + ) + return rvalue_type if infer else attribute_type, attribute_type, infer def check_indexed_assignment( self, lvalue: IndexExpr, rvalue: Expression, context: Context diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 812121994fd7..0804917476a9 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -3334,8 +3334,13 @@ def visit_member_expr(self, e: MemberExpr, is_lvalue: bool = False) -> Type: self.chk.warn_deprecated(e.node, e) return narrowed - def analyze_ordinary_member_access(self, e: MemberExpr, is_lvalue: bool) -> Type: - """Analyse member expression or member lvalue.""" + def analyze_ordinary_member_access( + self, e: MemberExpr, is_lvalue: bool, rvalue: Expression | None = None + ) -> Type: + """Analyse member expression or member lvalue. + + An rvalue can be provided optionally to infer better setter type when is_lvalue is True. + """ if e.kind is not None: # This is a reference to a module attribute. return self.analyze_ref_expr(e) @@ -3366,6 +3371,7 @@ def analyze_ordinary_member_access(self, e: MemberExpr, is_lvalue: bool) -> Type in_literal_context=self.is_literal_context(), module_symbol_table=module_symbol_table, is_self=is_self, + rvalue=rvalue, ) return member_type diff --git a/mypy/checkmember.py b/mypy/checkmember.py index 0535486bfd4a..ebc4fe8705ce 100644 --- a/mypy/checkmember.py +++ b/mypy/checkmember.py @@ -21,6 +21,7 @@ SYMBOL_FUNCBASE_TYPES, Context, Decorator, + Expression, FuncBase, FuncDef, IndexExpr, @@ -96,6 +97,7 @@ def __init__( module_symbol_table: SymbolTable | None = None, no_deferral: bool = False, is_self: bool = False, + rvalue: Expression | None = None, ) -> None: self.is_lvalue = is_lvalue self.is_super = is_super @@ -108,6 +110,9 @@ def __init__( self.module_symbol_table = module_symbol_table self.no_deferral = no_deferral self.is_self = is_self + if rvalue is not None: + assert is_lvalue + self.rvalue = rvalue def named_type(self, name: str) -> Instance: return self.chk.named_type(name) @@ -132,6 +137,7 @@ def copy_modified( self_type=self.self_type, module_symbol_table=self.module_symbol_table, no_deferral=self.no_deferral, + rvalue=self.rvalue, ) if self_type is not None: mx.self_type = self_type @@ -158,6 +164,7 @@ def analyze_member_access( module_symbol_table: SymbolTable | None = None, no_deferral: bool = False, is_self: bool = False, + rvalue: Expression | None = None, ) -> Type: """Return the type of attribute 'name' of 'typ'. @@ -176,11 +183,14 @@ def analyze_member_access( of 'original_type'. 'original_type' is always preserved as the 'typ' type used in the initial, non-recursive call. The 'self_type' is a component of 'original_type' to which generic self should be bound (a narrower type that has a fallback to instance). - Currently this is used only for union types. + Currently, this is used only for union types. - 'module_symbol_table' is passed to this function if 'typ' is actually a module + 'module_symbol_table' is passed to this function if 'typ' is actually a module, and we want to keep track of the available attributes of the module (since they are not available via the type object directly) + + 'rvalue' can be provided optionally to infer better setter type when is_lvalue is True, + most notably this helps for descriptors with overloaded __set__() method. """ mx = MemberContext( is_lvalue=is_lvalue, @@ -193,6 +203,7 @@ def analyze_member_access( module_symbol_table=module_symbol_table, no_deferral=no_deferral, is_self=is_self, + rvalue=rvalue, ) result = _analyze_member_access(name, typ, mx, override_info) possible_literal = get_proper_type(result) @@ -619,9 +630,7 @@ def check_final_member(name: str, info: TypeInfo, msg: MessageBuilder, ctx: Cont msg.cant_assign_to_final(name, attr_assign=True, ctx=ctx) -def analyze_descriptor_access( - descriptor_type: Type, mx: MemberContext, *, assignment: bool = False -) -> Type: +def analyze_descriptor_access(descriptor_type: Type, mx: MemberContext) -> Type: """Type check descriptor access. Arguments: @@ -629,7 +638,7 @@ def analyze_descriptor_access( (the type of ``f`` in ``a.f`` when ``f`` is a descriptor). mx: The current member access context. Return: - The return type of the appropriate ``__get__`` overload for the descriptor. + The return type of the appropriate ``__get__/__set__`` overload for the descriptor. """ instance_type = get_proper_type(mx.self_type) orig_descriptor_type = descriptor_type @@ -638,15 +647,24 @@ def analyze_descriptor_access( if isinstance(descriptor_type, UnionType): # Map the access over union types return make_simplified_union( - [ - analyze_descriptor_access(typ, mx, assignment=assignment) - for typ in descriptor_type.items - ] + [analyze_descriptor_access(typ, mx) for typ in descriptor_type.items] ) elif not isinstance(descriptor_type, Instance): return orig_descriptor_type - if not descriptor_type.type.has_readable_member("__get__"): + if not mx.is_lvalue and not descriptor_type.type.has_readable_member("__get__"): + return orig_descriptor_type + + # We do this check first to accommodate for descriptors with only __set__ method. + # If there is no __set__, we type-check that the assigned value matches + # the return type of __get__. This doesn't match the python semantics, + # (which allow you to override the descriptor with any value), but preserves + # the type of accessing the attribute (even after the override). + if mx.is_lvalue and descriptor_type.type.has_readable_member("__set__"): + return analyze_descriptor_assign(descriptor_type, mx) + + if mx.is_lvalue and not descriptor_type.type.has_readable_member("__get__"): + # This turned out to be not a descriptor after all. return orig_descriptor_type dunder_get = descriptor_type.type.get_method("__get__") @@ -703,11 +721,10 @@ def analyze_descriptor_access( callable_name=callable_name, ) - if not assignment: - mx.chk.check_deprecated(dunder_get, mx.context) - mx.chk.warn_deprecated_overload_item( - dunder_get, mx.context, target=inferred_dunder_get_type, selftype=descriptor_type - ) + mx.chk.check_deprecated(dunder_get, mx.context) + mx.chk.warn_deprecated_overload_item( + dunder_get, mx.context, target=inferred_dunder_get_type, selftype=descriptor_type + ) inferred_dunder_get_type = get_proper_type(inferred_dunder_get_type) if isinstance(inferred_dunder_get_type, AnyType): @@ -726,6 +743,79 @@ def analyze_descriptor_access( return inferred_dunder_get_type.ret_type +def analyze_descriptor_assign(descriptor_type: Instance, mx: MemberContext) -> Type: + instance_type = get_proper_type(mx.self_type) + dunder_set = descriptor_type.type.get_method("__set__") + if dunder_set is None: + mx.chk.fail( + message_registry.DESCRIPTOR_SET_NOT_CALLABLE.format( + descriptor_type.str_with_options(mx.msg.options) + ), + mx.context, + ) + return AnyType(TypeOfAny.from_error) + + bound_method = analyze_decorator_or_funcbase_access( + defn=dunder_set, + itype=descriptor_type, + name="__set__", + mx=mx.copy_modified(is_lvalue=False, self_type=descriptor_type), + ) + typ = map_instance_to_supertype(descriptor_type, dunder_set.info) + dunder_set_type = expand_type_by_instance(bound_method, typ) + + callable_name = mx.chk.expr_checker.method_fullname(descriptor_type, "__set__") + rvalue = mx.rvalue or TempNode(AnyType(TypeOfAny.special_form), context=mx.context) + dunder_set_type = mx.chk.expr_checker.transform_callee_type( + callable_name, + dunder_set_type, + [TempNode(instance_type, context=mx.context), rvalue], + [ARG_POS, ARG_POS], + mx.context, + object_type=descriptor_type, + ) + + # For non-overloaded setters, the result should be type-checked like a regular assignment. + # Hence, we first only try to infer the type by using the rvalue as type context. + type_context = rvalue + with mx.msg.filter_errors(): + _, inferred_dunder_set_type = mx.chk.expr_checker.check_call( + dunder_set_type, + [TempNode(instance_type, context=mx.context), type_context], + [ARG_POS, ARG_POS], + mx.context, + object_type=descriptor_type, + callable_name=callable_name, + ) + + # And now we in fact type check the call, to show errors related to wrong arguments + # count, etc., replacing the type context for non-overloaded setters only. + inferred_dunder_set_type = get_proper_type(inferred_dunder_set_type) + if isinstance(inferred_dunder_set_type, CallableType): + type_context = TempNode(AnyType(TypeOfAny.special_form), context=mx.context) + mx.chk.expr_checker.check_call( + dunder_set_type, + [TempNode(instance_type, context=mx.context), type_context], + [ARG_POS, ARG_POS], + mx.context, + object_type=descriptor_type, + callable_name=callable_name, + ) + + # Search for possible deprecations: + mx.chk.check_deprecated(dunder_set, mx.context) + mx.chk.warn_deprecated_overload_item( + dunder_set, mx.context, target=inferred_dunder_set_type, selftype=descriptor_type + ) + + # In the following cases, a message already will have been recorded in check_call. + if (not isinstance(inferred_dunder_set_type, CallableType)) or ( + len(inferred_dunder_set_type.arg_types) < 2 + ): + return AnyType(TypeOfAny.from_error) + return inferred_dunder_set_type.arg_types[1] + + def is_instance_var(var: Var) -> bool: """Return if var is an instance variable according to PEP 526.""" return ( @@ -810,6 +900,7 @@ def analyze_var( # A property cannot have an overloaded type => the cast is fine. assert isinstance(expanded_signature, CallableType) if var.is_settable_property and mx.is_lvalue and var.setter_type is not None: + # TODO: use check_call() to infer better type, same as for __set__(). result = expanded_signature.arg_types[0] else: result = expanded_signature.ret_type @@ -822,7 +913,7 @@ def analyze_var( result = AnyType(TypeOfAny.special_form) fullname = f"{var.info.fullname}.{name}" hook = mx.chk.plugin.get_attribute_hook(fullname) - if result and not mx.is_lvalue and not implicit: + if result and not (implicit or var.info.is_protocol and is_instance_var(var)): result = analyze_descriptor_access(result, mx) if hook: result = hook( @@ -1075,6 +1166,7 @@ def analyze_class_attribute_access( result = add_class_tvars( t, isuper, is_classmethod, is_staticmethod, mx.self_type, original_vars=original_vars ) + # __set__ is not called on class objects. if not mx.is_lvalue: result = analyze_descriptor_access(result, mx)