Skip to content

Commit

Permalink
Fix KeyError on flush after obj.set(**kwargs)
Browse files Browse the repository at this point in the history
  • Loading branch information
kozlovsky committed Jan 10, 2022
1 parent 895a8b1 commit e97de4e
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 5 deletions.
13 changes: 8 additions & 5 deletions pony/orm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5094,9 +5094,12 @@ def set(obj, **kwargs):
get_val = obj._vals_.get
objects_to_save = cache.objects_to_save
if avdict:
if any(attr not in obj._vals_ and attr.reverse and obj._bits_[attr] for attr in avdict):
obj._load_()

for attr in avdict:
if attr not in obj._vals_ and attr.reverse and not attr.reverse.is_collection:
attr.load(obj) # loading of one-to-one relations
if attr not in obj._vals_ and attr.reverse:
attr.load(obj) # load one-to-one and lazy relations

if wbits is not None:
new_wbits = wbits
Expand All @@ -5115,8 +5118,8 @@ def set(obj, **kwargs):
obj._vals_.update(avdict)
return

for attr, value in items_list(avdict):
if value == get_val(attr):
for attr, new_val in items_list(avdict):
if new_val == get_val(attr, NOT_LOADED):
avdict.pop(attr)

undo_funcs = []
Expand Down Expand Up @@ -5148,7 +5151,7 @@ def undo_func():
cache.update_composite_index(obj, attrs, prev_vals, new_vals, undo)
for attr, new_val in iteritems(avdict):
if not attr.reverse: continue
old_val = get_val(attr)
old_val = get_val(attr, NOT_LOADED)
attr.update_reverse(obj, old_val, new_val, undo_funcs)
for attr, new_val in iteritems(collection_avdict):
attr.__set__(obj, new_val, undo_funcs)
Expand Down
24 changes: 24 additions & 0 deletions pony/orm/tests/test_crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,22 @@ class Group(db.Entity):
students = Set('Student')

class Student(db.Entity):
id = PrimaryKey(int)
name = Required(unicode)
age = Optional(int)
passport = Optional("Passport")
scholarship = Required(Decimal, default=0)
picture = Optional(buffer, lazy=True)
email = Required(unicode, unique=True)
phone = Optional(unicode, unique=True)
courses = Set('Course')
group = Optional('Group')

class Passport(db.Entity):
id = PrimaryKey(int)
number = Required(str, unique=True)
person = Required(Student)

class Course(db.Entity):
id = PrimaryKey(int)
name = Required(unicode)
Expand All @@ -43,6 +50,7 @@ def setUpClass(cls):
s1 = Student(id=1, name='S1', age=19, email='[email protected]', group=g1)
s2 = Student(id=2, name='S2', age=21, email='[email protected]', group=g1)
s3 = Student(id=3, name='S3', email='[email protected]', group=g2)
p1 = Passport(id=1, number='111', person=1)
c1 = Course(id=1, name='Math', semester=1)
c2 = Course(id=2, name='Math', semester=2)
c3 = Course(id=3, name='Physics', semester=1)
Expand Down Expand Up @@ -114,6 +122,22 @@ def test_set4(self):
s1 = Student[1]
s1.set(name='New name', email='[email protected]')

def test_set5(self):
g2 = Group[1]
s2 = Student._get_by_raw_pkval_((1,))
s2.set(age=20, group=None)
db.flush()

def test_set6(self):
s2 = Student._get_by_raw_pkval_((1,))
s2.set(age=20, group=None, picture=None)
db.flush()

def test_set7(self):
s2 = Student._get_by_raw_pkval_((2,))
s2.set(age=22, passport=None)
db.flush()

def test_validate_1(self):
s4 = Student(id=3, name='S4', email='[email protected]', group=1)

Expand Down

0 comments on commit e97de4e

Please sign in to comment.