8000 Added disjoint sets data structure. · pythonpeixun/practice-python@47cc6ff · GitHub
[go: up one dir, main page]

Skip to content

Commit 47cc6ff

Browse files
committed
Added disjoint sets data structure.
1 parent 31e6905 commit 47cc6ff

File tree

1 file changed

+68
-0
lines changed

1 file changed

+68
-0
lines changed

disjoint-sets/disjoint-sets.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
from array import array
2+
import os
3+
4+
5+
class DisjointSets(object):
6+
7+
def __init__(self, size):
8+
self.hierarchy = array('l', [0 for _ in range(size)])
9+
self.names = {}
10+
11+
def add(self, item_id, name):
12+
self.hierarchy[item_id] = -1
13+
self.names[item_id] = name
14+
15+
def union(self, root1, root2):
16+
if self.hierarchy[root1] <= self.hierarchy[root2]: # root1 is a larger tree, since roots are -size of tree
17+
self.hierarchy[root1] += self.hierarchy[root2] # adding to increase negative value
18+
self.hierarchy[root2] = root1
19+
else:
20+
self.hierarchy[root2] += self.hierarchy[root1]
21+
self.hierarchy[root1] = root2
22+
23+
def find(self, item_id):
24+
'''
25+
Finds the root of the set of which item_id is a member.
26+
To speed up subsequent finds, updates the parent to root at each level in path to root.
27+
:param item_id:
28+
:return: integer representative of set
29+
'''
30+
if self.hierarchy[item_id] < 0:
31+
return item_id
32+
else:
33+
# path compression
34+
self.hierarchy[item_id] = self.find(self.hierarchy[item_id])
35+
return self.hierarchy[item_id]
36+
37+
def __str__(self):
38+
ret_str = repr(self.hierarchy)
39+
ret_str += os.linesep + repr(self.names)
40+
return ret_str
41+
42+
43+
def main():
44+
ds = DisjointSets(10)
45+
ds.add(0, "Microsoft")
46+
ds.add(1, "WebTV")
47+
ds.add(2, "Google")
48+
ds.add(3, "DeepMind")
49+
ds.add(4, "Skype")
50+
ds.add(5, "Uber")
51+
ds.add(6, "TinyCo")
52+
ds.add(7, "TeenyTinyCo")
53+
54+
ds.union(0, 1)
55+
ds.union(2, 3)
56+
ds.union(0, 4)
57+
ds.union(4, 6)
58+
ds.union(6, 7)
59+
60+
# print(ds)
61+
62+
assert(ds.find(3) == 2)
63+
64+
# print(ds)
65+
66+
67+
if __name__ == '__main__':
68+
main()

0 commit comments

Comments
 (0)
0