@@ -11,17 +11,14 @@ class has its own standard deviation with QDA.
11
11
12
12
"""
13
13
14
- from scipy import linalg
15
- import numpy as np
14
+ # %%
15
+ # Colormap
16
+ # --------
17
+
16
18
import matplotlib .pyplot as plt
17
19
import matplotlib as mpl
18
20
from matplotlib import colors
19
21
20
- from sklearn .discriminant_analysis import LinearDiscriminantAnalysis
21
- from sklearn .discriminant_analysis import QuadraticDiscriminantAnalysis
22
-
23
- # #############################################################################
24
- # Colormap
25
22
cmap = colors .LinearSegmentedColormap (
26
23
"red_blue_classes" ,
27
24
{
@@ -33,8 +30,13 @@ class has its own standard deviation with QDA.
33
30
plt .cm .register_cmap (cmap = cmap )
34
31
35
32
36
- # #############################################################################
37
- # Generate datasets
33
+ # %%
34
+ # Datasets generation functions
35
+ # -----------------------------
36
+
37
+ import numpy as np
38
+
39
+
38
40
def dataset_fixed_cov ():
39
41
"""Generate 2 Gaussians samples with the same covariance matrix"""
40
42
n , dim = 300 , 2
@@ -61,8 +63,13 @@ def dataset_cov():
61
63
return X , y
62
64
63
65
64
- # #############################################################################
66
+ # %%
65
67
# Plot functions
68
+ # --------------
69
+
70
+ from scipy import linalg
71
+
72
+
66
73
def plot_data (lda , X , y , y_pred , fig_index ):
67
74
splot = plt .subplot (2 , 2 , fig_index )
68
75
if fig_index == 1 :
@@ -154,12 +161,20 @@ def plot_qda_cov(qda, splot):
154
161
plot_ellipse (splot , qda .means_ [1 ], qda .covariance_ [1 ], "blue" )
155
162
156
163
164
+ # %%
165
+ # Plot
166
+ # ----
167
+
157
168
plt .figure (figsize = (10 , 8 ), facecolor = "white" )
158
169
plt .suptitle (
159
170
"Linear Discriminant Analysis vs Quadratic Discriminant Analysis" ,
160
171
y = 0.98 ,
161
172
fontsize = 15 ,
162
173
)
174
+
175
+ from sklearn .discriminant_analysis import LinearDiscriminantAnalysis
176
+ from sklearn .discriminant_analysis import QuadraticDiscriminantAnalysis
177
+
163
178
for i , (X , y ) in enumerate ([dataset_fixed_cov (), dataset_cov ()]):
164
179
# Linear Discriminant Analysis
165
180
lda = LinearDiscriminantAnalysis (solver = "svd" , store_covariance = True )
@@ -174,6 +189,7 @@ def plot_qda_cov(qda, splot):
174
189
splot = plot_data (qda , X , y , y_pred , fig_index = 2 * i + 2 )
175
190
plot_qda_cov (qda , splot )
176
191
plt .axis ("tight" )
192
+
177
193
plt .tight_layout ()
178
194
plt .subplots_adjust (top = 0.92 )
179
195
plt .show ()
0 commit comments