8000 MISC: Make sure that nosetests doesn't try to run the bench · seckcoder/scikit-learn@1fb8ea5 · GitHub
[go: up one dir, main page]

Skip to content

Commit 1fb8ea5

Browse files
committed
MISC: Make sure that nosetests doesn't try to run the bench
1 parent c4b9b00 commit 1fb8ea5

File tree

2 files changed

+34
-33
lines changed

2 files changed

+34
-33
lines changed

benchmarks/__init__.py

Whitespace-only changes.

benchmarks/bench_plot_balltree.py

Lines changed: 34 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -26,36 +26,37 @@ def compare_nbrs(nbrs1, nbrs2):
2626
elif(nbrs1.ndim == 1):
2727
return np.all(nbrs1 == nbrs2)
2828

29-
n_samples = 1000
30-
leaf_size = 1 # leaf size
31-
k = 20
32-
BT_results = []
33-
KDT_results = []
34-
35-
for i in range(1, 10):
36-
print 'Iteration %s' %i
37-
n_features = i*100
38-
X = np.random.random([n_samples, n_features])
39-
40-
t0 = time()
41-
BT = BallTree(X, leaf_size)
42-
d, nbrs1 = BT.query(X, k)
43-
delta = time() - t0
44-
BT_results.append(delta)
45-
46-
t0 = time()
47-
KDT = cKDTree(X, leaf_size)
48-
d, nbrs2 = KDT.query(X, k)
49-
delta = time() - t0
50-
KDT_results.append(delta)
51-
52-
# this checks we get the correct result
53-
assert compare_nbrs(nbrs1, nbrs2)
54-
55-
xx = 100 * np.arange(1, 10)
56-
pl.plot(xx, BT_results, label='scikits.learn (BallTree)')
57-
pl.plot(xx, KDT_results, label='scipy (cKDTree)')
58-
pl.xlabel('number of dimensions')
59-
pl.ylabel('time (seconds)')
60-
pl.legend()
61-
pl.show()
29+
if __name__ == '__main__':
30+
n_samples = 1000
31+
leaf_size = 1 # leaf size
32+
k = 20
33+
BT_results = []
34+
KDT_results = []
35+
36+
for i in range(1, 10):
37+
print 'Iteration %s' %i
38+
n_features = i*100
39+
X = np.random.random([n_samples, n_features])
40+
41+
t0 = time()
42+
BT = BallTree(X, leaf_size)
43+
d, nbrs1 = BT.query(X, k)
44+
delta = time() - t0
45+
BT_results.append(delta)
46+
47+
t0 = time()
48+
KDT = cKDTree(X, leaf_size)
49+
d, nbrs2 = KDT.query(X, k)
50+
delta = time() - t0
51+
KDT_results.append(delta)
52+
53+
# this checks we get the correct result
54+
assert compare_nbrs(nbrs1, nbrs2)
55+
56+
xx = 100 * np.arange(1, 10)
57+
pl.plot(xx, BT_results, label='scikits.learn (BallTree)')
58+
pl.plot(xx, KDT_results, label='scipy (cKDTree)')
59+
pl.xlabel('number of dimensions')
60+
pl.ylabel('time (seconds)')
61+
pl.legend()
62+
pl.show()

0 commit comments

Comments
 (0)
0