Skip to content

Commit 151a1df

Browse files
authored
Merge pull request ipython#415 from minrk/mro-defaults
follow mro for trait default generators
2 parents d64bba2 + 644a649 commit 151a1df

File tree

2 files changed

+47
-8
lines changed

2 files changed

+47
-8
lines changed

traitlets/tests/test_traitlets.py

+33
Original file line numberDiff line numberDiff line change
@@ -2496,6 +2496,39 @@ class SuperHasTraits(HasTraits):
24962496
assert not hasattr(obj, 'b')
24972497

24982498

2499+
def test_default_mro():
2500+
"""Verify that default values follow mro"""
2501+
class Base(HasTraits):
2502+
trait = Unicode('base')
2503+
attr = 'base'
2504+
2505+
class A(Base):
2506+
pass
2507+
2508+
class B(Base):
2509+
trait = Unicode('B')
2510+
attr = 'B'
2511+
2512+
class AB(A, B):
2513+
pass
2514+
2515+
class BA(B, A):
2516+
pass
2517+
2518+
assert 'trait' in Base._trait_default_generators
2519+
assert 'trait' not in A._trait_default_generators
2520+
assert 'trait' in B._trait_default_generators
2521+
assert 'trait' not in AB._trait_default_generators
2522+
assert 'trait' not in BA._trait_default_generators
2523+
2524+
assert A().trait == 'base'
2525+
assert A().attr == 'base'
2526+
assert BA().trait == 'B'
2527+
assert BA().attr == 'B'
2528+
assert AB().trait == 'B'
2529+
assert AB().attr == 'B'
2530+
2531+
24992532
def test_cls_self_argument():
25002533
class X(HasTraits):
25012534
def __init__(__self, cls, self):

traitlets/traitlets.py

+14-8
Original file line numberDiff line numberDiff line change
@@ -490,7 +490,7 @@ def get(self, obj, cls=None):
490490
value = obj._trait_values[self.name]
491491
except KeyError:
492492
# Check for a dynamic initializer.
493-
default = cls._trait_default_generators[self.name](obj)
493+
default = obj.trait_defaults(self.name)
494494
if default is Undefined:
495495
raise TraitError("No default value found for "
496496
"the '%s' trait named '%s' of %r" % (
@@ -776,11 +776,6 @@ class MetaHasTraits(MetaHasDescriptors):
776776
def setup_class(cls, classdict):
777777
cls._trait_default_generators = {}
778778
super(MetaHasTraits, cls).setup_class(classdict)
779-
new = {}
780-
for c in reversed(cls.mro()):
781-
if hasattr(c, "_trait_default_generators"):
782-
new.update(c._trait_default_generators)
783-
cls._trait_default_generators = new
784779

785780

786781

@@ -1460,6 +1455,17 @@ def trait_values(self, **metadata):
14601455
"""
14611456
return {name: getattr(self, name) for name in self.trait_names(**metadata)}
14621457

1458+
@classmethod
1459+
def _get_trait_default_generator(cls, name):
1460+
"""Return default generator for a given trait
1461+
1462+
Walk the MRO to resolve the correct default generator according to inheritance.
1463+
"""
1464+
for c in cls.mro():
1465+
if name in c.__dict__.get('_trait_default_generators', {}):
1466+
return c._trait_default_generators[name]
1467+
raise KeyError("No default generator for trait %r found in %r" % (name, cls.mro()))
1468+
14631469
def trait_defaults(self, *names, **metadata):
14641470
"""Return a trait's default value or a dictionary of them
14651471
@@ -1468,7 +1474,7 @@ def trait_defaults(self, *names, **metadata):
14681474
Dynamically generated default values may
14691475
depend on the current state of the object."""
14701476
if len(names) == 1 and len(metadata) == 0:
1471-
return self._trait_default_generators[names[0]](self)
1477+
return self._get_trait_default_generator(names[0])(self)
14721478

14731479
for n in names:
14741480
if not has_trait(self, n):
@@ -1479,7 +1485,7 @@ def trait_defaults(self, *names, **metadata):
14791485

14801486
defaults = {}
14811487
for n in trait_names:
1482-
defaults[n] = self._trait_default_generators[n](self)
1488+
defaults[n] = self._get_trait_default_generator(n)(self)
14831489
return defaults
14841490

14851491
def trait_names(self, **metadata):

0 commit comments

Comments
 (0)