Skip to content

Commit

Permalink
Fix required and optional keys inheritance for TypedDict (python#700)
Browse files Browse the repository at this point in the history
(For a complete description of the issue see python#700.)
  • Loading branch information
vemel authored Feb 12, 2020
1 parent fdc9359 commit e796957
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 13 deletions.
34 changes: 34 additions & 0 deletions typing_extensions/src_py3/test_typing_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,16 @@ class LabelPoint2D(Point2D, Label): ...
class Options(TypedDict, total=False):
log_level: int
log_path: str
class BaseAnimal(TypedDict):
name: str
class Animal(BaseAnimal, total=False):
voice: str
tail: bool
class Cat(Animal):
fur_color: str
"""

if PY36:
Expand All @@ -444,6 +454,7 @@ class Options(TypedDict, total=False):
A = B = CSub = G = CoolEmployee = CoolEmployeeWithDefault = object
XMeth = XRepr = HasCallProtocol = NoneAndForward = Loop = object
Point2D = Point2Dor3D = LabelPoint2D = Options = object
BaseAnimal = Animal = Cat = object

gth = get_type_hints

Expand Down Expand Up @@ -1549,6 +1560,29 @@ def test_optional_keys(self):
assert Point2Dor3D.__required_keys__ == frozenset(['x', 'y'])
assert Point2Dor3D.__optional_keys__ == frozenset(['z'])

@skipUnless(PY36, 'Python 3.6 required')
def test_keys_inheritance(self):
assert BaseAnimal.__required_keys__ == frozenset(['name'])
assert BaseAnimal.__optional_keys__ == frozenset([])
assert BaseAnimal.__annotations__ == {'name': str}

assert Animal.__required_keys__ == frozenset(['name'])
assert Animal.__optional_keys__ == frozenset(['tail', 'voice'])
assert Animal.__annotations__ == {
'name': str,
'tail': bool,
'voice': str,
}

assert Cat.__required_keys__ == frozenset(['name', 'fur_color'])
assert Cat.__optional_keys__ == frozenset(['tail', 'voice'])
assert Cat.__annotations__ == {
'fur_color': str,
'name': str,
'tail': bool,
'voice': str,
}


@skipUnless(TYPING_3_5_3, "Python >= 3.5.3 required")
class AnnotatedTests(BaseTestCase):
Expand Down
33 changes: 20 additions & 13 deletions typing_extensions/src_py3/typing_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1651,23 +1651,30 @@ def __new__(cls, name, bases, ns, total=True):
ns['__new__'] = _typeddict_new if name == 'TypedDict' else _dict_new
tp_dict = super(_TypedDictMeta, cls).__new__(cls, name, (dict,), ns)

anns = ns.get('__annotations__', {})
annotations = {}
own_annotations = ns.get('__annotations__', {})
own_annotation_keys = set(own_annotations.keys())
msg = "TypedDict('Name', {f0: t0, f1: t1, ...}); each t must be a type"
anns = {n: typing._type_check(tp, msg) for n, tp in anns.items()}
required = set(anns if total else ())
optional = set(() if total else anns)
own_annotations = {
n: typing._type_check(tp, msg) for n, tp in own_annotations.items()
}
required_keys = set()
optional_keys = set()

for base in bases:
base_anns = base.__dict__.get('__annotations__', {})
anns.update(base_anns)
if getattr(base, '__total__', True):
required.update(base_anns)
else:
optional.update(base_anns)
annotations.update(base.__dict__.get('__annotations__', {}))
required_keys.update(base.__dict__.get('__required_keys__', ()))
optional_keys.update(base.__dict__.get('__optional_keys__', ()))

annotations.update(own_annotations)
if total:
required_keys.update(own_annotation_keys)
else:
optional_keys.update(own_annotation_keys)

tp_dict.__annotations__ = anns
tp_dict.__required_keys__ = frozenset(required)
tp_dict.__optional_keys__ = frozenset(optional)
tp_dict.__annotations__ = annotations
tp_dict.__required_keys__ = frozenset(required_keys)
tp_dict.__optional_keys__ = frozenset(optional_keys)
if not hasattr(tp_dict, '__total__'):
tp_dict.__total__ = total
return tp_dict
Expand Down

0 comments on commit e796957

Please sign in to comment.