diff --git a/python/12_sorts/merge_sort.py b/python/12_sorts/merge_sort.py index 9bcacecf..f1c5e71a 100644 --- a/python/12_sorts/merge_sort.py +++ b/python/12_sorts/merge_sort.py @@ -4,21 +4,23 @@ from typing import List + def merge_sort(a: List[int]): - _merge_sort_between(a, 0, len(a)-1) + _merge_sort_between(a, 0, len(a) - 1) + def _merge_sort_between(a: List[int], low: int, high: int): # The indices are inclusive for both low and high. - if low >= high: return - mid = low + (high - low)//2 - _merge_sort_between(a, low, mid) - _merge_sort_between(a, mid+1, high) + if low < high: + mid = low + (high - low) // 2 + _merge_sort_between(a, low, mid) + _merge_sort_between(a, mid + 1, high) + _merge(a, low, mid, high) - _merge(a, low, mid, high) def _merge(a: List[int], low: int, mid: int, high: int): # a[low:mid], a[mid+1, high] are sorted. - i, j = low, mid+1 + i, j = low, mid + 1 tmp = [] while i <= mid and j <= high: if a[i] <= a[j]: @@ -29,8 +31,23 @@ def _merge(a: List[int], low: int, mid: int, high: int): j += 1 start = i if i <= mid else j end = mid if i <= mid else high - tmp.extend(a[start:end+1]) - a[low:high+1] = tmp + tmp.extend(a[start:end + 1]) + a[low:high + 1] = tmp + + +def test_merge_sort(): + a1 = [3, 5, 6, 7, 8] + merge_sort(a1) + assert a1 == [3, 5, 6, 7, 8] + a2 = [2, 2, 2, 2] + merge_sort(a2) + assert a2 == [2, 2, 2, 2] + a3 = [4, 3, 2, 1] + merge_sort(a3) + assert a3 == [1, 2, 3, 4] + a4 = [5, -1, 9, 3, 7, 8, 3, -2, 9] + merge_sort(a4) + assert a4 == [-2, -1, 3, 3, 5, 7, 8, 9, 9] if __name__ == "__main__": @@ -45,4 +62,4 @@ def _merge(a: List[int], low: int, mid: int, high: int): merge_sort(a3) print(a3) merge_sort(a4) - print(a4) \ No newline at end of file + print(a4) diff --git a/python/12_sorts/quick_sort.py b/python/12_sorts/quick_sort.py index f4267512..b4f923cf 100644 --- a/python/12_sorts/quick_sort.py +++ b/python/12_sorts/quick_sort.py @@ -5,22 +5,25 @@ from typing import List import random + def quick_sort(a: List[int]): - _quick_sort_between(a, 0, len(a)-1) + _quick_sort_between(a, 0, len(a) - 1) + def _quick_sort_between(a: List[int], low: int, high: int): - if low >= high: return - # get a random position as the pivot - k = random.randint(low, high) - a[low], a[k] = a[k], a[low] + if low < high: + # get a random position as the pivot + k = random.randint(low, high) + a[low], a[k] = a[k], a[low] + + m = _partition(a, low, high) # a[m] is in final position + _quick_sort_between(a, low, m - 1) + _quick_sort_between(a, m + 1, high) - m = _partition(a, low, high) # a[m] is in final position - _quick_sort_between(a, low, m-1) - _quick_sort_between(a, m+1, high) def _partition(a: List[int], low: int, high: int): pivot, j = a[low], low - for i in range(low+1, high+1): + for i in range(low + 1, high + 1): if a[i] <= pivot: j += 1 a[j], a[i] = a[i], a[j] # swap @@ -28,6 +31,21 @@ def _partition(a: List[int], low: int, high: int): return j +def test_quick_sort(): + a1 = [3, 5, 6, 7, 8] + quick_sort(a1) + assert a1 == [3, 5, 6, 7, 8] + a2 = [2, 2, 2, 2] + quick_sort(a2) + assert a2 == [2, 2, 2, 2] + a3 = [4, 3, 2, 1] + quick_sort(a3) + assert a3 == [1, 2, 3, 4] + a4 = [5, -1, 9, 3, 7, 8, 3, -2, 9] + quick_sort(a4) + assert a4 == [-2, -1, 3, 3, 5, 7, 8, 9, 9] + + if __name__ == "__main__": a1 = [3, 5, 6, 7, 8] a2 = [2, 2, 2, 2] @@ -40,4 +58,4 @@ def _partition(a: List[int], low: int, high: int): quick_sort(a3) print(a3) quick_sort(a4) - print(a4) \ No newline at end of file + print(a4)