forked from TheAlgorithms/C-Plus-Plus
-
Notifications
You must be signed in to change notification settings - Fork 0
/
persistent_seg_tree_lazy_prop.cpp
321 lines (307 loc) · 12.7 KB
/
persistent_seg_tree_lazy_prop.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
/**
* @file
* @brief [Persistent segment tree with range updates (lazy
* propagation)](https://en.wikipedia.org/wiki/Persistent_data_structure)
*
* @details
* A normal segment tree facilitates making point updates and range queries in
* logarithmic time. Lazy propagation preserves the logarithmic time with range
* updates. So, a segment tree with lazy propagation enables doing range updates
* and range queries in logarithmic time, but it doesn't save any information
* about itself before the last update. A persistent data structure always
* preserves the previous version of itself when it is modified. That is, a new
* version of the segment tree is generated after every update. It saves all
* previous versions of itself (before every update) to facilitate doing range
* queries in any version. More memory is used ,but the logarithmic time is
* preserved because the new version points to the same nodes, that the previous
* version points to, that are not affected by the update. That is, only the
* nodes that are affected by the update and their ancestors are copied. The
* rest is copied using lazy propagation in the next queries. Thus preserving
* the logarithmic time because the number of nodes copied after any update is
* logarithmic.
*
* @author [Magdy Sedra](https://github.com/MSedra)
*/
#include <iostream> /// for IO operations
#include <memory> /// to manage dynamic memory
#include <vector> /// for std::vector
/**
* @namespace range_queries
* @brief Range queries algorithms
*/
namespace range_queries {
/**
* @brief Range query here is range sum, but the code can be modified to make
* different queries like range max or min.
*/
class perSegTree {
private:
class Node {
public:
std::shared_ptr<Node> left = nullptr; /// pointer to the left node
std::shared_ptr<Node> right = nullptr; /// pointer to the right node
int64_t val = 0,
prop = 0; /// val is the value of the node (here equals to the
/// sum of the leaf nodes children of that node),
/// prop is the value to be propagated/added to all
/// the leaf nodes children of that node
};
uint32_t n = 0; /// number of elements/leaf nodes in the segment tree
std::vector<std::shared_ptr<Node>>
ptrs{}; /// ptrs[i] holds a root pointer to the segment tree after the
/// ith update. ptrs[0] holds a root pointer to the segment
/// tree before any updates
std::vector<int64_t> vec{}; /// values of the leaf nodes that the segment
/// tree will be constructed with
/**
* @brief Creating a new node with the same values of curr node
* @param curr node that would be copied
* @returns the new node
*/
std::shared_ptr<Node> newKid(std::shared_ptr<Node> const &curr) {
auto newNode = std::make_shared<Node>(Node());
newNode->left = curr->left;
newNode->right = curr->right;
newNode->prop = curr->prop;
newNode->val = curr->val;
return newNode;
}
/**
* @brief If there is some value to be propagated to the passed node, value
* is added to the node and the children of the node, if exist, are copied
* and the propagated value is also added to them
* @param i the left index of the range that the passed node holds its sum
* @param j the right index of the range that the passed node holds its sum
* @param curr pointer to the node to be propagated
* @returns void
*/
void lazy(const uint32_t &i, const uint32_t &j,
std::shared_ptr<Node> const &curr) {
if (!curr->prop) {
return;
}
curr->val += (j - i + 1) * curr->prop;
if (i != j) {
curr->left = newKid(curr->left);
curr->right = newKid(curr->right);
curr->left->prop += curr->prop;
curr->right->prop += curr->prop;
}
curr->prop = 0;
}
/**
* @brief Constructing the segment tree with the early passed vector. Every
* call creates a node to hold the sum of the given range, set its pointers
* to the children, and set its value to the sum of the children's values
* @param i the left index of the range that the created node holds its sum
* @param j the right index of the range that the created node holds its sum
* @returns pointer to the newly created node
*/
std::shared_ptr<Node> construct(const uint32_t &i, const uint32_t &j) {
auto newNode = std::make_shared<Node>(Node());
if (i == j) {
newNode->val = vec[i];
} else {
uint32_t mid = i + (j - i) / 2;
auto leftt = construct(i, mid);
auto right = construct(mid + 1, j);
newNode->val = leftt->val + right->val;
newNode->left = leftt;
newNode->right = right;
}
return newNode;
}
/**
* @brief Doing range update, checking at every node if it has some value to
* be propagated. All nodes affected by the update are copied and
* propagation value is added to the leaf of them
* @param i the left index of the range that the passed node holds its sum
* @param j the right index of the range that the passed node holds its sum
* @param l the left index of the range to be updated
* @param r the right index of the range to be updated
* @param value the value to be added to every element whose index x
* satisfies l<=x<=r
* @param curr pointer to the current node, which has value = the sum of
* elements whose index x satisfies i<=x<=j
* @returns pointer to the current newly created node
*/
std::shared_ptr<Node> update(const uint32_t &i, const uint32_t &j,
const uint32_t &l, const uint32_t &r,
const int64_t &value,
std::shared_ptr<Node> const &curr) {
lazy(i, j, curr);
if (i >= l && j <= r) {
std::shared_ptr<Node> newNode = newKid(curr);
newNode->prop += value;
lazy(i, j, newNode);
return newNode;
}
if (i > r || j < l) {
return curr;
}
auto newNode = std::make_shared<Node>(Node());
uint32_t mid = i + (j - i) / 2;
newNode->left = update(i, mid, l, r, value, curr->left);
newNode->right = update(mid + 1, j, l, r, value, curr->right);
newNode->val = newNode->left->val + newNode->right->val;
return newNode;
}
/**
* @brief Querying the range from index l to index r, checking at every node
* if it has some value to be propagated. Current node's value is returned
* if its range is completely inside the wanted range, else 0 is returned
* @param i the left index of the range that the passed node holds its sum
* @param j the right index of the range that the passed node holds its sum
* @param l the left index of the range whose sum should be returned as a
* result
* @param r the right index of the range whose sum should be returned as a
* result
* @param curr pointer to the current node, which has value = the sum of
* elements whose index x satisfies i<=x<=j
* @returns sum of elements whose index x satisfies l<=x<=r
*/
int64_t query(const uint32_t &i, const uint32_t &j, const uint32_t &l,
const uint32_t &r, std::shared_ptr<Node> const &curr) {
lazy(i, j, curr);
if (j < l || r < i) {
return 0;
}
if (i >= l && j <= r) {
return curr->val;
}
uint32_t mid = i + (j - i) / 2;
return query(i, mid, l, r, curr->left) +
query(mid + 1, j, l, r, curr->right);
}
/**
* public methods that can be used directly from outside the class. They
* call the private functions that do all the work
*/
public:
/**
* @brief Constructing the segment tree with the values in the passed
* vector. Returned root pointer is pushed in the pointers vector to have
* access to the original version if the segment tree is updated
* @param vec vector whose values will be used to build the segment tree
* @returns void
*/
void construct(const std::vector<int64_t>
&vec) // the segment tree will be built from the values
// in "vec", "vec" is 0 indexed
{
if (vec.empty()) {
return;
}
n = vec.size();
this->vec = vec;
auto root = construct(0, n - 1);
ptrs.push_back(root);
}
/**
* @brief Doing range update by passing the left and right indexes of the
* range as well as the value to be added.
* @param l the left index of the range to be updated
* @param r the right index of the range to be updated
* @param value the value to be added to every element whose index x
* satisfies l<=x<=r
* @returns void
*/
void update(const uint32_t &l, const uint32_t &r,
const int64_t
&value) // all elements from index "l" to index "r" would
// by updated by "value", "l" and "r" are 0 indexed
{
ptrs.push_back(update(
0, n - 1, l, r, value,
ptrs[ptrs.size() -
1])); // saving the root pointer to the new segment tree
}
/**
* @brief Querying the range from index l to index r, getting the sum of the
* elements whose index x satisfies l<=x<=r
* @param l the left index of the range whose sum should be returned as a
* result
* @param r the right index of the range whose sum should be returned as a
* result
* @param version the version to query on. If equals to 0, the original
* segment tree will be queried
* @returns sum of elements whose index x satisfies l<=x<=r
*/
int64_t query(
const uint32_t &l, const uint32_t &r,
const uint32_t
&version) // querying the range from "l" to "r" in a segment tree
// after "version" updates, "l" and "r" are 0 indexed
{
return query(0, n - 1, l, r, ptrs[version]);
}
/**
* @brief Getting the number of versions after updates so far which is equal
* to the size of the pointers vector
* @returns the number of versions
*/
uint32_t size() // returns the number of segment trees (versions) , the
// number of updates done so far = returned value - 1
// ,because one of the trees is the original segment tree
{
return ptrs.size();
}
};
} // namespace range_queries
/**
* @brief Test implementations
* @returns void
*/
static void test() {
std::vector<int64_t> arr = {-5, 2, 3, 11, -2, 7, 0, 1};
range_queries::perSegTree tree;
std::cout << "Elements before any updates are {";
for (uint32_t i = 0; i < arr.size(); ++i) {
std::cout << arr[i];
if (i != arr.size() - 1) {
std::cout << ",";
}
}
std::cout << "}\n";
tree.construct(
arr); // constructing the original segment tree (version = 0)
std::cout << "Querying range sum on version 0 from index 2 to 4 = 3+11-2 = "
<< tree.query(2, 4, 0) << '\n';
std::cout
<< "Subtract 7 from all elements from index 1 to index 5 inclusive\n";
tree.update(1, 5, -7); // subtracting 7 from index 1 to index 5
std::cout << "Elements of the segment tree whose version = 1 (after 1 "
"update) are {";
for (uint32_t i = 0; i < arr.size(); ++i) {
std::cout << tree.query(i, i, 1);
if (i != arr.size() - 1) {
std::cout << ",";
}
}
std::cout << "}\n";
std::cout << "Add 10 to all elements from index 0 to index 7 inclusive\n";
tree.update(0, 7, 10); // adding 10 to all elements
std::cout << "Elements of the segment tree whose version = 2 (after 2 "
"updates) are {";
for (uint32_t i = 0; i < arr.size(); ++i) {
std::cout << tree.query(i, i, 2);
if (i != arr.size() - 1) {
std::cout << ",";
}
}
std::cout << "}\n";
std::cout << "Number of segment trees (versions) now = " << tree.size()
<< '\n';
std::cout << "Querying range sum on version 0 from index 3 to 5 = 11-2+7 = "
<< tree.query(3, 5, 0) << '\n';
std::cout << "Querying range sum on version 1 from index 3 to 5 = 4-9+0 = "
<< tree.query(3, 5, 1) << '\n';
}
/**
* @brief Main function
* @returns 0 on exit
*/
int main() {
test(); // run self-test implementations
return 0;
}