8000 DOC update notebook style for plot_lda_qda (#22528) · scikit-learn/scikit-learn@a3893d3 · GitHub
[go: up one dir, main page]

Skip to content

Commit a3893d3

Browse files
DOC update notebook style for plot_lda_qda (#22528)
Co-authored-by: Loïc Estève <loic.esteve@ymail.com>
1 parent f5d3a8b commit a3893d3

File tree

1 file changed

+26
-10
lines changed

1 file changed

+26
-10
lines changed

examples/classification/plot_lda_qda.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,14 @@ class has its own standard deviation with QDA.
1111
1212
"""
1313

14-
from scipy import linalg
15-
import numpy as np
14+
# %%
15+
# Colormap
16+
# --------
17+
1618
import matplotlib.pyplot as plt
1719
import matplotlib as mpl
1820
from matplotlib import colors
1921

20-
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
21-
from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis
22-
23-
# #############################################################################
24-
# Colormap
2522
cmap = colors.LinearSegmentedColormap(
2623
"red_blue_classes",
2724
{
@@ -33,8 +30,13 @@ class has its own standard deviation with QDA.
3330
plt.cm.register_cmap(cmap=cmap)
3431

3532

36-
# #############################################################################
37-
# Generate datasets
33+
# %%
34+
# Datasets generation functions
35+
# -----------------------------
36+
37+
import numpy as np
38+
39+
3840
def dataset_fixed_cov():
3941
"""Generate 2 Gaussians samples with the same covariance matrix"""
4042
n, dim = 300, 2
@@ -61,8 +63,13 @@ def dataset_cov():
6163
return X, y
6264

6365

64-
# #############################################################################
66+
# %%
6567
# Plot functions
68+
# --------------
69+
70+
from scipy import linalg
71+
72+
6673
def plot_data(lda, X, y, y_pred, fig_index):
6774
splot = plt.subplot(2, 2, fig_index)
6875
if fig_index == 1:
@@ -154,12 +161,20 @@ def plot_qda_cov(qda, splot):
154161
plot_ellipse(splot, qda.means_[1], qda.covariance_[1], "blue")
155162

156163

164+
# %%
165+
# Plot
166+
# ----
167+
157168
plt.figure(figsize=(10, 8), facecolor="white")
158169
plt.suptitle(
159170
"Linear Discriminant Analysis vs Quadratic Discriminant Analysis",
160171
y=0.98,
161172
fontsize=15,
162173
)
174+
175+
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
176+
from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis
177+
163178
for i, (X, y) in enumerate([dataset_fixed_cov(), dataset_cov()]):
164179
# Linear Discriminant Analysis
165180
lda = LinearDiscriminantAnalysis(solver="svd", store_covariance=True)
@@ -174,6 +189,7 @@ def plot_qda_cov(qda, splot):
174189
splot = plot_data(qda, X, y, y_pred, fig_index=2 * i + 2)
175190
plot_qda_cov(qda, splot)
176191
plt.axis("tight")
192+
177193
plt.tight_layout()
178194
plt.subplots_adjust(top=0.92)
179195
plt.show()

0 commit comments

Comments
 (0)
0