8000 Complete lab3 · hacheyz/HIT-AdvancedAlgorithm@0d3d718 · GitHub
[go: up one dir, main page]

Skip to content

Commit 0d3d718

Browse files
committed
Complete lab3
1 parent 39f2f64 commit 0d3d718

File tree

4 files changed

+107
-0
lines changed

4 files changed

+107
-0
lines changed
Loading
116 KB
Loading

lab3/graph.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import numpy as np
2+
3+
class RandomGraph:
4+
def __init__(self, n: int):
5+
"""
6+
:param n: 图的顶点数
7+
"""
8+
self.n = n
9+
self.graph = np.zeros((n, n))
10+
11+
def randomize(self):
12+
"""
13+
生成一个 n 顶点随机图,任意两个顶点之间边的权值均匀分布于 (0, 1)
14+
"""
15+
self.graph = np.random.rand(self.n, self.n)
16+
self.graph = np.tril(self.graph, -1) # 取下三角矩阵
17+
self.graph += self.graph.T # 对称
18+
np.fill_diagonal(self.graph, np.inf) # 设置对角线为无穷大
19+
20+
def prim(self):
21+
"""
22+
Prim 算法计算最小生成树权值
23+
:return: 最小生成树权值
24+
"""
25+
n = self.n
26+
visited = [False] * n
27+
visited[0] = True
28+
dist = self.graph[0].copy()
29+
mst = 0
30+
for _ in range(n - 1):
31+
u = np.argmin(dist)
32+
mst += dist[u]
33+
dist[u] = np.inf
34+
visited[u] = True
35+
for v in range(n):
36+
if not visited[v] and self.graph[u, v] < dist[v]:
37+
dist[v] = self.graph[u, v]
38+
return mst

lab3/main.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
"""
2+
实现一种随机图,在随机图上实现对最小生成树的抽样过程,
3+
由抽样过程实现蒙特卡罗方法计算最小生成树权值的数学期望估计,
4+
比较估计结果的准确性
5+
6+
步骤:
7+
1. 实现算法产生 n 顶点随机图的生成
8+
输入:n
9+
输出:一个 n 顶点随机图,任意两个顶点之间边的权值均匀分布于 (0, 1)
10+
2. 调用第 1 步实现的算法,实现对 n 顶点图的均匀抽样
11+
3. 在抽样样本上计算最小生成树并计算其权值的数学期望
12+
4. 在第 2 步和第 3 步的基础上,建立 n 与最小生成树权值数学期望之间的关系
13+
5. 对 n = 16, 32, 64, 128, 256, 512, 1024,... 展开实验,考察算法运行时间的变化,
14+
并检验所建立的关系的一般性
15+
6. 尝试用理论分析解释实验结果
16+
7. 撰写实验报告
17+
"""
18+
19+
from graph import *
20+
import time
21+
import matplotlib.pyplot as plt
22+
from mpmath import zeta
23+
24+
def main():
25+
n_list = np.arange(16, 1040, 16)
26+
iter_num = 10
27+
runtimes = []
28+
mst_weights = []
29+
30+
g = RandomGraph(8)
31+
g.randomize()
32+
ret = g.prim()
33+
print()
34+
35+
# for n in n_list:
36+
# graph = RandomGraph(n)
37+
# mst_weight = 0
38+
#
39+
# start_time = time.time()
40+
# for _ in range(iter_num):
41+
# graph.randomize()
42+
# mst_weight += graph.prim()
43+
# end_time = time.time()
44+
#
45+
# runtimes.append((end_time - start_time)/iter_num)
46+
# mst_weights.append(mst_weight/iter_num)
47+
#
48+
# fig = plt.figure(dpi=400)
49+
# ax = fig.add_subplot(111)
50+
# ax.plot(n_list, runtimes, label='runtime')
51+
# ax.set_ylabel('Runtime (s)')
52+
# ax.set_xlabel('Vertex num n')
53+
# ax.set_title('Runtime of Prim Algorithm')
54+
# plt.show()
55+
#
56+
# fig = plt.figure(dpi=400)
57+
# ax = fig.add_subplot(111)
58+
# ax.plot(n_list, mst_weights, label='mst_weights')
59+
# Apery_const = zeta(3) # Apery's constant
60+
# ax.plot([0, n_list[-1]], [Apery_const, Apery_const], linestyle='--', c='gray')
61+
# ax.set_ylabel('Mean weight of MST')
62+
# ax.set_xlabel('Vertex num n')
63+
# ax.set_title('Relation between n and mean weight of MST')
64+
# ax.set_ylim(1.0, 1.4)
65+
# plt.show()
66+
67+
68+
if __name__ == '__main__':
69+
main()

0 commit comments

Comments
 (0)
0