From 29889c882c5f350a23fac6c298a0d005cf5f80cd Mon Sep 17 00:00:00 2001 From: Michael Lee Date: Sat, 16 Jun 2018 13:44:27 -0700 Subject: [PATCH] Make overloads support classmethod and staticmethod (#5224) * Move 'is_class' and 'is_static' into FuncBase This commit moves the `is_class` and `is_static` fields into FuncBase. It also cleans up the list of flags so they don't repeat the 'is_property' entry, which is now present in `FUNCBASE_FLAGS`. The high-level plan is to modify the `is_class` and `is_static` fields in OverloadedFuncDef for use later in mypy. * Make semantic analysis phase record class/static methods with overloads This commit adjusts the semantic analysis phase to detect and record when an overload appears to be a classmethod or staticmethod. * Broaden class/static method checks to catch overloads This commit modifies mypy to use the `is_static` and `is_class` fields of OverloadedFuncDef as appropriate. I found the code snippets to modify by asking PyCharm for all instances of code using those two fields and modified the surrounding code as appropriate. * Add support for overloaded classmethods in attrs/dataclasses Both the attrs and dataclasses plugins manually patch classmethods -- we do the same for overloads. * Respond to code review This commit: 1. Updates astdiff.py and adds a case to one of the fine-grained dependency test files. 2. Adds some helper methods to FunctionLike. 3. Performs a few misc cleanups. * Respond to code review; add tests for self types --- mypy/checker.py | 16 +- mypy/checkmember.py | 3 +- mypy/messages.py | 8 +- mypy/nodes.py | 27 +- mypy/plugins/attrs.py | 10 + mypy/plugins/dataclasses.py | 14 +- mypy/semanal.py | 31 ++ mypy/server/astdiff.py | 6 +- mypy/strconv.py | 4 + mypy/treetransform.py | 3 + mypy/types.py | 20 +- test-data/unit/check-attr.test | 32 ++ test-data/unit/check-dataclasses.test | 33 ++ test-data/unit/check-overloading.test | 472 +++++++++++++++++++++++ test-data/unit/fine-grained.test | 27 ++ test-data/unit/fixtures/staticmethod.pyi | 1 + 16 files changed, 677 insertions(+), 30 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index feedc5a820a4..7e94d50d58f1 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -1289,12 +1289,10 @@ def check_override(self, override: FunctionLike, original: FunctionLike, # this could be unsafe with reverse operator methods. fail = True - if isinstance(original, CallableType) and isinstance(override, CallableType): - if (isinstance(original.definition, FuncItem) and - isinstance(override.definition, FuncItem)): - if ((original.definition.is_static or original.definition.is_class) and - not (override.definition.is_static or override.definition.is_class)): - fail = True + if isinstance(original, FunctionLike) and isinstance(override, FunctionLike): + if ((original.is_classmethod() or original.is_staticmethod()) and + not (override.is_classmethod() or override.is_staticmethod())): + fail = True if fail: emitted_msg = False @@ -3911,8 +3909,6 @@ def is_untyped_decorator(typ: Optional[Type]) -> bool: def is_static(func: Union[FuncBase, Decorator]) -> bool: if isinstance(func, Decorator): return is_static(func.func) - elif isinstance(func, OverloadedFuncDef): - return any(is_static(item) for item in func.items) - elif isinstance(func, FuncItem): + elif isinstance(func, FuncBase): return func.is_static - return False + assert False, "Unexpected func type: {}".format(type(func)) diff --git a/mypy/checkmember.py b/mypy/checkmember.py index 99298c04fd02..0c0a3c5af5d5 100644 --- a/mypy/checkmember.py +++ b/mypy/checkmember.py @@ -448,7 +448,8 @@ def analyze_class_attribute_access(itype: Instance, return handle_partial_attribute_type(t, is_lvalue, msg, symnode) if not is_method and (isinstance(t, TypeVarType) or get_type_vars(t)): msg.fail(messages.GENERIC_INSTANCE_VAR_CLASS_ACCESS, context) - is_classmethod = is_decorated and cast(Decorator, node.node).func.is_class + is_classmethod = ((is_decorated and cast(Decorator, node.node).func.is_class) + or (isinstance(node.node, FuncBase) and node.node.is_class)) return add_class_tvars(t, itype, is_classmethod, builtin_type, original_type) elif isinstance(node.node, Var): not_ready_callback(name, context) diff --git a/mypy/messages.py b/mypy/messages.py index a5bc2e1aa9f6..c272ce75d997 100644 --- a/mypy/messages.py +++ b/mypy/messages.py @@ -27,7 +27,7 @@ TypeInfo, Context, MypyFile, op_methods, FuncDef, reverse_type_aliases, ARG_POS, ARG_OPT, ARG_NAMED, ARG_NAMED_OPT, ARG_STAR, ARG_STAR2, ReturnStmt, NameExpr, Var, CONTRAVARIANT, COVARIANT, SymbolNode, - CallExpr, Expression + CallExpr, Expression, OverloadedFuncDef, ) # Constants that represent simple type checker error message, i.e. messages @@ -942,6 +942,12 @@ def incompatible_typevar_value(self, self.format(typ)), context) + def overload_inconsistently_applies_decorator(self, decorator: str, context: Context) -> None: + self.fail( + 'Overload does not consistently use the "@{}" '.format(decorator) + + 'decorator on all function signatures.', + context) + def overloaded_signatures_overlap(self, index1: int, index2: int, context: Context) -> None: self.fail('Overloaded function signatures {} and {} overlap with ' 'incompatible return types'.format(index1, index2), context) diff --git a/mypy/nodes.py b/mypy/nodes.py index 2ec8859a8a46..4f6ef124e311 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -370,6 +370,11 @@ def __str__(self) -> str: return 'ImportedName(%s)' % self.target_fullname +FUNCBASE_FLAGS = [ + 'is_property', 'is_class', 'is_static', +] + + class FuncBase(Node): """Abstract base class for function-like nodes""" @@ -377,6 +382,8 @@ class FuncBase(Node): 'unanalyzed_type', 'info', 'is_property', + 'is_class', # Uses "@classmethod" + 'is_static', # USes "@staticmethod" '_fullname', ) @@ -391,6 +398,8 @@ def __init__(self) -> None: # TODO: Type should be Optional[TypeInfo] self.info = cast(TypeInfo, None) self.is_property = False + self.is_class = False + self.is_static = False # Name with module prefix # TODO: Type should be Optional[str] self._fullname = cast(str, None) @@ -436,8 +445,8 @@ def serialize(self) -> JsonDict: 'items': [i.serialize() for i in self.items], 'type': None if self.type is None else self.type.serialize(), 'fullname': self._fullname, - 'is_property': self.is_property, - 'impl': None if self.impl is None else self.impl.serialize() + 'impl': None if self.impl is None else self.impl.serialize(), + 'flags': get_flags(self, FUNCBASE_FLAGS), } @classmethod @@ -451,7 +460,7 @@ def deserialize(cls, data: JsonDict) -> 'OverloadedFuncDef': if data.get('type') is not None: res.type = mypy.types.deserialize_type(data['type']) res._fullname = data['fullname'] - res.is_property = data['is_property'] + set_flags(res, data['flags']) # NOTE: res.info will be set in the fixup phase. return res @@ -481,9 +490,9 @@ def set_line(self, target: Union[Context, int], column: Optional[int] = None) -> self.variable.set_line(self.line, self.column) -FUNCITEM_FLAGS = [ +FUNCITEM_FLAGS = FUNCBASE_FLAGS + [ 'is_overload', 'is_generator', 'is_coroutine', 'is_async_generator', - 'is_awaitable_coroutine', 'is_static', 'is_class', + 'is_awaitable_coroutine', ] @@ -503,8 +512,6 @@ class FuncItem(FuncBase): 'is_coroutine', # Defined using 'async def' syntax? 'is_async_generator', # Is an async def generator? 'is_awaitable_coroutine', # Decorated with '@{typing,asyncio}.coroutine'? - 'is_static', # Uses @staticmethod? - 'is_class', # Uses @classmethod? 'expanded', # Variants of function with type variables with values expanded ) @@ -525,8 +532,6 @@ def __init__(self, self.is_coroutine = False self.is_async_generator = False self.is_awaitable_coroutine = False - self.is_static = False - self.is_class = False self.expanded = [] # type: List[FuncItem] self.min_args = 0 @@ -547,7 +552,7 @@ def is_dynamic(self) -> bool: FUNCDEF_FLAGS = FUNCITEM_FLAGS + [ - 'is_decorated', 'is_conditional', 'is_abstract', 'is_property', + 'is_decorated', 'is_conditional', 'is_abstract', ] @@ -561,7 +566,6 @@ class FuncDef(FuncItem, SymbolNode, Statement): 'is_decorated', 'is_conditional', 'is_abstract', - 'is_property', 'original_def', ) @@ -575,7 +579,6 @@ def __init__(self, self.is_decorated = False self.is_conditional = False # Defined conditionally (within block)? self.is_abstract = False - self.is_property = False # Original conditional definition self.original_def = None # type: Union[None, FuncDef, Var, Decorator] diff --git a/mypy/plugins/attrs.py b/mypy/plugins/attrs.py index 028b86ca8550..e1dbc1221151 100644 --- a/mypy/plugins/attrs.py +++ b/mypy/plugins/attrs.py @@ -470,6 +470,16 @@ def _add_init(ctx: 'mypy.plugin.ClassDefContext', attributes: List[Attribute], func_type = stmt.func.type if isinstance(func_type, CallableType): func_type.arg_types[0] = ctx.api.class_type(ctx.cls.info) + if isinstance(stmt, OverloadedFuncDef) and stmt.is_class: + func_type = stmt.type + if isinstance(func_type, Overloaded): + class_type = ctx.api.class_type(ctx.cls.info) + for item in func_type.items(): + item.arg_types[0] = class_type + if stmt.impl is not None: + assert isinstance(stmt.impl, Decorator) + if isinstance(stmt.impl.func.type, CallableType): + stmt.impl.func.type.arg_types[0] = class_type class MethodAdder: diff --git a/mypy/plugins/dataclasses.py b/mypy/plugins/dataclasses.py index 9dba62558b2d..d545b39e7f19 100644 --- a/mypy/plugins/dataclasses.py +++ b/mypy/plugins/dataclasses.py @@ -4,12 +4,12 @@ from mypy.nodes import ( ARG_OPT, ARG_POS, MDEF, Argument, AssignmentStmt, CallExpr, Context, Decorator, Expression, FuncDef, JsonDict, NameExpr, - SymbolTableNode, TempNode, TypeInfo, Var, + OverloadedFuncDef, SymbolTableNode, TempNode, TypeInfo, Var, ) from mypy.plugin import ClassDefContext from mypy.plugins.common import _add_method, _get_decorator_bool_argument from mypy.types import ( - CallableType, Instance, NoneTyp, TypeVarDef, TypeVarType, + CallableType, Instance, NoneTyp, Overloaded, TypeVarDef, TypeVarType, ) # The set of decorators that generate dataclasses. @@ -95,6 +95,16 @@ def transform(self) -> None: func_type = stmt.func.type if isinstance(func_type, CallableType): func_type.arg_types[0] = self._ctx.api.class_type(self._ctx.cls.info) + if isinstance(stmt, OverloadedFuncDef) and stmt.is_class: + func_type = stmt.type + if isinstance(func_type, Overloaded): + class_type = ctx.api.class_type(ctx.cls.info) + for item in func_type.items(): + item.arg_types[0] = class_type + if stmt.impl is not None: + assert isinstance(stmt.impl, Decorator) + if isinstance(stmt.impl.func.type, CallableType): + stmt.impl.func.type.arg_types[0] = class_type # Add an eq method, but only if the class doesn't already have one. if decorator_arguments['eq'] and info.get('__eq__') is None: diff --git a/mypy/semanal.py b/mypy/semanal.py index e7678d0e6799..9249547d3593 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -587,6 +587,37 @@ def _visit_overloaded_func_def(self, defn: OverloadedFuncDef) -> None: # redefinitions already. return + # We know this is an overload def -- let's handle classmethod and staticmethod + class_status = [] + static_status = [] + for item in defn.items: + if isinstance(item, Decorator): + inner = item.func + elif isinstance(item, FuncDef): + inner = item + else: + assert False, "The 'item' variable is an unexpected type: {}".format(type(item)) + class_status.append(inner.is_class) + static_status.append(inner.is_static) + + if defn.impl is not None: + if isinstance(defn.impl, Decorator): + inner = defn.impl.func + elif isinstance(defn.impl, FuncDef): + inner = defn.impl + else: + assert False, "Unexpected impl type: {}".format(type(defn.impl)) + class_status.append(inner.is_class) + static_status.append(inner.is_static) + + if len(set(class_status)) != 1: + self.msg.overload_inconsistently_applies_decorator('classmethod', defn) + elif len(set(static_status)) != 1: + self.msg.overload_inconsistently_applies_decorator('staticmethod', defn) + else: + defn.is_class = class_status[0] + defn.is_static = static_status[0] + if self.type and not self.is_func_scope(): self.type.names[defn.name()] = SymbolTableNode(MDEF, defn, typ=defn.type) diff --git a/mypy/server/astdiff.py b/mypy/server/astdiff.py index e525496163ef..e24001cbffcc 100644 --- a/mypy/server/astdiff.py +++ b/mypy/server/astdiff.py @@ -54,7 +54,7 @@ class level -- these are handled at attribute level (say, 'mod.Cls.method' from mypy.nodes import ( SymbolTable, TypeInfo, Var, SymbolNode, Decorator, TypeVarExpr, - OverloadedFuncDef, FuncItem, MODULE_REF, TYPE_ALIAS, UNBOUND_IMPORTED, TVAR + FuncBase, OverloadedFuncDef, FuncItem, MODULE_REF, TYPE_ALIAS, UNBOUND_IMPORTED, TVAR ) from mypy.types import ( Type, TypeVisitor, UnboundType, AnyType, NoneTyp, UninhabitedType, @@ -167,13 +167,13 @@ def snapshot_definition(node: Optional[SymbolNode], The representation is nested tuples and dicts. Only externally visible attributes are included. """ - if isinstance(node, (OverloadedFuncDef, FuncItem)): + if isinstance(node, FuncBase): # TODO: info if node.type: signature = snapshot_type(node.type) else: signature = snapshot_untyped_signature(node) - return ('Func', common, node.is_property, signature) + return ('Func', common, node.is_property, node.is_class, node.is_static, signature) elif isinstance(node, Var): return ('Var', common, snapshot_optional_type(node.type)) elif isinstance(node, Decorator): diff --git a/mypy/strconv.py b/mypy/strconv.py index 74982ca8eb0a..1a2b5911aa5d 100644 --- a/mypy/strconv.py +++ b/mypy/strconv.py @@ -146,6 +146,10 @@ def visit_overloaded_func_def(self, o: 'mypy.nodes.OverloadedFuncDef') -> str: a.insert(0, o.type) if o.impl: a.insert(0, o.impl) + if o.is_static: + a.insert(-1, 'Static') + if o.is_class: + a.insert(-1, 'Class') return self.dump(a, o) def visit_class_def(self, o: 'mypy.nodes.ClassDef') -> str: diff --git a/mypy/treetransform.py b/mypy/treetransform.py index c9c83f98c22d..7068b4c06a8d 100644 --- a/mypy/treetransform.py +++ b/mypy/treetransform.py @@ -154,6 +154,9 @@ def visit_overloaded_func_def(self, node: OverloadedFuncDef) -> OverloadedFuncDe new._fullname = node._fullname new.type = self.optional_type(node.type) new.info = node.info + new.is_static = node.is_static + new.is_class = node.is_class + new.is_property = node.is_property if node.impl: new.impl = cast(OverloadPart, node.impl.accept(self)) return new diff --git a/mypy/types.py b/mypy/types.py index 040ac50e816c..9c5edb3af29c 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -14,7 +14,7 @@ from mypy import experiments from mypy.nodes import ( INVARIANT, SymbolNode, ARG_POS, ARG_OPT, ARG_STAR, ARG_STAR2, ARG_NAMED, ARG_NAMED_OPT, - FuncDef + FuncBase, FuncDef, ) from mypy.sharedparse import argument_elide_name from mypy.util import IdMapper @@ -645,6 +645,12 @@ def with_name(self, name: str) -> 'FunctionLike': pass @abstractmethod def get_name(self) -> Optional[str]: pass + @abstractmethod + def is_classmethod(self) -> bool: pass + + @abstractmethod + def is_staticmethod(self) -> bool: pass + FormalArgument = NamedTuple('FormalArgument', [ ('name', Optional[str]), @@ -828,6 +834,12 @@ def with_name(self, name: str) -> 'CallableType': def get_name(self) -> Optional[str]: return self.name + def is_classmethod(self) -> bool: + return isinstance(self.definition, FuncBase) and self.definition.is_class + + def is_staticmethod(self) -> bool: + return isinstance(self.definition, FuncBase) and self.definition.is_static + def max_fixed_args(self) -> int: n = len(self.arg_types) if self.is_var_arg: @@ -1046,6 +1058,12 @@ def with_name(self, name: str) -> 'Overloaded': def get_name(self) -> Optional[str]: return self._items[0].name + def is_classmethod(self) -> bool: + return self._items[0].is_classmethod() + + def is_staticmethod(self) -> bool: + return self._items[0].is_staticmethod() + def accept(self, visitor: 'TypeVisitor[T]') -> T: return visitor.visit_overloaded(self) diff --git a/test-data/unit/check-attr.test b/test-data/unit/check-attr.test index 9851509dad15..433d0fd319ac 100644 --- a/test-data/unit/check-attr.test +++ b/test-data/unit/check-attr.test @@ -428,6 +428,38 @@ a = A.new() reveal_type(a.foo) # E: Revealed type is 'def () -> builtins.int' [builtins fixtures/classmethod.pyi] +[case testAttrsOtherOverloads] +import attr +from typing import overload, Union + +@attr.s +class A: + a = attr.ib() + b = attr.ib(default=3) + + @classmethod + def other(cls) -> str: + return "..." + + @overload + @classmethod + def foo(cls, x: int) -> int: ... + + @overload + @classmethod + def foo(cls, x: str) -> str: ... + + @classmethod + def foo(cls, x: Union[int, str]) -> Union[int, str]: + reveal_type(cls) # E: Revealed type is 'def (a: Any, b: Any =) -> __main__.A' + reveal_type(cls.other()) # E: Revealed type is 'builtins.str' + return x + +reveal_type(A.foo(3)) # E: Revealed type is 'builtins.int' +reveal_type(A.foo("foo")) # E: Revealed type is 'builtins.str' + +[builtins fixtures/classmethod.pyi] + [case testAttrsDefaultDecorator] import attr @attr.s diff --git a/test-data/unit/check-dataclasses.test b/test-data/unit/check-dataclasses.test index aa8bad16f505..e79202aafcd1 100644 --- a/test-data/unit/check-dataclasses.test +++ b/test-data/unit/check-dataclasses.test @@ -222,6 +222,39 @@ app = Application.parse('') [builtins fixtures/list.pyi] [builtins fixtures/classmethod.pyi] +[case testDataclassesOverloadsAndClassmethods] +# flags: --python-version 3.6 +from dataclasses import dataclass +from typing import overload, Union + +@dataclass +class A: + a: int + b: str + + @classmethod + def other(cls) -> str: + return "..." + + @overload + @classmethod + def foo(cls, x: int) -> int: ... + + @overload + @classmethod + def foo(cls, x: str) -> str: ... + + @classmethod + def foo(cls, x: Union[int, str]) -> Union[int, str]: + reveal_type(cls) # E: Revealed type is 'def (a: builtins.int, b: builtins.str) -> __main__.A' + reveal_type(cls.other()) # E: Revealed type is 'builtins.str' + return x + +reveal_type(A.foo(3)) # E: Revealed type is 'builtins.int' +reveal_type(A.foo("foo")) # E: Revealed type is 'builtins.str' + +[builtins fixtures/classmethod.pyi] + [case testDataclassesClassVars] # flags: --python-version 3.6 from dataclasses import dataclass diff --git a/test-data/unit/check-overloading.test b/test-data/unit/check-overloading.test index bf05a19fba82..0b1eedf16e64 100644 --- a/test-data/unit/check-overloading.test +++ b/test-data/unit/check-overloading.test @@ -3198,3 +3198,475 @@ def add_proxy(x, y): tup = (1, '2') reveal_type(foo(lambda (x, y): add_proxy(x, y), tup)) # E: Revealed type is 'builtins.str*' [builtins fixtures/primitives.pyi] + +[case testOverloadWithClassMethods] +from typing import overload + +class Wrapper: + @overload + @classmethod + def foo(cls, x: int) -> int: ... + @overload + @classmethod + def foo(cls, x: str) -> str: ... + @classmethod + def foo(cls, x): pass + +reveal_type(Wrapper.foo(3)) # E: Revealed type is 'builtins.int' +reveal_type(Wrapper.foo("foo")) # E: Revealed type is 'builtins.str' + +[builtins fixtures/classmethod.pyi] + +[case testOverloadWithInconsistentClassMethods] +from typing import overload + +class Wrapper1: + @overload # E: Overload does not consistently use the "@classmethod" decorator on all function signatures. + @classmethod + def foo(cls, x: int) -> int: ... + @overload + @classmethod + def foo(cls, x: str) -> str: ... + def foo(cls, x): pass + +class Wrapper2: + @overload # E: Overload does not consistently use the "@classmethod" decorator on all function signatures. + @classmethod + def foo(cls, x: int) -> int: ... + @overload + def foo(cls, x: str) -> str: ... + @classmethod + def foo(cls, x): pass + +class Wrapper3: + @overload # E: Overload does not consistently use the "@classmethod" decorator on all function signatures. + def foo(cls, x: int) -> int: ... + @overload + def foo(cls, x: str) -> str: ... + @classmethod + def foo(cls, x): pass + +[builtins fixtures/classmethod.pyi] + +[case testOverloadWithSwappedDecorators] +from typing import overload + +class Wrapper1: + @classmethod + @overload + def foo(cls, x: int) -> int: ... + + @classmethod + @overload + def foo(cls, x: str) -> str: ... + + @classmethod + def foo(cls, x): pass + +class Wrapper2: + @classmethod + @overload + def foo(cls, x: int) -> int: ... + + @overload + @classmethod + def foo(cls, x: str) -> str: ... + + @classmethod + def foo(cls, x): pass + +class Wrapper3: + @classmethod # E: Overload does not consistently use the "@classmethod" decorator on all function signatures. + @overload + def foo(cls, x: int) -> int: ... + + @overload + def foo(cls, x: str) -> str: ... + + def foo(cls, x): pass + +reveal_type(Wrapper1.foo(3)) # E: Revealed type is 'builtins.int' +reveal_type(Wrapper2.foo(3)) # E: Revealed type is 'builtins.int' + +[builtins fixtures/classmethod.pyi] + +[case testOverloadFaultyClassMethodInheritance] +from typing import overload + +class A: pass +class B(A): pass +class C(B): pass + +class Parent: + @overload + @classmethod + def foo(cls, x: B) -> int: ... + + @overload + @classmethod + def foo(cls, x: str) -> str: ... + + @classmethod + def foo(cls, x): pass + +class BadChild(Parent): + @overload # E: Signature of "foo" incompatible with supertype "Parent" + @classmethod + def foo(cls, x: C) -> int: ... + + @overload + @classmethod + def foo(cls, x: str) -> str: ... + + @classmethod + def foo(cls, x): pass + +class GoodChild(Parent): + @overload + @classmethod + def foo(cls, x: A) -> int: ... + + @overload + @classmethod + def foo(cls, x: str) -> str: ... + + @classmethod + def foo(cls, x): pass + +[builtins fixtures/classmethod.pyi] + +[case testOverloadClassMethodMixingInheritance] +from typing import overload + +class BadParent: + @overload + @classmethod + def foo(cls, x: int) -> int: ... + + @overload + @classmethod + def foo(cls, x: str) -> str: ... + + @classmethod + def foo(cls, x): pass + +class BadChild(BadParent): + @overload # E: Signature of "foo" incompatible with supertype "BadParent" + def foo(cls, x: int) -> int: ... + + @overload + def foo(cls, x: str) -> str: ... + + def foo(cls, x): pass + +class GoodParent: + @overload + def foo(cls, x: int) -> int: ... + + @overload + def foo(cls, x: str) -> str: ... + + def foo(cls, x): pass + +class GoodChild(GoodParent): + @overload + @classmethod + def foo(cls, x: int) -> int: ... + + @overload + @classmethod + def foo(cls, x: str) -> str: ... + + @classmethod + def foo(cls, x): pass + +[builtins fixtures/classmethod.pyi] + +[case testOverloadClassMethodImplementation] +from typing import overload, Union + +class Wrapper: + @classmethod + def other(cls) -> str: + return "..." + + @overload + @classmethod + def foo(cls, x: int) -> int: ... + + @overload + @classmethod + def foo(cls, x: str) -> str: ... + + @classmethod # E: Overloaded function implementation cannot produce return type of signature 1 + def foo(cls, x: Union[int, str]) -> str: + reveal_type(cls) # E: Revealed type is 'def () -> __main__.Wrapper' + reveal_type(cls.other()) # E: Revealed type is 'builtins.str' + return "..." + +[builtins fixtures/classmethod.pyi] + +[case testOverloadWithStaticMethods] +from typing import overload + +class Wrapper: + @overload + @staticmethod + def foo(x: int) -> int: ... + @overload + @staticmethod + def foo(x: str) -> str: ... + @staticmethod + def foo(x): pass + +reveal_type(Wrapper.foo(3)) # E: Revealed type is 'builtins.int' +reveal_type(Wrapper.foo("foo")) # E: Revealed type is 'builtins.str' + +[builtins fixtures/staticmethod.pyi] + +[case testOverloadWithInconsistentStaticMethods] +from typing import overload, Union + +class Wrapper1: + @overload # E: Overload does not consistently use the "@staticmethod" decorator on all function signatures. + @staticmethod + def foo(x: int) -> int: ... + @overload + @staticmethod + def foo(x: str) -> str: ... + def foo(x): pass + +class Wrapper2: + @overload # E: Overload does not consistently use the "@staticmethod" decorator on all function signatures. + @staticmethod + def foo(x: int) -> int: ... + @overload + def foo(x: str) -> str: ... # E: Self argument missing for a non-static method (or an invalid type for self) + @staticmethod + def foo(x): pass + +class Wrapper3: + @overload # E: Overload does not consistently use the "@staticmethod" decorator on all function signatures. + @staticmethod + def foo(x: int) -> int: ... + @overload + @staticmethod + def foo(x: str) -> str: ... + def foo(x: Union[int, str]): pass # E: Self argument missing for a non-static method (or an invalid type for self) +[builtins fixtures/staticmethod.pyi] + +[case testOverloadWithSwappedDecorators] +from typing import overload + +class Wrapper1: + @staticmethod + @overload + def foo(x: int) -> int: ... + + @staticmethod + @overload + def foo(x: str) -> str: ... + + @staticmethod + def foo(x): pass + +class Wrapper2: + @staticmethod + @overload + def foo(x: int) -> int: ... + + @overload + @staticmethod + def foo(x: str) -> str: ... + + @staticmethod + def foo(x): pass + +class Wrapper3: + @staticmethod # E: Overload does not consistently use the "@staticmethod" decorator on all function signatures. + @overload + def foo(x: int) -> int: ... + + @overload + def foo(x: str) -> str: ... # E: Self argument missing for a non-static method (or an invalid type for self) + + @staticmethod + def foo(x): pass + +reveal_type(Wrapper1.foo(3)) # E: Revealed type is 'builtins.int' +reveal_type(Wrapper2.foo(3)) # E: Revealed type is 'builtins.int' + +[builtins fixtures/staticmethod.pyi] + +[case testOverloadFaultyStaticMethodInheritance] +from typing import overload + +class A: pass +class B(A): pass +class C(B): pass + +class Parent: + @overload + @staticmethod + def foo(x: B) -> int: ... + + @overload + @staticmethod + def foo(x: str) -> str: ... + + @staticmethod + def foo(x): pass + +class BadChild(Parent): + @overload # E: Signature of "foo" incompatible with supertype "Parent" + @staticmethod + def foo(x: C) -> int: ... + + @overload + @staticmethod + def foo(x: str) -> str: ... + + @staticmethod + def foo(x): pass + +class GoodChild(Parent): + @overload + @staticmethod + def foo(x: A) -> int: ... + + @overload + @staticmethod + def foo(x: str) -> str: ... + + @staticmethod + def foo(x): pass + +[builtins fixtures/staticmethod.pyi] + +[case testOverloadStaticMethodMixingInheritance] +from typing import overload + +class BadParent: + @overload + @staticmethod + def foo(x: int) -> int: ... + + @overload + @staticmethod + def foo(x: str) -> str: ... + + @staticmethod + def foo(x): pass + +class BadChild(BadParent): + @overload # E: Signature of "foo" incompatible with supertype "BadParent" + def foo(self, x: int) -> int: ... + + @overload + def foo(self, x: str) -> str: ... + + def foo(self, x): pass + +class GoodParent: + @overload + def foo(self, x: int) -> int: ... + + @overload + def foo(self, x: str) -> str: ... + + def foo(self, x): pass + +class GoodChild(GoodParent): + @overload + @staticmethod + def foo(x: int) -> int: ... + + @overload + @staticmethod + def foo(x: str) -> str: ... + + @staticmethod + def foo(x): pass + +[builtins fixtures/staticmethod.pyi] + +[case testOverloadStaticMethodImplementation] +from typing import overload, Union + +class Wrapper: + @staticmethod + def other() -> str: + return "..." + + @overload + @staticmethod + def foo(x: int) -> int: ... + + @overload + @staticmethod + def foo(x: str) -> str: ... + + @staticmethod # E: Overloaded function implementation cannot produce return type of signature 1 + def foo(x: Union[int, str]) -> str: + return 3 # E: Incompatible return value type (got "int", expected "str") + +[builtins fixtures/staticmethod.pyi] + +[case testOverloadAndSelfTypes] +from typing import overload, Union, TypeVar, Type + +T = TypeVar('T', bound='Parent') +class Parent: + @overload + def foo(self: T, x: int) -> T: pass + + @overload + def foo(self, x: str) -> str: pass + + def foo(self: T, x: Union[int, str]) -> Union[T, str]: + reveal_type(self.bar()) # E: Revealed type is 'builtins.str' + return self + + def bar(self) -> str: pass + +class Child(Parent): + def child_only(self) -> int: pass + +x: Union[int, str] +reveal_type(Parent().foo(3)) # E: Revealed type is '__main__.Parent*' +reveal_type(Child().foo(3)) # E: Revealed type is '__main__.Child*' +reveal_type(Child().foo("...")) # E: Revealed type is 'builtins.str' +reveal_type(Child().foo(x)) # E: Revealed type is 'Union[__main__.Child*, builtins.str]' +reveal_type(Child().foo(3).child_only()) # E: Revealed type is 'builtins.int' + +[case testOverloadAndClassTypes] +from typing import overload, Union, TypeVar, Type + +T = TypeVar('T', bound='Parent') +class Parent: + @overload + @classmethod + def foo(cls: Type[T], x: int) -> Type[T]: pass + + @overload + @classmethod + def foo(cls, x: str) -> str: pass + + @classmethod + def foo(cls: Type[T], x: Union[int, str]) -> Union[Type[T], str]: + reveal_type(cls.bar()) # E: Revealed type is 'builtins.str' + return cls + + @classmethod + def bar(cls) -> str: pass + +class Child(Parent): + def child_only(self) -> int: pass + +x: Union[int, str] +reveal_type(Parent.foo(3)) # E: Revealed type is 'Type[__main__.Parent*]' +reveal_type(Child.foo(3)) # E: Revealed type is 'Type[__main__.Child*]' +reveal_type(Child.foo("...")) # E: Revealed type is 'builtins.str' +reveal_type(Child.foo(x)) # E: Revealed type is 'Union[Type[__main__.Child*], builtins.str]' +reveal_type(Child.foo(3)().child_only()) # E: Revealed type is 'builtins.int' +[builtins fixtures/classmethod.pyi] diff --git a/test-data/unit/fine-grained.test b/test-data/unit/fine-grained.test index 037c71c221a5..26d82cd459c0 100644 --- a/test-data/unit/fine-grained.test +++ b/test-data/unit/fine-grained.test @@ -2087,6 +2087,33 @@ main:2: note: (Perhaps setting MYPYPATH or using the "--ignore-missing-imports" main:12: error: Overloaded function signature 2 will never be matched: signature 1's parameter type(s) are the same or broader main:14: error: Cannot find module named 'n' +[case testOverloadClassmethodDisappears] +from typing import overload +from m import Wrapper +reveal_type(Wrapper.foo(3)) +[file m.pyi] +from typing import overload +class Wrapper: + @overload + @classmethod + def foo(self, x: int) -> int: ... + @overload + @classmethod + def foo(self, x: str) -> str: ... +[file m.pyi.2] +from typing import overload +class Wrapper: + @overload + def foo(cls, x: int) -> int: ... + @overload + def foo(cls, x: str) -> str: ... +[builtins fixtures/classmethod.pyi] +[out] +main:3: error: Revealed type is 'builtins.int' +== +main:3: error: Revealed type is 'Any' +main:3: error: No overload variant of "foo" of "Wrapper" matches argument type "int" + [case testRefreshGenericClass] from typing import TypeVar, Generic from a import A diff --git a/test-data/unit/fixtures/staticmethod.pyi b/test-data/unit/fixtures/staticmethod.pyi index 5f1013f18213..14254e64dcb1 100644 --- a/test-data/unit/fixtures/staticmethod.pyi +++ b/test-data/unit/fixtures/staticmethod.pyi @@ -17,3 +17,4 @@ class int: class str: pass class unicode: pass class bytes: pass +class ellipsis: pass