forked from mubastan/segment-py
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathDSF.py
98 lines (85 loc) · 2.67 KB
/
DSF.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
# Disjoint Set data structure with union by rank and path compression
class Node:
def __init__(self, parent, rank = 0):
self.parent = parent
self.rank = rank
self.size = 1
class DSF:
def __init__(self, numElements):
self.numElements = numElements
self.numSets = numElements
self.nodes = [Node(i) for i in range(numElements)]
def find_nopc(self, x):
y = x
while y != self.nodes[y].parent:
y = self.nodes[y].parent
self.nodes[x].parent = y
return y
# with recursive path compression
def find(self, x):
if self.nodes[x].parent != x:
self.nodes[x].parent = self.find(self.nodes[x].parent)
return self.nodes[x].parent
def union(self, x, y):
if x==y: return
xr = self.find(x) # root, set of x
yr = self.find(y) # root, set of y
if xr==yr: return
nx = self.nodes[xr]
ny = self.nodes[yr]
if nx.rank > ny.rank:
ny.parent = xr
nx.size += ny.size
else:
nx.parent = yr
ny.size += nx.size
if nx.rank == ny.rank:
ny.rank += 1
self.numSets -= 1
def setSize(self, id):
return self.nodes[id].size
def reset(self):
for i in range(self.numElements):
self.nodes[i].rank = 0
self.nodes[i].size = 1
self.nodes[i].parent = i
self.numSets = self.numElements
def printSet(self):
print('\nnum elements: ', self.numElements)
print('num sets: ', self.numSets)
print('Size: ')
for i in range(self.numElements):
print(self.nodes[i].size, end='')
print('\nParent: ')
for i in range(self.numElements):
print(self.nodes[i].parent, end='')
print('\nRank: ')
for i in range(self.numElements):
print(self.nodes[i].rank, end='')
print()
if __name__ == "__main__":
dsf = DSF(10)
print('dsf.numElements: ', dsf.numElements)
print('dsf.numSets: ', dsf.numSets)
print('find 3: ', dsf.find(3))
print('find 4: ', dsf.find(4))
print('num sets:', dsf.numSets)
dsf.union(3,4)
print('union 3,4')
print('find 3: ', dsf.find(3))
print('find 4: ', dsf.find(4))
print('num sets:', dsf.numSets)
dsf.union(3,5)
print('union 3,5')
print('find 3: ', dsf.find(3))
print('find 5: ', dsf.find(5))
print('num sets:', dsf.numSets)
dsf.printSet()
dsf.union(9,4)
dsf.union(0,1)
dsf.union(0,5)
for e in range(dsf.numElements):
print('Find', e, dsf.find(e))
dsf.printSet()
dsf.reset()
dsf.printSet()