8000 Added Prim's algorithm for minimum spanning tree. · pythonpeixun/practice-python@a896b36 · GitHub
[go: up one dir, main page]

Skip to content

Commit a896b36

Browse files
committed
Added Prim's algorithm for minimum spanning tree.
1 parent 6816597 commit a896b36

File tree

2 files changed

+141
-99
lines changed

2 files changed

+141
-99
lines changed

graphs/dijkstra.py

Lines changed: 0 additions & 99 deletions
This file was deleted.

graphs/undirected_graph_weighted.py

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
import queue
2+
from collections import namedtuple
3+
4+
Edge = namedtuple('Edge', ['vertex', 'weight'])
5+
6+
7+
class GraphUndirectedWeighted(object):
8+
def __init__(self, vertex_count):
9+
self.vertex_count = vertex_count
10+
self.adjacency_list = [[] for _ in range(vertex_count)]
11+
12+
def add_edge(self, source, dest, weight):
13+
assert source < self.vertex_count
14+
assert dest < self.vertex_count
15+
self.adjacency_list[source].append(Edge(dest, weight))
16+
self.adjacency_list[dest].append(Edge(source, weight))
17+
18+
def get_neighbor(self, vertex):
19+
"""
20+
Returns the next neighbor to vertex
21+
:param vertex:
22+
:rtype: Edge
23+
"""
24+
for e in self.adjacency_list[vertex]:
25+
yield e
26+
27+
def get_vertex(self):
28+
for v in range(self.vertex_count):
29+
yield v
30+
31+
def dijkstra(self, source, dest):
32+
q = queue.PriorityQueue()
33+
parents = []
34+
distances = []
35< 8000 /td>+
start_weight = float("inf")
36+
37+
for i in self.get_vertex():
38+
weight = start_weight
39+
if source == i:
40+
weight = 0
41+
distances.append(weight)
42+
parents.append(None)
43+
44+
q.put(([0, source]))
45+
46+
while not q.empty():
47+
v_tuple = q.get()
48+
v = v_tuple[1]
49+
50+
for e in self.get_neighbor(v):
51+
candidate_distance = distances[v] + e.weight
52+
if distances[e.vertex] > candidate_distance:
53+
distances[e.vertex] = candidate_distance
54+
parents[e.vertex] = v
55+
# primitive but effective negative cycle detection
56+
if candidate_distance < -1000:
57+
raise Exception("Negative cycle detected")
58+
q.put(([distances[e.vertex], e.vertex]))
59+
60+
shortest_path = []
61+
end = dest
62+
while end is not None:
63+
shortest_path.append(end)
64+
end = parents[end]
65+
66+
shortest_path.reverse()
67+
68+
return shortest_path, distances[dest]
69+
70+
def prim(self):
71+
"""
72+
Returns a dictionary of parents of vertices in a minimum spanning tree
73+
:rtype: dict
74+
"""
75+
s = set()
76+
q = queue.PriorityQueue()
77+
parents = {}
78+
start_weight = float("inf")
79+
weights = {} # since we can't peek into queue
80+
81+
for i in self.get_vertex():
82+
weight = start_weight
83+
if i == 0:
84+
weight = 0
85+
q.put(([weight, i]))
86+
weights[i] = weight
87+
parents[i] = None
88+
89+
while not q.empty():
90+
v_tuple = q.get()
91+
vertex = v_tuple[1]
92+
93+
s.add(vertex)
94+
95+
for u in self.get_neighbor(vertex):
96+
if u.vertex not in s:
97+
if u.weight < weights[u.vertex]:
98+
parents[u.vertex] = vertex
99+
weights[u.vertex] = u.weight
100+
q.put(([u.weight, u.vertex]))
101+
102+
return parents
103+
104+
105+
def main():
106+
g = GraphUndirectedWeighted(9)
107+
g.add_edge(0, 1, 4)
108+
g.add_edge(1, 7, 6)
109+
g.add_edge(1, 2, 1)
110+
g.add_edge(2, 3, 3)
111+
g.add_edge(3, 7, 1)
112+
g.add_edge(3, 4, 2)
113+
g.add_edge(3, 5, 1)
114+
g.add_edge(4, 5, 1)
115+
g.add_edge(5, 6, 1)
116+
g.add_edge(6, 7, 2)
117+
g.add_edge(6, 8, 2)
118+
g.add_edge(7, 8, 2)
119+
# for testing negative cycles
120+
# g.add_edge(1, 9, -5)
121+
# g.add_edge(9, 7, -4)
122+
123+
shortest_path, distance = g.dijkstra(0, 1)
124+
assert shortest_path == [0, 1] and distance == 4
125+
126+
shortest_path, distance = g.dijkstra(0, 8)
127+
assert shortest_path == [0, 1, 2, 3, 7, 8] and distance == 11
128+
129+
shortest_path, distance = g.dijkstra(5, 0)
130+
assert shortest_path == [5, 3, 2, 1, 0] and distance == 9
131+
132+
shortest_path, distance = g.dijkstra(1, 1)
133+
assert shortest_path == [1] and distance == 0
134+
135+
msp = g.prim()
136+
print(msp)
137+
assert(msp == {0: None, 1: 0, 2: 1, 3: 2, 4: 5, 5: 3, 6: 5, 7: 3, 8: 6})
138+
139+
140+
if __name__ == "__main__":
141+
main()

0 commit comments

Comments
 (0)
0