|
| 1 | +from array import array |
| 2 | +import os |
| 3 | + |
| 4 | + |
| 5 | +class DisjointSets(object): |
| 6 | + |
| 7 | + def __init__(self, size): |
| 8 | + self.hierarchy = array('l', [0 for _ in range(size)]) |
| 9 | + self.names = {} |
| 10 | + |
| 11 | + def add(self, item_id, name): |
| 12 | + self.hierarchy[item_id] = -1 |
| 13 | + self.names[item_id] = name |
| 14 | + |
| 15 | + def union(self, root1, root2): |
| 16 | + if self.hierarchy[root1] <= self.hierarchy[root2]: # root1 is a larger tree, since roots are -size of tree |
| 17 | + self.hierarchy[root1] += self.hierarchy[root2] # adding to increase negative value |
| 18 | + self.hierarchy[root2] = root1 |
| 19 | + else: |
| 20 | + self.hierarchy[root2] += self.hierarchy[root1] |
| 21 | + self.hierarchy[root1] = root2 |
| 22 | + |
| 23 | + def find(self, item_id): |
| 24 | + ''' |
| 25 | + Finds the root of the set of which item_id is a member. |
| 26 | + To speed up subsequent finds, updates the parent to root at each level in path to root. |
| 27 | + :param item_id: |
| 28 | + :return: integer representative of set |
| 29 | + ''' |
| 30 | + if self.hierarchy[item_id] < 0: |
| 31 | + return item_id |
| 32 | + else: |
| 33 | + # path compression |
| 34 | + self.hierarchy[item_id] = self.find(self.hierarchy[item_id]) |
| 35 | + return self.hierarchy[item_id] |
| 36 | + |
| 37 | + def __str__(self): |
| 38 | + ret_str = repr(self.hierarchy) |
| 39 | + ret_str += os.linesep + repr(self.names) |
| 40 | + return ret_str |
| 41 | + |
| 42 | + |
| 43 | +def main(): |
| 44 | + ds = DisjointSets(10) |
| 45 | + ds.add(0, "Microsoft") |
| 46 | + ds.add(1, "WebTV") |
| 47 | + ds.add(2, "Google") |
| 48 | + ds.add(3, "DeepMind") |
| 49 | + ds.add(4, "Skype") |
| 50 | + ds.add(5, "Uber") |
| 51 | + ds.add(6, "TinyCo") |
| 52 | + ds.add(7, "TeenyTinyCo") |
| 53 | + |
| 54 | + ds.union(0, 1) |
| 55 | + ds.union(2, 3) |
| 56 | + ds.union(0, 4) |
| 57 | + ds.union(4, 6) |
| 58 | + ds.union(6, 7) |
| 59 | + |
| 60 | + # print(ds) |
| 61 | + |
| 62 | + assert(ds.find(3) == 2) |
| 63 | + |
| 64 | + # print(ds) |
| 65 | + |
| 66 | + |
| 67 | +if __name__ == '__main__': |
| 68 | + main() |
0 commit comments