Skip to content

Commit

Permalink
Merge branch master to orm
Browse files Browse the repository at this point in the history
  • Loading branch information
kozlovsky committed Jan 21, 2019
2 parents 66f5602 + 61841d3 commit efc1368
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 23 deletions.
2 changes: 1 addition & 1 deletion pony/orm/dbapiprovider.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,7 +846,7 @@ def validate(converter, val, obj=None):
return TrackedArray(obj, converter.attr, items)

def dbval2val(converter, dbval, obj=None):
if obj is None:
if obj is None or dbval is None:
return dbval
return TrackedArray(obj, converter.attr, dbval)

Expand Down
32 changes: 22 additions & 10 deletions pony/orm/sqltranslation.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,16 +157,15 @@ def dispatch_external(translator, node):
else:
is_array = True

for i, item_type in enumerate(t):
if item_type is NoneType:
throw(TypeError, 'Expression `%s` should not contain None values' % node.src)
param = ParamMonad.new(item_type, (varkey, i, None))
params.append(param)
monad = ListMonad(params)
if is_array:
array_type = array_types.get(item_type, None)
monad = ArrayParamMonad(array_type, (varkey, None, None))
else:
for i, item_type in enumerate(t):
if item_type is NoneType:
throw(TypeError, 'Expression `%s` should not contain None values' % node.src)
param = ParamMonad.new(item_type, (varkey, i, None))
params.append(param)
monad = ListMonad(params)
monad = ArrayParamMonad(array_type, (varkey, None, None), list_monad=monad)
elif isinstance(t, RawSQLType):
monad = RawSQLMonad(t, varkey)
else:
Expand Down Expand Up @@ -2049,6 +2048,11 @@ def contains(monad, key, not_in=False):
sql = 'ARRAY_CONTAINS', key.getsql()[0], not_in, monad.getsql()[0]
return BoolExprMonad(sql)
if isinstance(key, ListMonad):
if not key.items:
if not_in:
return BoolExprMonad(['EQ', ['VALUE', 0], ['VALUE', 1]], nullable=False)
else:
return BoolExprMonad(['EQ', ['VALUE', 1], ['VALUE', 1]], nullable=False)
sql = [ 'MAKE_ARRAY' ]
sql.extend(item.getsql()[0] for item in key.items)
sql = 'ARRAY_SUBSET', sql, not_in, monad.getsql()[0]
Expand Down Expand Up @@ -2231,7 +2235,7 @@ def new(t, paramkey):
result = cls(t, paramkey)
result.aggregated = False
return result
def __new__(cls, *args):
def __new__(cls, *args, **kwargs):
if cls is ParamMonad: assert False, 'Abstract class' # pragma: no cover
return Monad.__new__(cls)
def __init__(monad, t, paramkey):
Expand Down Expand Up @@ -2268,7 +2272,15 @@ class TimedeltaParamMonad(TimedeltaMixin, ParamMonad): pass
class DatetimeParamMonad(DatetimeMixin, ParamMonad): pass
class BufferParamMonad(BufferMixin, ParamMonad): pass
class UuidParamMonad(UuidMixin, ParamMonad): pass
class ArrayParamMonad(ArrayMixin, ParamMonad): pass

class ArrayParamMonad(ArrayMixin, ParamMonad):
def __init__(monad, t, paramkey, list_monad=None):
ParamMonad.__init__(monad, t, paramkey)
monad.list_monad = list_monad
def contains(monad, key, not_in=False):
if key.type is monad.type.item_type:
return monad.list_monad.contains(key, not_in)
return ArrayMixin.contains(monad, key, not_in)

class JsonParamMonad(JsonMixin, ParamMonad):
def getsql(monad, sqlquery=None):
Expand Down
31 changes: 19 additions & 12 deletions pony/orm/tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ def test_2(self):
foo = select(f for f in Foo if [10, 20, 50] in f.array1)[:]
self.assertEqual([Foo[1]], foo)

@db_session
def test_2a(self):
foo = select(f for f in Foo if [] in f.array1)[:]
self.assertEqual([Foo[1]], foo)

@db_session
def test_3(self):
x = [10, 20, 50]
Expand Down Expand Up @@ -218,33 +223,35 @@ def test_35(self):
self.assertTrue([10, 20] in foo.array1)
self.assertTrue([20, 10] in foo.array1)
self.assertTrue([10, 1000] not in foo.array1)
self.assertTrue([] in foo.array1)
self.assertTrue('bar' in foo.array3)
self.assertTrue('baz' not in foo.array3)
self.assertTrue(['foo', 'bar'] in foo.array3)
self.assertTrue(['bar', 'foo'] in foo.array3)
self.assertTrue(['baz', 'bar'] not in foo.array3)
self.assertTrue([] in foo.array3)

@db_session(sql_debug=True)
@db_session
def test_36(self):
items = []
result = select(foo for foo in Foo if foo.id in items)[:]
result = select(foo for foo in Foo if foo in items)[:]
self.assertEqual(result, [])

@db_session(sql_debug=True)
@db_session
def test_37(self):
items = [1]
result = select(foo.id for foo in Foo if foo.id in items)[:]
self.assertEqual(result, [1])

@db_session(sql_debug=True)
def test_38(self):
f1 = Foo[1]
items = [f1]
result = select(foo for foo in Foo if foo in items)[:]
self.assertEqual(result, [f1])

@db_session(sql_debug=True)
def test_39(self):
@db_session
def test_38(self):
items = []
result = select(foo for foo in Foo if foo in items)[:]
result = select(foo for foo in Foo if foo.id in items)[:]
self.assertEqual(result, [])

@db_session
def test_39(self):
items = [1]
result = select(foo.id for foo in Foo if foo.id in items)[:]
self.assertEqual(result, [1])

0 comments on commit efc1368

Please sign in to comment.