forked from celery/django-celery
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmanagers.py
209 lines (150 loc) · 6.55 KB
/
managers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
import warnings
from itertools import count
from datetime import datetime
from celery.utils.functional import wraps
from django.db import transaction, connection
try:
from django.db import connections, router
except ImportError: # pre-Django 1.2
connections = router = None
from django.db import models
from django.db.models.query import QuerySet
from django.conf import settings
class TxIsolationWarning(UserWarning):
pass
def transaction_retry(max_retries=1):
"""Decorator for methods doing database operations.
If the database operation fails, it will retry the operation
at most ``max_retries`` times.
"""
def _outer(fun):
@wraps(fun)
def _inner(*args, **kwargs):
_max_retries = kwargs.pop("exception_retry_count", max_retries)
for retries in count(0):
try:
return fun(*args, **kwargs)
except Exception: # pragma: no cover
# Depending on the database backend used we can experience
# various exceptions. E.g. psycopg2 raises an exception
# if some operation breaks the transaction, so saving
# the task result won't be possible until we rollback
# the transaction.
if retries >= _max_retries:
raise
transaction.rollback_unless_managed()
return _inner
return _outer
def update_model_with_dict(obj, fields):
[setattr(obj, attr_name, attr_value)
for attr_name, attr_value in fields.items()]
obj.save()
return obj
class ExtendedQuerySet(QuerySet):
def update_or_create(self, **kwargs):
obj, created = self.get_or_create(**kwargs)
if not created:
fields = dict(kwargs.pop("defaults", {}))
fields.update(kwargs)
update_model_with_dict(obj, fields)
return obj
class ExtendedManager(models.Manager):
def get_query_set(self):
return ExtendedQuerySet(self.model)
def update_or_create(self, **kwargs):
return self.get_query_set().update_or_create(**kwargs)
def connection_for_write(self):
if connections:
return connections[router.db_for_write(self.model)]
return connection
def connection_for_read(self):
if connections:
return connections[self.db]
return connection
class ResultManager(ExtendedManager):
def get_all_expired(self, expires):
"""Get all expired task results."""
return self.filter(date_done__lt=datetime.now() - expires)
def delete_expired(self, expires):
"""Delete all expired taskset results."""
self.get_all_expired(expires).delete()
class PeriodicTaskManager(ExtendedManager):
def enabled(self):
return self.filter(enabled=True)
class TaskManager(ResultManager):
"""Manager for :class:`celery.models.Task` models."""
_last_id = None
def get_task(self, task_id):
"""Get task meta for task by ``task_id``.
:keyword exception_retry_count: How many times to retry by
transaction rollback on exception. This could theoretically
happen in a race condition if another worker is trying to
create the same task. The default is to retry once.
"""
try:
return self.get(task_id=task_id)
except self.model.DoesNotExist:
if self._last_id == task_id:
self.warn_if_repeatable_read()
self._last_id = task_id
return self.model(task_id=task_id)
@transaction_retry(max_retries=2)
def store_result(self, task_id, result, status, traceback=None):
"""Store the result and status of a task.
:param task_id: task id
:param result: The return value of the task, or an exception
instance raised by the task.
:param status: Task status. See
:meth:`celery.result.AsyncResult.get_status` for a list of
possible status values.
:keyword traceback: The traceback at the point of exception (if the
task failed).
:keyword exception_retry_count: How many times to retry by
transaction rollback on exception. This could theoretically
happen in a race condition if another worker is trying to
create the same task. The default is to retry twice.
"""
return self.update_or_create(task_id=task_id,
defaults={"status": status,
"result": result,
"traceback": traceback})
def warn_if_repeatable_read(self):
if settings.DATABASE_ENGINE.lower() == "mysql":
cursor = self.connection_for_read().cursor()
if cursor.execute("SELECT @@tx_isolation"):
isolation = cursor.fetchone()[0]
if isolation == 'REPEATABLE-READ':
warnings.warn(TxIsolationWarning(
"Polling results with transaction isolation level "
"repeatable-read within the same transaction "
"may give outdated results. Be sure to commit the "
"transaction for each poll iteration."))
class TaskSetManager(ResultManager):
"""Manager for :class:`celery.models.TaskSet` models."""
def restore_taskset(self, taskset_id):
"""Get taskset meta for task by ``taskset_id``."""
try:
return self.get(taskset_id=taskset_id)
except self.model.DoesNotExist:
pass
@transaction_retry(max_retries=2)
def store_result(self, taskset_id, result):
"""Store the result of a taskset.
:param taskset_id: task set id
:param result: The return value of the taskset
"""
return self.update_or_create(taskset_id=taskset_id,
defaults={"result": result})
class TaskStateManager(ExtendedManager):
def active(self):
return self.filter(hidden=False)
def expired(self, states, expires):
return self.filter(state__in=states,
tstamp__lte=datetime.now() - expires)
def expire_by_states(self, states, expires):
return self.expired(states, expires).update(hidden=True)
def purge(self):
cursor = self.connection_for_write().cursor()
cursor.execute("DELETE FROM %s WHERE hidden=%%s" % (
self.model._meta.db_table, ), (True, ))
transaction.commit_unless_managed()