forked from AmbitionEng/django-pgtrigger
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcontrib.py
220 lines (162 loc) · 7.4 KB
/
contrib.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
"""Additional goodies"""
import functools
import operator
from pgtrigger import core
from pgtrigger import utils
# A sentinel value to determine if a kwarg is unset
_unset = object()
class Protect(core.Trigger):
"""A trigger that raises an exception."""
when = core.Before
def get_func(self, model):
sql = f"""
RAISE EXCEPTION
'pgtrigger: Cannot {str(self.operation).lower()} rows from % table',
TG_TABLE_NAME;
"""
return self.format_sql(sql)
class ReadOnly(Protect):
"""A trigger that prevents edits to fields.
If ``fields`` are provided, will protect edits to only those fields.
If ``exclude`` is provided, will protect all fields except the ones
excluded.
If none of these arguments are provided, all fields cannot be edited.
"""
fields = None
exclude = None
operation = core.Update
def __init__(self, *, fields=None, exclude=None, **kwargs):
self.fields = fields or self.fields
self.exclude = exclude or self.exclude
if self.fields and self.exclude:
raise ValueError('Must provide only one of "fields" or "exclude" to ReadOnly trigger')
super().__init__(**kwargs)
def get_condition(self, model):
if not self.fields and not self.exclude:
return core.Condition("OLD.* IS DISTINCT FROM NEW.*")
else:
if self.exclude:
# Sanity check that the exclude list contains valid fields
for field in self.exclude:
model._meta.get_field(field)
fields = [f.name for f in model._meta.fields if f.name not in self.exclude]
else:
fields = [model._meta.get_field(field).name for field in self.fields]
return functools.reduce(
operator.or_,
[core.Q(**{f"old__{field}__df": core.F(f"new__{field}")}) for field in fields],
)
class FSM(core.Trigger):
"""Enforces a finite state machine on a field.
Supply the trigger with the "field" that transitions and then
a list of tuples of valid transitions to the "transitions" argument.
.. note::
Only non-null ``CharField`` fields are currently supported.
"""
when = core.Before
operation = core.Update
field = None
transitions = None
def __init__(self, *, name=None, condition=None, field=None, transitions=None):
self.field = field or self.field
self.transitions = transitions or self.transitions
if not self.field: # pragma: no cover
raise ValueError('Must provide "field" for FSM')
if not self.transitions: # pragma: no cover
raise ValueError('Must provide "transitions" for FSM')
super().__init__(name=name, condition=condition)
def get_declare(self, model):
return [("_is_valid_transition", "BOOLEAN")]
def get_func(self, model):
col = model._meta.get_field(self.field).column
transition_uris = "{" + ",".join([f"{old}:{new}" for old, new in self.transitions]) + "}"
sql = f"""
SELECT CONCAT(OLD.{utils.quote(col)}, ':', NEW.{utils.quote(col)}) = ANY('{transition_uris}'::text[])
INTO _is_valid_transition;
IF (_is_valid_transition IS FALSE AND OLD.{utils.quote(col)} IS DISTINCT FROM NEW.{utils.quote(col)}) THEN
RAISE EXCEPTION
'pgtrigger: Invalid transition of field "{self.field}" from "%" to "%" on table %',
OLD.{utils.quote(col)},
NEW.{utils.quote(col)},
TG_TABLE_NAME;
ELSE
RETURN NEW;
END IF;
""" # noqa
return self.format_sql(sql)
class SoftDelete(core.Trigger):
"""Sets a field to a value when a delete happens.
Supply the trigger with the "field" that will be set
upon deletion and the "value" to which it should be set.
The "value" defaults to ``False``.
.. note::
This trigger currently only supports nullable ``BooleanField``,
``CharField``, and ``IntField`` fields.
"""
when = core.Before
operation = core.Delete
field = None
value = False
def __init__(self, *, name=None, condition=None, field=None, value=_unset):
self.field = field or self.field
self.value = value if value is not _unset else self.value
if not self.field: # pragma: no cover
raise ValueError('Must provide "field" for soft delete')
super().__init__(name=name, condition=condition)
def get_func(self, model):
soft_field = model._meta.get_field(self.field).column
pk_col = model._meta.pk.column
def _render_value():
if self.value is None:
return "NULL"
elif isinstance(self.value, str):
return f"'{self.value}'"
else:
return str(self.value)
sql = f"""
UPDATE {utils.quote(model._meta.db_table)}
SET {soft_field} = {_render_value()}
WHERE {utils.quote(pk_col)} = OLD.{utils.quote(pk_col)};
RETURN NULL;
"""
return self.format_sql(sql)
class UpdateSearchVector(core.Trigger):
"""Updates a ``django.contrib.postgres.search.SearchVectorField`` from document fields.
Supply the trigger with the ``vector_field`` that will be updated with
changes to the ``document_fields``. Optionally provide a ``config_name``, which
defaults to ``pg_catalog.english``.
This trigger uses ``tsvector_update_trigger`` to update the vector field.
See `the Postgres docs <https://www.postgresql.org/docs/current/textsearch-features.html#TEXTSEARCH-UPDATE-TRIGGERS>`__
for more information.
.. note::
``UpdateSearchVector`` triggers are not compatible with `pgtrigger.ignore` since
it references a built-in trigger. Trying to ignore this trigger results in a
`RuntimeError`.
""" # noqa
when = core.Before
vector_field = None
document_fields = None
config_name = "pg_catalog.english"
def __init__(self, *, name=None, vector_field=None, document_fields=None, config_name=None):
self.vector_field = vector_field or self.vector_field
self.document_fields = document_fields or self.document_fields
self.config_name = config_name or self.config_name
if not self.vector_field:
raise ValueError('Must provide "vector_field" to update search vector')
if not self.document_fields:
raise ValueError('Must provide "document_fields" to update search vector')
if not self.config_name: # pragma: no cover
raise ValueError('Must provide "config_name" to update search vector')
super().__init__(name=name, operation=core.Insert | core.UpdateOf(*document_fields))
def ignore(self, model):
raise RuntimeError(f"Cannot ignore {self.__class__.__name__} triggers")
def get_func(self, model):
return ""
def render_execute(self, model):
document_cols = [model._meta.get_field(field).column for field in self.document_fields]
rendered_document_cols = ", ".join(utils.quote(col) for col in document_cols)
vector_col = model._meta.get_field(self.vector_field).column
return (
f"tsvector_update_trigger({utils.quote(vector_col)},"
f" {utils.quote(self.config_name)}, {rendered_document_cols})"
)