diff --git a/examples/cluster/plot_digits_linkage.py b/examples/cluster/plot_digits_linkage.py index e13e83047fee3..730f85c543356 100644 --- a/examples/cluster/plot_digits_linkage.py +++ b/examples/cluster/plot_digits_linkage.py @@ -37,7 +37,8 @@ from sklearn import manifold, datasets -X, y = datasets.load_digits(return_X_y=True) +digits = datasets.load_digits() +X, y = digits.data, digits.target n_samples, n_features = X.shape np.random.seed(0) @@ -50,13 +51,13 @@ def plot_clustering(X_red, labels, title=None): X_red = (X_red - x_min) / (x_max - x_min) plt.figure(figsize=(6, 4)) - for i in range(X_red.shape[0]): - plt.text( - X_red[i, 0], - X_red[i, 1], - str(y[i]), - color=plt.cm.nipy_spectral(labels[i] / 10.0), - fontdict={"weight": "bold", "size": 9}, + for digit in digits.target_names: + plt.scatter( + *X_red[y == digit].T, + marker=f"${digit}$", + s=50, + c=plt.cm.nipy_spectral(labels[y == digit] / 10), + alpha=0.5, ) plt.xticks([])