Skip to content

Commit

Permalink
Apply --strict-equality special-casing for bytes also to bytearray (p…
Browse files Browse the repository at this point in the history
  • Loading branch information
ilevkivskyi authored Sep 5, 2019
1 parent 6dce58f commit 8e960b3
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 6 deletions.
10 changes: 6 additions & 4 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2103,8 +2103,9 @@ def dangerous_comparison(self, left: Type, right: Type,
right = remove_optional(right)
if (original_container and has_bytes_component(original_container) and
has_bytes_component(left)):
# We need to special case bytes, because both 97 in b'abc' and b'a' in b'abc'
# return True (and we want to show the error only if the check can _never_ be True).
# We need to special case bytes and bytearray, because 97 in b'abc', b'a' in b'abc',
# b'a' in bytearray(b'abc') etc. all return True (and we want to show the error only
# if the check can _never_ be True).
return False
if isinstance(left, Instance) and isinstance(right, Instance):
# Special case some builtin implementations of AbstractSet.
Expand Down Expand Up @@ -4136,11 +4137,12 @@ def custom_equality_method(typ: Type) -> bool:


def has_bytes_component(typ: Type) -> bool:
"""Is this the builtin bytes type, or a union that contains it?"""
"""Is this one of builtin byte types, or a union that contains it?"""
typ = get_proper_type(typ)
if isinstance(typ, UnionType):
return any(has_bytes_component(t) for t in typ.items)
if isinstance(typ, Instance) and typ.type.fullname() == 'builtins.bytes':
if isinstance(typ, Instance) and typ.type.fullname() in {'builtins.bytes',
'builtins.bytearray'}:
return True
return False

Expand Down
7 changes: 7 additions & 0 deletions test-data/unit/check-expressions.test
Original file line number Diff line number Diff line change
Expand Up @@ -2059,6 +2059,13 @@ x in b'abc'
[builtins fixtures/primitives.pyi]
[typing fixtures/typing-full.pyi]

[case testStrictEqualityByteArraySpecial]
# flags: --strict-equality
b'abc' in bytearray(b'abcde')
bytearray(b'abc') in b'abcde' # OK on Python 3
[builtins fixtures/primitives.pyi]
[typing fixtures/typing-full.pyi]

[case testStrictEqualityNoPromotePy3]
# flags: --strict-equality
'a' == b'a' # E: Non-overlapping equality check (left operand type: "Literal['a']", right operand type: "Literal[b'a']")
Expand Down
2 changes: 1 addition & 1 deletion test-data/unit/check-type-promotion.test
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ f(1)

[case testPromoteBytearrayToByte]
def f(x: bytes) -> None: pass
f(bytearray())
f(bytearray(b''))
[builtins fixtures/primitives.pyi]

[case testNarrowingDownFromPromoteTargetType]
Expand Down
6 changes: 5 additions & 1 deletion test-data/unit/fixtures/primitives.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@ class bytes(Sequence[int]):
def __iter__(self) -> Iterator[int]: pass
def __contains__(self, other: object) -> bool: pass
def __getitem__(self, item: int) -> int: pass
class bytearray: pass
class bytearray(Sequence[int]):
def __init__(self, x: bytes) -> None: pass
def __iter__(self) -> Iterator[int]: pass
def __contains__(self, other: object) -> bool: pass
def __getitem__(self, item: int) -> int: pass
class tuple(Generic[T]): pass
class function: pass
class ellipsis: pass

0 comments on commit 8e960b3

Please sign in to comment.