8000 EXA/MAINT Simplify code in manifold learning example (#15949) · scikit-learn/scikit-learn@c393bdc · GitHub
[go: up one dir, main page]

Skip to content

Commit c393bdc

Browse files
DavidBreuerrth
authored andcommitted
EXA/MAINT Simplify code in manifold learning example (#15949)
1 parent 2a185f9 commit c393bdc

File tree

1 file changed

+29
-65
lines changed

1 file changed

+29
-65
lines changed

examples/manifold/plot_compare_methods.py

+29-65
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323

2424
print(__doc__)
2525

26+
from collections import OrderedDict
27+
from functools import partial
2628
from time import time
2729

2830
import matplotlib.pyplot as plt
@@ -39,81 +41,43 @@
3941
n_neighbors = 10
4042
n_components = 2
4143

44+
# Create figure
4245
fig = plt.figure(figsize=(15, 8))
43-
plt.suptitle("Manifold Learning with %i points, %i neighbors"
46+
fig.suptitle("Manifold Learning with %i points, %i neighbors"
4447
% (1000, n_neighbors), fontsize=14)
4548

46-
49+
# Add 3d scatter plot
4750
ax = fig.add_subplot(251, projection='3d')
4851
ax.scatter(X[:, 0], X[:, 1], X[:, 2], c=color, cmap=plt.cm.Spectral)
4952
ax.view_init(4, -72)
5053

51-
methods = ['standard', 'ltsa', 'hessian', 'modified']
52-
labels = ['LLE', 'LTSA', 'Hessian LLE', 'Modified LLE']
53-
54-
for i, method in enumerate(methods):
54+
# Set-up manifold methods
55+
LLE = partial(manifold.LocallyLinearEmbedding,
56+
n_neighbors, n_components, eigen_solver='auto')
57+
58+
methods = OrderedDict()
59+
methods['LLE'] = LLE(method='standard')
60+
methods['LTSA'] = LLE(method='ltsa')
61+
methods['Hessian LLE'] = LLE(method='hessian')
62+
methods['Modified LLE'] = LLE(method='modified')
63+
methods['Isomap'] = manifold.Isomap(n_neighbors, n_components)
64+
methods['MDS'] = manifold.MDS(n_components, max_iter=100, n_init=1)
65+
methods['SE'] = manifold.SpectralEmbedding(n_components=n_components,
66+
n_neighbors=n_neighbors)
67+
methods['t-SNE'] = manifold.TSNE(n_components=n_components, init='pca',
68+
random_state=0)
69+
70+
# Plot results
71+
for i, (label, method) in enumerate(methods.items()):
5572
t0 = time()
56-
Y = manifold.LocallyLinearEmbedding(n_neighbors, n_components,
57-
eigen_solver='auto',
58-
method=method).fit_transform(X)
73+
Y = method.fit_transform(X)
5974
t1 = time()
60-
print("%s: %.2g sec" % (methods[i], t1 - t0))
61-
62-
ax = fig.add_subplot(252 + i)
63-
plt.scatter(Y[:, 0], Y[:, 1], c=color, cmap=plt.cm.Spectral)
64-
plt.title("%s (%.2g sec)" % (labels[i], t1 - t0))
75+
print("%s: %.2g sec" % (label, t1 - t0))
76+
ax = fig.add_subplot(2, 5, 2 + i + (i > 3))
77+
ax.scatter(Y[:, 0], Y[:, 1], c=color, cmap=plt.cm.Spectral)
78+
ax.set_title("%s (%.2g sec)" % (label, t1 - t0))
6579
ax.xaxis.set_major_formatter(NullFormatter())
6680
ax.yaxis.set_major_formatter(NullFormatter())
67-
plt.axis('tight')
68-
69< 8000 /code>-
t0 = time()
70-
Y = manifold.Isomap(n_neighbors, n_components).fit_transform(X)
71-
t1 = time()
72-
print("Isomap: %.2g sec" % (t1 - t0))
73-
ax = fig.add_subplot(257)
74-
plt.scatter(Y[:, 0], Y[:, 1], c=color, cmap=plt.cm.Spectral)
75-
plt.title("Isomap (%.2g sec)" % (t1 - t0))
76-
ax.xaxis.set_major_formatter(NullFormatter())
77-
ax.yaxis.set_major_formatter(NullFormatter())
78-
plt.axis('tight')
79-
80-
81-
t0 = time()
82-
mds = manifold.MDS(n_components, max_iter=100, n_init=1)
83-
Y = mds.fit_transform(X)
84-
t1 = time()
85-
print("MDS: %.2g sec" % (t1 - t0))
86-
ax = fig.add_subplot(258)
87-
plt.scatter(Y[:, 0], Y[:, 1], c=color, cmap=plt.cm.Spectral)
88-
plt.title("MDS (%.2g sec)" % (t1 - t0))
89-
ax.xaxis.set_major_formatter(NullFormatter())
90-
ax.yaxis.set_major_formatter(NullFormatter())
91-
plt.axis('tight')
92-
93-
94-
t0 = time()
95-
se = manifold.SpectralEmbedding(n_components=n_components,
96-
n_neighbors=n_neighbors)
97-
Y = se.fit_transform(X)
98-
t1 = time()
99-
print("SpectralEmbedding: %.2g sec" % (t1 - t0))
100-
ax = fig.add_subplot(259)
101-
plt.scatter(Y[:, 0], Y[:, 1], c=color, cmap=plt.cm.Spectral)
102-
plt.title("SpectralEmbedding (%.2g sec)" % (t1 - t0))
103-
ax.xaxis.set_major_formatter(NullFormatter())
104-
ax.yaxis.set_major_formatter(NullFormatter())
105-
plt.axis('tight')
106-
107-
t0 = time()
108-
tsne = manifold.TSNE(n_components=n_components, init='pca', random_state=0)
109-
Y = tsne.fit_transform(X)
110-
t1 = time()
111-
print("t-SNE: %.2g sec" % (t1 - t0))
112-
ax = fig.add_subplot(2, 5, 10)
113-
plt.scatter(Y[:, 0], Y[:, 1], c=color, cmap=plt.cm.Spectral)
114-
plt.title("t-SNE (%.2g sec)" % (t1 - t0))
115-
ax.xaxis.set_major_formatter(NullFormatter())
116-
ax.yaxis.set_major_formatter(NullFormatter())
117-
plt.axis('tight')
81+
ax.axis('tight')
11882

11983
plt.show()

0 commit comments

Comments
 (0)
0