Skip to content

Commit

Permalink
[mypyc] Support various number-related dunders (python#10679)
Browse files Browse the repository at this point in the history
This adds support for these unary dunders:
* `__neg__`
* `__invert__`
* `__int__`
* `__float__`

Also add support for binary, reversible dunders, such as `__add__` and `__radd__`.

Finally, add support for in-place operator dunders such as `__iadd__`.

The semantics of the binary dunders don't always match Python semantics, but
many common use cases should work.

There is one significant difference from Python that is not easy to remove: if a
forward dunder method is called with an incompatible argument, it's treated the
same as if it returned `NotImplemented`. This is necessary since the body of
the method is never reached on incompatible argument type and there is no
way to explicitly return `NotImplemented`. However, it's still recommended that 
the body returns `NotImplemented` as expected for Python compatibility.

If a dunder returns `NotImplemented` and has a type annotation, the return
type should be annotated as `Union[T, Any]`, where `T` is the return value 
when `NotImplemented` is not returned.

Work on mypyc/mypyc#839.
  • Loading branch information
JukkaL authored Jun 22, 2021
1 parent 002722e commit f91fb1a
Show file tree
Hide file tree
Showing 19 changed files with 1,123 additions and 198 deletions.
15 changes: 8 additions & 7 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
is_final_node,
ARG_NAMED)
from mypy import nodes
from mypy import operators
from mypy.literals import literal, literal_hash, Key
from mypy.typeanal import has_any_from_unimported_type, check_for_explicit_any
from mypy.types import (
Expand Down Expand Up @@ -1026,13 +1027,13 @@ def is_forward_op_method(self, method_name: str) -> bool:
if self.options.python_version[0] == 2 and method_name == '__div__':
return True
else:
return method_name in nodes.reverse_op_methods
return method_name in operators.reverse_op_methods

def is_reverse_op_method(self, method_name: str) -> bool:
if self.options.python_version[0] == 2 and method_name == '__rdiv__':
return True
else:
return method_name in nodes.reverse_op_method_set
return method_name in operators.reverse_op_method_set

def check_for_missing_annotations(self, fdef: FuncItem) -> None:
# Check for functions with unspecified/not fully specified types.
Expand Down Expand Up @@ -1188,7 +1189,7 @@ def check_reverse_op_method(self, defn: FuncItem,
if self.options.python_version[0] == 2 and reverse_name == '__rdiv__':
forward_name = '__div__'
else:
forward_name = nodes.normal_from_reverse_op[reverse_name]
forward_name = operators.normal_from_reverse_op[reverse_name]
forward_inst = get_proper_type(reverse_type.arg_types[1])
if isinstance(forward_inst, TypeVarType):
forward_inst = get_proper_type(forward_inst.upper_bound)
Expand Down Expand Up @@ -1327,7 +1328,7 @@ def check_inplace_operator_method(self, defn: FuncBase) -> None:
They cannot arbitrarily overlap with __add__.
"""
method = defn.name
if method not in nodes.inplace_operator_methods:
if method not in operators.inplace_operator_methods:
return
typ = bind_self(self.function_type(defn))
cls = defn.info
Expand Down Expand Up @@ -1447,7 +1448,7 @@ def check_method_or_accessor_override_for_base(self, defn: Union[FuncDef,
# (__init__, __new__, __init_subclass__ are special).
if self.check_method_override_for_base_with_name(defn, name, base):
return True
if name in nodes.inplace_operator_methods:
if name in operators.inplace_operator_methods:
# Figure out the name of the corresponding operator method.
method = '__' + name[3:]
# An inplace operator method such as __iadd__ might not be
Expand Down Expand Up @@ -5529,9 +5530,9 @@ def infer_operator_assignment_method(typ: Type, operator: str) -> Tuple[bool, st
depending on which method is supported by the type.
"""
typ = get_proper_type(typ)
method = nodes.op_methods[operator]
method = operators.op_methods[operator]
if isinstance(typ, Instance):
if operator in nodes.ops_with_inplace_method:
if operator in operators.ops_with_inplace_method:
inplace_method = '__i' + method[2:]
if typ.type.has_readable_member(inplace_method):
return True, inplace_method
Expand Down
23 changes: 12 additions & 11 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
)
from mypy.literals import literal
from mypy import nodes
from mypy import operators
import mypy.checker
from mypy import types
from mypy.sametypes import is_same_type
Expand Down Expand Up @@ -2169,7 +2170,7 @@ def visit_op_expr(self, e: OpExpr) -> Type:
if right_radd_method is None:
return self.concat_tuples(proper_left_type, proper_right_type)

if e.op in nodes.op_methods:
if e.op in operators.op_methods:
method = self.get_operator_method(e.op)
result, method_type = self.check_op(method, left_type, e.right, e,
allow_reverse=True)
Expand Down Expand Up @@ -2234,7 +2235,7 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type:
self.msg.dangerous_comparison(left_type, cont_type, 'container', e)
else:
self.msg.add_errors(local_errors)
elif operator in nodes.op_methods:
elif operator in operators.op_methods:
method = self.get_operator_method(operator)
err_count = self.msg.errors.total_errors()
sub_result, method_type = self.check_op(method, left_type, right, e,
Expand Down Expand Up @@ -2362,7 +2363,7 @@ def get_operator_method(self, op: str) -> str:
# TODO also check for "from __future__ import division"
return '__div__'
else:
return nodes.op_methods[op]
return operators.op_methods[op]

def check_method_call_by_name(self,
method: str,
Expand Down Expand Up @@ -2537,7 +2538,7 @@ def lookup_definer(typ: Instance, attr_name: str) -> Optional[str]:
# which records tuples containing the method, base type, and the argument.

bias_right = is_proper_subtype(right_type, left_type)
if op_name in nodes.op_methods_that_shortcut and is_same_type(left_type, right_type):
if op_name in operators.op_methods_that_shortcut and is_same_type(left_type, right_type):
# When we do "A() + A()", for example, Python will only call the __add__ method,
# never the __radd__ method.
#
Expand Down Expand Up @@ -2575,8 +2576,8 @@ def lookup_definer(typ: Instance, attr_name: str) -> Optional[str]:
# When running Python 2, we might also try calling the __cmp__ method.

is_python_2 = self.chk.options.python_version[0] == 2
if is_python_2 and op_name in nodes.ops_falling_back_to_cmp:
cmp_method = nodes.comparison_fallback_method
if is_python_2 and op_name in operators.ops_falling_back_to_cmp:
cmp_method = operators.comparison_fallback_method
left_cmp_op = lookup_operator(cmp_method, left_type)
right_cmp_op = lookup_operator(cmp_method, right_type)

Expand Down Expand Up @@ -2760,7 +2761,7 @@ def get_reverse_op_method(self, method: str) -> str:
if method == '__div__' and self.chk.options.python_version[0] == 2:
return '__rdiv__'
else:
return nodes.reverse_op_methods[method]
return operators.reverse_op_methods[method]

def check_boolean_op(self, e: OpExpr, context: Context) -> Type:
"""Type check a boolean operation ('and' or 'or')."""
Expand Down Expand Up @@ -2867,7 +2868,7 @@ def visit_unary_expr(self, e: UnaryExpr) -> Type:
if op == 'not':
result = self.bool_type() # type: Type
else:
method = nodes.unary_op_methods[op]
method = operators.unary_op_methods[op]
result, method_type = self.check_method_call_by_name(method, operand_type, [], [], e)
e.method_type = method_type
return result
Expand Down Expand Up @@ -4533,9 +4534,9 @@ def is_operator_method(fullname: Optional[str]) -> bool:
return False
short_name = fullname.split('.')[-1]
return (
short_name in nodes.op_methods.values() or
short_name in nodes.reverse_op_methods.values() or
short_name in nodes.unary_op_methods.values())
short_name in operators.op_methods.values() or
short_name in operators.reverse_op_methods.values() or
short_name in operators.unary_op_methods.values())


def get_partial_instance_type(t: Optional[Type]) -> Optional[PartialType]:
Expand Down
4 changes: 2 additions & 2 deletions mypy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@
)
from mypy.typetraverser import TypeTraverserVisitor
from mypy.nodes import (
TypeInfo, Context, MypyFile, op_methods, op_methods_to_symbols,
FuncDef, reverse_builtin_aliases,
TypeInfo, Context, MypyFile, FuncDef, reverse_builtin_aliases,
ARG_POS, ARG_OPT, ARG_NAMED, ARG_NAMED_OPT, ARG_STAR, ARG_STAR2,
ReturnStmt, NameExpr, Var, CONTRAVARIANT, COVARIANT, SymbolNode,
CallExpr, IndexExpr, StrExpr, SymbolTable, TempNode
)
from mypy.operators import op_methods, op_methods_to_symbols
from mypy.subtypes import (
is_subtype, find_member, get_member_flags,
IS_SETTABLE, IS_CLASSVAR, IS_CLASS_OR_STATIC,
Expand Down
94 changes: 0 additions & 94 deletions mypy/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1634,100 +1634,6 @@ def accept(self, visitor: ExpressionVisitor[T]) -> T:
return visitor.visit_assignment_expr(self)


# Map from binary operator id to related method name (in Python 3).
op_methods = {
'+': '__add__',
'-': '__sub__',
'*': '__mul__',
'/': '__truediv__',
'%': '__mod__',
'divmod': '__divmod__',
'//': '__floordiv__',
'**': '__pow__',
'@': '__matmul__',
'&': '__and__',
'|': '__or__',
'^': '__xor__',
'<<': '__lshift__',
'>>': '__rshift__',
'==': '__eq__',
'!=': '__ne__',
'<': '__lt__',
'>=': '__ge__',
'>': '__gt__',
'<=': '__le__',
'in': '__contains__',
} # type: Final

op_methods_to_symbols = {v: k for (k, v) in op_methods.items()} # type: Final
op_methods_to_symbols['__div__'] = '/'

comparison_fallback_method = '__cmp__' # type: Final
ops_falling_back_to_cmp = {'__ne__', '__eq__',
'__lt__', '__le__',
'__gt__', '__ge__'} # type: Final


ops_with_inplace_method = {
'+', '-', '*', '/', '%', '//', '**', '@', '&', '|', '^', '<<', '>>'} # type: Final

inplace_operator_methods = set(
'__i' + op_methods[op][2:] for op in ops_with_inplace_method) # type: Final

reverse_op_methods = {
'__add__': '__radd__',
'__sub__': '__rsub__',
'__mul__': '__rmul__',
'__truediv__': '__rtruediv__',
'__mod__': '__rmod__',
'__divmod__': '__rdivmod__',
'__floordiv__': '__rfloordiv__',
'__pow__': '__rpow__',
'__matmul__': '__rmatmul__',
'__and__': '__rand__',
'__or__': '__ror__',
'__xor__': '__rxor__',
'__lshift__': '__rlshift__',
'__rshift__': '__rrshift__',
'__eq__': '__eq__',
'__ne__': '__ne__',
'__lt__': '__gt__',
'__ge__': '__le__',
'__gt__': '__lt__',
'__le__': '__ge__',
} # type: Final

# Suppose we have some class A. When we do A() + A(), Python will only check
# the output of A().__add__(A()) and skip calling the __radd__ method entirely.
# This shortcut is used only for the following methods:
op_methods_that_shortcut = {
'__add__',
'__sub__',
'__mul__',
'__div__',
'__truediv__',
'__mod__',
'__divmod__',
'__floordiv__',
'__pow__',
'__matmul__',
'__and__',
'__or__',
'__xor__',
'__lshift__',
'__rshift__',
} # type: Final

normal_from_reverse_op = dict((m, n) for n, m in reverse_op_methods.items()) # type: Final
reverse_op_method_set = set(reverse_op_methods.values()) # type: Final

unary_op_methods = {
'-': '__neg__',
'+': '__pos__',
'~': '__invert__',
} # type: Final


class OpExpr(Expression):
"""Binary operation (other than . or [] or comparison operators,
which have specific nodes)."""
Expand Down
99 changes: 99 additions & 0 deletions mypy/operators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
"""Information about Python operators"""

from typing_extensions import Final


# Map from binary operator id to related method name (in Python 3).
op_methods = {
'+': '__add__',
'-': '__sub__',
'*': '__mul__',
'/': '__truediv__',
'%': '__mod__',
'divmod': '__divmod__',
'//': '__floordiv__',
'**': '__pow__',
'@': '__matmul__',
'&': '__and__',
'|': '__or__',
'^': '__xor__',
'<<': '__lshift__',
'>>': '__rshift__',
'==': '__eq__',
'!=': '__ne__',
'<': '__lt__',
'>=': '__ge__',
'>': '__gt__',
'<=': '__le__',
'in': '__contains__',
} # type: Final

op_methods_to_symbols = {v: k for (k, v) in op_methods.items()} # type: Final
op_methods_to_symbols['__div__'] = '/'

comparison_fallback_method = '__cmp__' # type: Final
ops_falling_back_to_cmp = {'__ne__', '__eq__',
'__lt__', '__le__',
'__gt__', '__ge__'} # type: Final


ops_with_inplace_method = {
'+', '-', '*', '/', '%', '//', '**', '@', '&', '|', '^', '<<', '>>'} # type: Final

inplace_operator_methods = set(
'__i' + op_methods[op][2:] for op in ops_with_inplace_method) # type: Final

reverse_op_methods = {
'__add__': '__radd__',
'__sub__': '__rsub__',
'__mul__': '__rmul__',
'__truediv__': '__rtruediv__',
'__mod__': '__rmod__',
'__divmod__': '__rdivmod__',
'__floordiv__': '__rfloordiv__',
'__pow__': '__rpow__',
'__matmul__': '__rmatmul__',
'__and__': '__rand__',
'__or__': '__ror__',
'__xor__': '__rxor__',
'__lshift__': '__rlshift__',
'__rshift__': '__rrshift__',
'__eq__': '__eq__',
'__ne__': '__ne__',
'__lt__': '__gt__',
'__ge__': '__le__',
'__gt__': '__lt__',
'__le__': '__ge__',
} # type: Final

reverse_op_method_names = set(reverse_op_methods.values()) # type: Final

# Suppose we have some class A. When we do A() + A(), Python will only check
# the output of A().__add__(A()) and skip calling the __radd__ method entirely.
# This shortcut is used only for the following methods:
op_methods_that_shortcut = {
'__add__',
'__sub__',
'__mul__',
'__div__',
'__truediv__',
'__mod__',
'__divmod__',
'__floordiv__',
'__pow__',
'__matmul__',
'__and__',
'__or__',
'__xor__',
'__lshift__',
'__rshift__',
} # type: Final

normal_from_reverse_op = dict((m, n) for n, m in reverse_op_methods.items()) # type: Final
reverse_op_method_set = set(reverse_op_methods.values()) # type: Final

unary_op_methods = {
'-': '__neg__',
'+': '__pos__',
'~': '__invert__',
} # type: Final
4 changes: 3 additions & 1 deletion mypy/server/deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,9 @@ class 'mod.Cls'. This can also refer to an attribute inherited from a
ComparisonExpr, GeneratorExpr, DictionaryComprehension, StarExpr, PrintStmt, ForStmt, WithStmt,
TupleExpr, OperatorAssignmentStmt, DelStmt, YieldFromExpr, Decorator, Block,
TypeInfo, FuncBase, OverloadedFuncDef, RefExpr, SuperExpr, Var, NamedTupleExpr, TypedDictExpr,
LDEF, MDEF, GDEF, TypeAliasExpr, NewTypeExpr, ImportAll, EnumCallExpr, AwaitExpr,
LDEF, MDEF, GDEF, TypeAliasExpr, NewTypeExpr, ImportAll, EnumCallExpr, AwaitExpr
)
from mypy.operators import (
op_methods, reverse_op_methods, ops_with_inplace_method, unary_op_methods
)
from mypy.traverser import TraverserVisitor
Expand Down
Loading

0 comments on commit f91fb1a

Please sign in to comment.