Skip to content

Commit

Permalink
Add iterative segment tree (keon#587)
Browse files Browse the repository at this point in the history
* Added an iterative version of segment tree

* Added test for the iterative segment tree

* Added readme index to iterative segment tree

* Add an additional example and moves examples to the top
  • Loading branch information
qf-jonathan authored and keon committed Nov 9, 2019
1 parent 66d7f97 commit 665b169
Show file tree
Hide file tree
Showing 3 changed files with 145 additions and 0 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,7 @@ If you want to uninstall algorithms, it is as simple as:
- [red_black_tree](algorithms/tree/red_black_tree/red_black_tree.py)
- [segment_tree](algorithms/tree/segment_tree)
- [segment_tree](algorithms/tree/segment_tree/segment_tree.py)
- [iterative_segment_tree](algorithms/tree/segment_tree/iterative_segment_tree.py)
- [traversal](algorithms/tree/traversal)
- [inorder](algorithms/tree/traversal/inorder.py)
- [level_order](algorithms/tree/traversal/level_order.py)
Expand Down
53 changes: 53 additions & 0 deletions algorithms/tree/segment_tree/iterative_segment_tree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
"""
SegmentTree creates a segment tree with a given array and a "commutative" function,
this non-recursive version uses less memory than the recursive version and include:
1. range queries in log(N) time
2. update an element in log(N) time
the function should be commutative and takes 2 values and returns the same type value
Examples -
mytree = SegmentTree([2, 4, 5, 3, 4],max)
print(mytree.query(2, 4))
mytree.update(3, 6)
print(mytree.query(0, 3)) ...
mytree = SegmentTree([4, 5, 2, 3, 4, 43, 3], lambda a, b: a + b)
print(mytree.query(0, 6))
mytree.update(2, -10)
print(mytree.query(0, 6)) ...
mytree = SegmentTree([(1, 2), (4, 6), (4, 5)], lambda a, b: (a[0] + b[0], a[1] + b[1]))
print(mytree.query(0, 2))
mytree.update(2, (-1, 2))
print(mytree.query(0, 2)) ...
"""


class SegmentTree:
def __init__(self, arr, function):
self.tree = [None for _ in range(len(arr))] + arr
self.size = len(arr)
self.fn = function
self.build_tree()

def build_tree(self):
for i in range(self.size - 1, 0, -1):
self.tree[i] = self.fn(self.tree[i * 2], self.tree[i * 2 + 1])

def update(self, p, v):
p += self.size
self.tree[p] = v
while p > 1:
p = p // 2
self.tree[p] = self.fn(self.tree[p * 2], self.tree[p * 2 + 1])

def query(self, l, r):
l, r = l + self.size, r + self.size
res = None
while l <= r:
if l % 2 == 1:
res = self.tree[l] if res is None else self.fn(res, self.tree[l])
if r % 2 == 0:
res = self.tree[r] if res is None else self.fn(res, self.tree[r])
l, r = (l + 1) // 2, (r - 1) // 2
return res
91 changes: 91 additions & 0 deletions tests/test_iterative_segment_tree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from algorithms.tree.segment_tree.iterative_segment_tree import SegmentTree
from functools import reduce

import unittest


def gcd(a, b):
if b == 0:
return a
return gcd(b, a % b)


class TestSegmentTree(unittest.TestCase):
"""
Test for the Iterative Segment Tree data structure
"""

def test_segment_tree_creation(self):
arr = [2, 4, 3, 6, 8, 9, 3]
max_segment_tree = SegmentTree(arr, max)
min_segment_tree = SegmentTree(arr, min)
sum_segment_tree = SegmentTree(arr, lambda a, b: a + b)
gcd_segment_tree = SegmentTree(arr, gcd)
self.assertEqual(max_segment_tree.tree, [None, 9, 8, 9, 4, 8, 9, 2, 4, 3, 6, 8, 9, 3])
self.assertEqual(min_segment_tree.tree, [None, 2, 3, 2, 3, 6, 3, 2, 4, 3, 6, 8, 9, 3])
self.assertEqual(sum_segment_tree.tree, [None, 35, 21, 14, 7, 14, 12, 2, 4, 3, 6, 8, 9, 3])
self.assertEqual(gcd_segment_tree.tree, [None, 1, 1, 1, 1, 2, 3, 2, 4, 3, 6, 8, 9, 3])

def test_max_segment_tree(self):
arr = [-1, 1, 10, 2, 9, -3, 8, 4, 7, 5, 6, 0]
self.__test_all_segments(arr, max)

def test_min_segment_tree(self):
arr = [1, 10, -2, 9, -3, 8, 4, -7, 5, 6, 11, -12]
self.__test_all_segments(arr, min)

def test_sum_segment_tree(self):
arr = [1, 10, 2, 9, 3, 8, 4, 7, 5, 6, -11, -12]
self.__test_all_segments(arr, lambda a, b: a + b)

def test_gcd_segment_tree(self):
arr = [1, 10, 2, 9, 3, 8, 4, 7, 5, 6, 11, 12, 14]
self.__test_all_segments(arr, gcd)

def test_max_segment_tree_with_updates(self):
arr = [-1, 1, 10, 2, 9, -3, 8, 4, 7, 5, 6, 0]
updates = {0: 1, 1: 2, 2: 3, 3: 4, 4: 5, 5: 6, 6: 7, 7: 8, 8: 9, 9: 10, 10: 11, 11: 12}
self.__test_all_segments_with_updates(arr, max, updates)

def test_min_segment_tree_with_updates(self):
arr = [1, 10, -2, 9, -3, 8, 4, -7, 5, 6, 11, -12]
updates = {0: 7, 1: 2, 2: 6, 3: -14, 4: 5, 5: 4, 6: 7, 7: -10, 8: 9, 9: 10, 10: 12, 11: 1}
self.__test_all_segments_with_updates(arr, min, updates)

def test_sum_segment_tree_with_updates(self):
arr = [1, 10, 2, 9, 3, 8, 4, 7, 5, 6, -11, -12]
updates = {0: 12, 1: 11, 2: 10, 3: 9, 4: 8, 5: 7, 6: 6, 7: 5, 8: 4, 9: 3, 10: 2, 11: 1}
self.__test_all_segments_with_updates(arr, lambda a, b: a + b, updates)

def test_gcd_segment_tree_with_updates(self):
arr = [1, 10, 2, 9, 3, 8, 4, 7, 5, 6, 11, 12, 14]
updates = {0: 4, 1: 2, 2: 3, 3: 9, 4: 21, 5: 7, 6: 4, 7: 4, 8: 2, 9: 5, 10: 17, 11: 12, 12: 3}
self.__test_all_segments_with_updates(arr, gcd, updates)

def __test_all_segments(self, arr, fnc):
"""
Test all possible segments in the tree
:param arr: array to test
:param fnc: function of the segment tree
"""
segment_tree = SegmentTree(arr, fnc)
self.__test_segments_helper(segment_tree, fnc, arr)

def __test_all_segments_with_updates(self, arr, fnc, upd):
"""
Test all possible segments in the tree with updates
:param arr: array to test
:param fnc: function of the segment tree
:param upd: updates to test
"""
segment_tree = SegmentTree(arr, fnc)
for index, value in upd.items():
arr[index] = value
segment_tree.update(index, value)
self.__test_segments_helper(segment_tree, fnc, arr)

def __test_segments_helper(self, seg_tree, fnc, arr):
for i in range(0, len(arr)):
for j in range(i, len(arr)):
range_value = reduce(fnc, arr[i:j + 1])
self.assertEqual(seg_tree.query(i, j), range_value)

0 comments on commit 665b169

Please sign in to comment.