-
Notifications
You must be signed in to change notification settings - Fork 151
/
Copy pathtest_bakery.py
162 lines (132 loc) · 4.99 KB
/
test_bakery.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import pytest
import sqlalchemy
from gino import UninitializedError, create_engine, InitializedError
from gino.bakery import Bakery, BakedQuery
from .models import db, User, MYSQL_URL
pytestmark = pytest.mark.asyncio
@pytest.mark.parametrize(
"query",
[
User.query.where(User.id == db.bindparam("uid")),
sqlalchemy.text("SELECT * FROM gino_users WHERE id = :uid"),
"SELECT * FROM gino_users WHERE id = :uid",
lambda: User.query.where(User.id == db.bindparam("uid")),
lambda: sqlalchemy.text("SELECT * FROM gino_users WHERE id = :uid"),
lambda: "SELECT * FROM gino_users WHERE id = :uid",
],
)
@pytest.mark.parametrize("options", [dict(return_model=False), dict(loader=User)])
@pytest.mark.parametrize("api", [True, False])
@pytest.mark.parametrize("timeout", [None, 1])
async def test(query, options, sa_engine, api, timeout):
uid = sa_engine.execute(User.insert()).lastrowid
if timeout:
options["timeout"] = timeout
if api:
b = db._bakery
qs = [db.bake(query, **options)]
if callable(query):
qs.append(db.bake(**options)(query))
else:
b = Bakery()
qs = [b.bake(query, **options)]
if callable(query):
qs.append(b.bake(**options)(query))
for q in qs:
assert isinstance(q, BakedQuery)
assert q in list(b)
assert q.sql is None
assert q.compiled_sql is None
with pytest.raises(UninitializedError):
q.bind.first()
with pytest.raises(UninitializedError):
await q.first()
for k, v in options.items():
assert q.query.get_execution_options()[k] == v
if api:
e = await db.set_bind(MYSQL_URL, minsize=1)
else:
e = await create_engine(MYSQL_URL, bakery=b, minsize=1)
with pytest.raises(InitializedError):
b.bake("SELECT now()")
with pytest.raises(InitializedError):
await create_engine(MYSQL_URL, bakery=b, minsize=0)
try:
for q in qs:
assert q.sql is not None
assert q.compiled_sql is not None
if api:
assert q.bind is e
else:
with pytest.raises(UninitializedError):
q.bind.first()
with pytest.raises(UninitializedError):
await q.first()
if api:
rv = await q.first(uid=uid)
else:
rv = await e.first(q, uid=uid)
if options.get("return_model", True):
assert isinstance(rv, User)
assert rv.id == uid
else:
assert rv[0] == rv[User.id] == rv["id"] == uid
eq = q.execution_options(return_model=True, loader=User)
assert eq is not q
assert isinstance(eq, BakedQuery)
assert type(eq) is not BakedQuery
assert eq in list(b)
assert eq.sql == q.sql
assert eq.compiled_sql is not q.compiled_sql
if api:
assert q.bind is e
else:
with pytest.raises(UninitializedError):
eq.bind.first()
with pytest.raises(UninitializedError):
await eq.first()
assert eq.query.get_execution_options()["return_model"]
assert eq.query.get_execution_options()["loader"] is User
if api:
rv = await eq.first(uid=uid)
non = await eq.first(uid=uid + 1)
rvl = await eq.all(uid=uid)
else:
rv = await e.first(eq, uid=uid)
non = await e.first(eq, uid=uid + 1)
rvl = await e.all(eq, uid=uid)
assert isinstance(rv, User)
assert rv.id == uid
assert non is None
assert len(rvl) == 1
assert rvl[0].id == uid
# original query is not affected
if api:
rv = await q.first(uid=uid)
else:
rv = await e.first(q, uid=uid)
if options.get("return_model", True):
assert isinstance(rv, User)
assert rv.id == uid
else:
assert rv[0] == rv[User.id] == rv["id"] == uid
finally:
if api:
await db.pop_bind().close()
else:
await e.close()
async def test_class_level_bake():
class BakeOnClass(db.Model):
__tablename__ = "bake_on_class_test"
name = db.Column(db.String(255), primary_key=True)
@db.bake
def getter(cls):
return cls.query.where(cls.name == db.bindparam("name"))
async with db.with_bind(MYSQL_URL, prebake=False):
await db.gino.create_all()
try:
await BakeOnClass.create(name="exist")
assert (await BakeOnClass.getter.one(name="exist")).name == "exist"
assert (await BakeOnClass.getter.one_or_none(name="nonexist")) is None
finally:
await db.gino.drop_all()