forked from keon/algorithms
-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add iterative segment tree (keon#587)
* Added an iterative version of segment tree * Added test for the iterative segment tree * Added readme index to iterative segment tree * Add an additional example and moves examples to the top
- Loading branch information
1 parent
66d7f97
commit 665b169
Showing
3 changed files
with
145 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
""" | ||
SegmentTree creates a segment tree with a given array and a "commutative" function, | ||
this non-recursive version uses less memory than the recursive version and include: | ||
1. range queries in log(N) time | ||
2. update an element in log(N) time | ||
the function should be commutative and takes 2 values and returns the same type value | ||
Examples - | ||
mytree = SegmentTree([2, 4, 5, 3, 4],max) | ||
print(mytree.query(2, 4)) | ||
mytree.update(3, 6) | ||
print(mytree.query(0, 3)) ... | ||
mytree = SegmentTree([4, 5, 2, 3, 4, 43, 3], lambda a, b: a + b) | ||
print(mytree.query(0, 6)) | ||
mytree.update(2, -10) | ||
print(mytree.query(0, 6)) ... | ||
mytree = SegmentTree([(1, 2), (4, 6), (4, 5)], lambda a, b: (a[0] + b[0], a[1] + b[1])) | ||
print(mytree.query(0, 2)) | ||
mytree.update(2, (-1, 2)) | ||
print(mytree.query(0, 2)) ... | ||
""" | ||
|
||
|
||
class SegmentTree: | ||
def __init__(self, arr, function): | ||
self.tree = [None for _ in range(len(arr))] + arr | ||
self.size = len(arr) | ||
self.fn = function | ||
self.build_tree() | ||
|
||
def build_tree(self): | ||
for i in range(self.size - 1, 0, -1): | ||
self.tree[i] = self.fn(self.tree[i * 2], self.tree[i * 2 + 1]) | ||
|
||
def update(self, p, v): | ||
p += self.size | ||
self.tree[p] = v | ||
while p > 1: | ||
p = p // 2 | ||
self.tree[p] = self.fn(self.tree[p * 2], self.tree[p * 2 + 1]) | ||
|
||
def query(self, l, r): | ||
l, r = l + self.size, r + self.size | ||
res = None | ||
while l <= r: | ||
if l % 2 == 1: | ||
res = self.tree[l] if res is None else self.fn(res, self.tree[l]) | ||
if r % 2 == 0: | ||
res = self.tree[r] if res is None else self.fn(res, self.tree[r]) | ||
l, r = (l + 1) // 2, (r - 1) // 2 | ||
return res |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
from algorithms.tree.segment_tree.iterative_segment_tree import SegmentTree | ||
from functools import reduce | ||
|
||
import unittest | ||
|
||
|
||
def gcd(a, b): | ||
if b == 0: | ||
return a | ||
return gcd(b, a % b) | ||
|
||
|
||
class TestSegmentTree(unittest.TestCase): | ||
""" | ||
Test for the Iterative Segment Tree data structure | ||
""" | ||
|
||
def test_segment_tree_creation(self): | ||
arr = [2, 4, 3, 6, 8, 9, 3] | ||
max_segment_tree = SegmentTree(arr, max) | ||
min_segment_tree = SegmentTree(arr, min) | ||
sum_segment_tree = SegmentTree(arr, lambda a, b: a + b) | ||
gcd_segment_tree = SegmentTree(arr, gcd) | ||
self.assertEqual(max_segment_tree.tree, [None, 9, 8, 9, 4, 8, 9, 2, 4, 3, 6, 8, 9, 3]) | ||
self.assertEqual(min_segment_tree.tree, [None, 2, 3, 2, 3, 6, 3, 2, 4, 3, 6, 8, 9, 3]) | ||
self.assertEqual(sum_segment_tree.tree, [None, 35, 21, 14, 7, 14, 12, 2, 4, 3, 6, 8, 9, 3]) | ||
self.assertEqual(gcd_segment_tree.tree, [None, 1, 1, 1, 1, 2, 3, 2, 4, 3, 6, 8, 9, 3]) | ||
|
||
def test_max_segment_tree(self): | ||
arr = [-1, 1, 10, 2, 9, -3, 8, 4, 7, 5, 6, 0] | ||
self.__test_all_segments(arr, max) | ||
|
||
def test_min_segment_tree(self): | ||
arr = [1, 10, -2, 9, -3, 8, 4, -7, 5, 6, 11, -12] | ||
self.__test_all_segments(arr, min) | ||
|
||
def test_sum_segment_tree(self): | ||
arr = [1, 10, 2, 9, 3, 8, 4, 7, 5, 6, -11, -12] | ||
self.__test_all_segments(arr, lambda a, b: a + b) | ||
|
||
def test_gcd_segment_tree(self): | ||
arr = [1, 10, 2, 9, 3, 8, 4, 7, 5, 6, 11, 12, 14] | ||
self.__test_all_segments(arr, gcd) | ||
|
||
def test_max_segment_tree_with_updates(self): | ||
arr = [-1, 1, 10, 2, 9, -3, 8, 4, 7, 5, 6, 0] | ||
updates = {0: 1, 1: 2, 2: 3, 3: 4, 4: 5, 5: 6, 6: 7, 7: 8, 8: 9, 9: 10, 10: 11, 11: 12} | ||
self.__test_all_segments_with_updates(arr, max, updates) | ||
|
||
def test_min_segment_tree_with_updates(self): | ||
arr = [1, 10, -2, 9, -3, 8, 4, -7, 5, 6, 11, -12] | ||
updates = {0: 7, 1: 2, 2: 6, 3: -14, 4: 5, 5: 4, 6: 7, 7: -10, 8: 9, 9: 10, 10: 12, 11: 1} | ||
self.__test_all_segments_with_updates(arr, min, updates) | ||
|
||
def test_sum_segment_tree_with_updates(self): | ||
arr = [1, 10, 2, 9, 3, 8, 4, 7, 5, 6, -11, -12] | ||
updates = {0: 12, 1: 11, 2: 10, 3: 9, 4: 8, 5: 7, 6: 6, 7: 5, 8: 4, 9: 3, 10: 2, 11: 1} | ||
self.__test_all_segments_with_updates(arr, lambda a, b: a + b, updates) | ||
|
||
def test_gcd_segment_tree_with_updates(self): | ||
arr = [1, 10, 2, 9, 3, 8, 4, 7, 5, 6, 11, 12, 14] | ||
updates = {0: 4, 1: 2, 2: 3, 3: 9, 4: 21, 5: 7, 6: 4, 7: 4, 8: 2, 9: 5, 10: 17, 11: 12, 12: 3} | ||
self.__test_all_segments_with_updates(arr, gcd, updates) | ||
|
||
def __test_all_segments(self, arr, fnc): | ||
""" | ||
Test all possible segments in the tree | ||
:param arr: array to test | ||
:param fnc: function of the segment tree | ||
""" | ||
segment_tree = SegmentTree(arr, fnc) | ||
self.__test_segments_helper(segment_tree, fnc, arr) | ||
|
||
def __test_all_segments_with_updates(self, arr, fnc, upd): | ||
""" | ||
Test all possible segments in the tree with updates | ||
:param arr: array to test | ||
:param fnc: function of the segment tree | ||
:param upd: updates to test | ||
""" | ||
segment_tree = SegmentTree(arr, fnc) | ||
for index, value in upd.items(): | ||
arr[index] = value | ||
segment_tree.update(index, value) | ||
self.__test_segments_helper(segment_tree, fnc, arr) | ||
|
||
def __test_segments_helper(self, seg_tree, fnc, arr): | ||
for i in range(0, len(arr)): | ||
for j in range(i, len(arr)): | ||
range_value = reduce(fnc, arr[i:j + 1]) | ||
self.assertEqual(seg_tree.query(i, j), range_value) |