Skip to content

Commit

Permalink
Moved all string formatter checker code into a new module
Browse files Browse the repository at this point in the history
  • Loading branch information
spkersten committed Oct 14, 2014
1 parent e329313 commit 9add187
Show file tree
Hide file tree
Showing 2 changed files with 295 additions and 237 deletions.
242 changes: 5 additions & 237 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
"""Expression type checker. This file is conceptually part of TypeChecker."""

import re

from typing import Undefined, cast, List, Tuple, Dict, Function

from mypy.types import (
Expand Down Expand Up @@ -33,6 +31,7 @@
from mypy.checkmember import analyse_member_access, type_object_type
from mypy.semanal import self_type
from mypy.constraints import get_actual_type
from mypy.checkstrformat import StringFormatterChecker


class ExpressionChecker:
Expand All @@ -46,12 +45,15 @@ class ExpressionChecker:
# This is shared with TypeChecker, but stored also here for convenience.
msg = Undefined(MessageBuilder)

strfrm_checker = Undefined('mypy.checkstrformat.StringFormatterChecker')

def __init__(self,
chk: 'mypy.checker.TypeChecker',
msg: MessageBuilder) -> None:
"""Construct an expression type checker."""
self.chk = chk
self.msg = msg
self.strfrm_checker = mypy.checkexpr.StringFormatterChecker(self, self.chk, self.msg)

def visit_name_expr(self, e: NameExpr) -> Type:
"""Type check a name expression.
Expand Down Expand Up @@ -751,7 +753,7 @@ def visit_op_expr(self, e: OpExpr) -> Type:
# Expressions of form [...] * e get special type inference.
return self.check_list_multiply(e)
if e.op == '%' and isinstance(e.left, StrExpr):
return self.check_str_interpolation(cast(StrExpr, e.left), e.right)
return self.strfrm_checker.check_str_interpolation(cast(StrExpr, e.left), e.right)
left_type = self.accept(e.left)

if e.op in nodes.op_methods:
Expand All @@ -763,240 +765,6 @@ def visit_op_expr(self, e: OpExpr) -> Type:
else:
raise RuntimeError('Unknown operator {}'.format(e.op))

def check_str_interpolation(self, str: StrExpr, replacements: Node) -> Type:
replacements = self.strip_parens(replacements)
specifiers = self.parse_conversion_specifiers(str.value)
has_mapping_keys = self.analyse_conversion_specifiers(specifiers, str)
if has_mapping_keys == None:
pass # Error was reported
elif has_mapping_keys:
self.check_mapping_str_interpolation(specifiers, replacements)
else:
self.check_simple_str_interpolation(specifiers, replacements)
return self.named_type('builtins.str')

class ConversionSpecifier:
def __init__(self, key: str, flags: str, width: str, precision: str, type: str) -> None:
self.key = key
self.flags = flags
self.width = width
self.precision = precision
self.type = type

def has_key(self):
return self.key != None

def has_star(self):
return self.width == '*' or self.precision == '*'

def parse_conversion_specifiers(self, format: str) -> List[ConversionSpecifier]:
key_regex = r'(\((\w*)\))?' # (optional) parenthesised sequence of characters
flags_regex = r'([#0\-+ ]*)' # (optional) sequence of flags
width_regex = r'(\*|[1-9][0-9]*)?' # (optional) minimum field width (* or numbers)
precision_regex = r'(?:\.(\*|[0-9]+))?' # (optional) . followed by * of numbers
length_mod_regex = r'[hlL]?' # (optional) length modifier (unused)
type_regex = r'(.)?' # conversion type
regex = ('%' + key_regex + flags_regex + width_regex +
precision_regex + length_mod_regex + type_regex)
specifiers = [] # type: List[ExpressionChecker.ConversionSpecifier]
for parens_key, key, flags, width, precision, type in re.findall(regex, format):
if parens_key == '':
key = None
specifiers.append(ExpressionChecker.ConversionSpecifier(key, flags, width, precision, type))
return specifiers

def analyse_conversion_specifiers(self, specifiers: List[ConversionSpecifier],
context: Context) -> bool:
has_star = any(specifier.has_star() for specifier in specifiers)
has_key = any(specifier.has_key() for specifier in specifiers)
all_have_keys = all(specifier.has_key() for specifier in specifiers)

if has_key and has_star:
self.msg.string_interpolation_with_star_and_key(context)
return None
if has_key and not all_have_keys:
self.msg.string_interpolation_mixing_key_and_non_keys(context)
return None
return has_key

def check_simple_str_interpolation(self, specifiers: List[ConversionSpecifier],
replacements: Node) -> None:
checkers = self.build_replacement_checkers(specifiers, replacements)
if checkers == None:
return

rhs_type = self.accept(replacements)
rep_types = [] # type: List[Type]
if isinstance(rhs_type, TupleType):
rep_types = rhs_type.items
elif isinstance(rhs_type, AnyType):
return
else:
rep_types = [rhs_type]

if len(checkers) > len(rep_types):
self.msg.too_few_string_formatting_arguments(replacements)
elif len(checkers) < len(rep_types):
self.msg.too_many_string_formatting_arguments(replacements)
else:
if len(checkers) == 1:
check_node, check_type = checkers[0]
check_node(replacements)
elif isinstance(replacements, TupleExpr):
for checks, rep_node in zip(checkers, replacements.items):
check_node, check_type = checks
check_node(rep_node)
else:
for checks, rep_type in zip(checkers, rep_types):
check_node, check_type = checks
check_type(rep_type)

def check_mapping_str_interpolation(self, specifiers: List[ConversionSpecifier],
replacements: Node) -> None:
dict_with_only_str_literal_keys = (isinstance(replacements, DictExpr) and
all(isinstance(self.strip_parens(k), StrExpr)
for k, v in cast(DictExpr, replacements).items))
if dict_with_only_str_literal_keys:
mapping = {} # type: Dict[str, Type]
for k, v in cast(DictExpr, replacements).items:
key_str = cast(StrExpr, k).value
mapping[key_str] = self.accept(v)

for specifier in specifiers:
if specifier.key not in mapping:
self.msg.key_not_in_mapping(specifier.key, replacements)
return
rep_type = mapping[specifier.key]
expected_type = self.conversion_type(specifier.type, replacements)
if expected_type == None:
return
self.chk.check_subtype(rep_type, expected_type, replacements,
messages.INCOMPATIBLE_TYPES_IN_STR_INTERPOLATION,
'expression has type', 'placeholder with key \'%s\' has type' % specifier.key)
else:
rep_type = self.accept(replacements)
dict_type = self.chk.named_generic_type('builtins.dict',
[AnyType(), AnyType()])
self.chk.check_subtype(rep_type, dict_type, replacements, messages.FORMAT_REQUIRES_MAPPING,
'expression has type', 'expected type for mapping is')

def build_replacement_checkers(self, specifiers: List[ConversionSpecifier],
context: Context) -> List[ Tuple[ Function[[Node], None],
Function[[Type], None] ] ]:
checkers = [] # type: List[ Tuple[ Function[[Node], None], Function[[Type], None] ] ]
for specifier in specifiers:
checker = self.replacement_checkers(specifier, context)
if checker == None:
return None
checkers.extend(checker)
return checkers

def replacement_checkers(self, specifier: ConversionSpecifier,
context: Context) -> List[ Tuple[ Function[[Node], None],
Function[[Type], None] ] ]:
"""Returns a list of tuples of two functions that check whether a replacement is
of the right type for the specifier. The first functions take a node and checks
its type in the right type context. The second function just checks a type.
"""
checkers = [] # type: List[ Tuple[ Function[[Node], None], Function[[Type], None] ] ]

if specifier.width == '*':
checkers.append(self.checkers_for_star(context))
if specifier.precision == '*':
checkers.append(self.checkers_for_star(context))
if specifier.type == 'c':
c = self.checkers_for_c_type(specifier.type, context)
if c == None:
return None
checkers.append(c)
elif specifier.type != '%':
c = self.checkers_for_regular_type(specifier.type, context)
if c == None:
return None
checkers.append(c)
return checkers

def checkers_for_star(self, context: Context) -> Tuple[ Function[[Node], None],
Function[[Type], None] ]:
"""Returns a tuple of check functions that check whether, respectively,
a node or a type is compatible with a star in a conversion specifier
"""
expected = self.named_type('builtins.int')

def check_type(type: Type = None) -> None:
expected = self.named_type('builtins.int')
self.chk.check_subtype(type, expected, context, '* wants int')

def check_node(node: Node) -> None:
type = self.accept(node, expected)
check_type(type)

return check_node, check_type

def checkers_for_regular_type(self, type: str, context: Context) -> Tuple[ Function[[Node], None],
Function[[Type], None] ]:
"""Returns a tuple of check functions that check whether, respectively,
a node or a type is compatible with 'type'. Return None in case of an
"""
expected_type = self.conversion_type(type, context)
if expected_type == None:
return None

def check_type(type: Type = None) -> None:
self.chk.check_subtype(type, expected_type, context,
messages.INCOMPATIBLE_TYPES_IN_STR_INTERPOLATION,
'expression has type', 'placeholder has type')

def check_node(node: Node) -> None:
type = self.accept(node, expected_type)
check_type(type)

return check_node, check_type

def checkers_for_c_type(self, type: str, context: Context) -> Tuple[ Function[[Node], None],
Function[[Type], None] ]:
"""Returns a tuple of check functions that check whether, respectively,
a node or a type is compatible with 'type' that is a character type
"""
expected_type = self.conversion_type(type, context)
if expected_type == None:
return None

def check_type(type: Type = None) -> None:
self.chk.check_subtype(type, expected_type, context,
messages.INCOMPATIBLE_TYPES_IN_STR_INTERPOLATION,
'expression has type', 'placeholder has type')

def check_node(node: Node) -> None:
"""int, or str with length 1"""
type = self.accept(node, expected_type)
if isinstance(node, StrExpr) and len(cast(StrExpr, node).value) != 1:
self.msg.requires_int_or_char(context)
check_type(type)

return check_node, check_type

def conversion_type(self, p: str, context: Context) -> Type:
"""Return the type that is accepted for a string interpolation
conversion specifier type.
Note that both Python's float (e.g. %f) and integer (e.g. %d)
specifier types accept both float and integers.
"""
if p in ['s', 'r']:
return AnyType()
elif p in ['d', 'i', 'o', 'u', 'x', 'X',
'e', 'E', 'f', 'F', 'g', 'G']:
return UnionType([self.named_type('builtins.int'),
self.named_type('builtins.float')])
elif p in ['c']:
return UnionType([self.named_type('builtins.int'),
self.named_type('builtins.float'),
self.named_type('builtins.str')])
else:
self.msg.unsupported_placeholder(p, context)
return None

def strip_parens(self, node: Node) -> Node:
if isinstance(node, ParenExpr):
return self.strip_parens(node.expr)
Expand Down
Loading

0 comments on commit 9add187

Please sign in to comment.