Skip to content

Commit

Permalink
None value can be passed as a param to hybrid methods
Browse files Browse the repository at this point in the history
  • Loading branch information
kozlovsky committed Jan 10, 2022
1 parent e97de4e commit 21a9137
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 2 deletions.
56 changes: 54 additions & 2 deletions pony/orm/sqltranslation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2164,6 +2164,8 @@ def negate(monad):
def nonzero(monad):
return CmpMonad('is not', monad, NoneMonad())
def getattr(monad, attrname):
if isinstance(monad, ParamMonad):
throw(NotImplementedError, '{EXPR} for external expressions inside hybrid methods is not supported')
entity = monad.type
attr = entity._adict_.get(attrname) or entity._subclass_adict_.get(attrname)
if attr is None:
Expand Down Expand Up @@ -2229,7 +2231,7 @@ def getsql(monad, sqlquery=None):
attr = monad.attr
entity = attr.entity
pk_only = attr.pk_offset is not None
alias, parent_columns = monad.parent.tableref.make_join(pk_only)
alias, parent_columns = parent.tableref.make_join(pk_only)
if pk_only:
if entity._pk_is_composite_:
offset = attr.pk_columns_offset
Expand Down Expand Up @@ -2452,6 +2454,55 @@ class NoneMonad(ConstMonad):
def __init__(monad, value=None):
assert value is None
ConstMonad.__init__(monad, value)
def cmp(monad, op, monad2):
return CmpMonad(op, monad, monad2)
def contains(monad, item, not_in=False):
return NoneMonad()
def nonzero(monad):
return NoneMonad()
def negate(monad):
return NoneMonad()
def getattr(monad, attrname):
return NoneMonad()
def len(monad):
return NoneMonad()
def count(monad, distinct=None):
return NumericExprMonad(int, [ ['VALUE', 0] ], nullable=False)
def aggregate(monad, func_name, distinct=None, sep=None):
return NoneMonad()
def __call__(monad, *args, **kwargs):
return NoneMonad()
def __getitem__(monad, key):
return NoneMonad()
def __add__(monad, monad2):
return NoneMonad()
def __sub__(monad, monad2):
return NoneMonad()
def __mul__(monad, monad2):
return NoneMonad()
def __truediv__(monad, monad2):
return NoneMonad()
def __floordiv__(monad, monad2):
return NoneMonad()
def __pow__(monad, monad2):
return NoneMonad()
def __neg__(monad):
return NoneMonad()
def __or__(monad, monad2):
return NoneMonad()
def __and__(monad, monad2):
return NoneMonad()
def __xor__(monad, monad2):
return NoneMonad()
def abs(monad):
return NoneMonad()
def to_int(monad):
return NoneMonad()
def to_str(monad):
return NoneMonad()
def to_real(monad):
return NoneMonad()


class EllipsisMonad(ConstMonad):
pass
Expand Down Expand Up @@ -2506,7 +2557,6 @@ class CmpMonad(BoolMonad):
def __init__(monad, op, left, right):
if op == '<>': op = '!='
if left.type is NoneType:
assert right.type is not NoneType
left, right = right, left
if right.type is NoneType:
if op == '==': op = 'is'
Expand All @@ -2530,6 +2580,8 @@ def negate(monad):
return CmpMonad(cmp_negate[monad.op], monad.left, monad.right)
def getsql(monad, sqlquery=None):
op = monad.op
if monad.left.type is NoneType and monad.right.type is NoneType: # in hybrid methods
return [['EQ' if op == 'is' else 'NE', ['VALUE', 1], ['VALUE', 1]]]
left_sql = monad.left.getsql()
if op == 'is':
return [ sqland([ [ 'IS_NULL', item ] for item in left_sql ]) ]
Expand Down
24 changes: 24 additions & 0 deletions pony/orm/tests/test_hybrid_methods_and_properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,12 @@ class Car(db.Entity):
price = Required(int)
color = Required(str)

def owner_likes_color(self):
return self.owner is not None and self.owner.favorite_color == self.color

def person_likes_color(self, user):
return user is not None and user.favorite_color == self.color


def simple_func(person):
return person.full_name
Expand Down Expand Up @@ -302,6 +308,24 @@ def test_32(self):
q1 = select(p for p in Person if p.id < 4)
q2 = select(p.id for p in q1 if p.property_with_incorrect_attr_reference)

@db_session
def test_33(self):
q1 = select(c.id for c in Car if c.owner_likes_color())
self.assertEqual(set(q1), {2})

@db_session
def test_34(self):
p = None
q1 = select(c.id for c in Car if c.person_likes_color(None))
self.assertEqual(set(q1), set())

@db_session
@raises_exception(NotImplementedError, 'user.favorite_color for external expressions inside hybrid methods '
'is not supported (inside Car.person_likes_color)')
def test_35(self):
p = Person[1]
q1 = select(c.id for c in Car if c.person_likes_color(p))
self.assertEqual(set(q1), {2, 5, 6})


if __name__ == '__main__':
Expand Down

0 comments on commit 21a9137

Please sign in to comment.