8000 DOC improve iris example (#26973) · scikit-learn/scikit-learn@a31e108 · GitHub
[go: up one dir, main page]

Skip to content

Commit a31e108

Browse files
eguentherElisabeth Güntherglemaitre
authored
DOC improve iris example (#26973)
Co-authored-by: Elisabeth Günther <eguenther@MacBook-Pro-von-Elisabeth.local> Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
1 parent 844b087 commit a31e108

File tree

2 files changed

+49
-32
lines changed

2 files changed

+49
-32
lines changed
Lines changed: 46 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""
2-
=========================================================
2+
================
33
The Iris Dataset
4-
=========================================================
4+
================
55
This data sets consists of 3 different types of irises'
66
(Setosa, Versicolour, and Virginica) petal and sepal
77
length, stored in a 150x4 numpy.ndarray
@@ -19,37 +19,47 @@
1919
# Modified for documentation by Jaques Grobler
2020
# License: BSD 3 clause
2121

22-
import matplotlib.pyplot as plt
23-
24-
# unused but required import for doing 3d projections with matplotlib < 3.2
25-
import mpl_toolkits.mplot3d # noqa: F401
26-
22+
# %%
23+
# Loading the iris dataset
24+
# ------------------------
2725
from sklearn import datasets
28-
from sklearn.decomposition import PCA
2926

30-
# import some data to play with
3127
iris = datasets.load_iris()
32-
X = iris.data[:, :2] # we only take the first two features.
33-
y = iris.target
3428

35-
x_min, x_max = X[:, 0].min() - 0.5, X[:, 0].max() + 0.5
36-
y_min, y_max = X[:, 1].min() - 0.5, X[:, 1].max() + 0.5
3729

38-
plt.figure(2, figsize=(8, 6))
39-
plt.clf()
30+
# %%
31+
# Scatter Plot of the Iris dataset
32+
# --------------------------------
33+
import matplotlib.pyplot as plt
34+
35+
_, ax = plt.subplots()
36+
scatter = ax.scatter(iris.data[:, 0], iris.data[:, 1], c=iris.target)
37+
ax.set(xlabel=iris.feature_names[0], ylabel=iris.feature_names[1])
38+
_ = ax.legend(
39+
scatter.legend_elements()[0], iris.target_names, loc="lower right", title="Classes"
40+
)
4041

41-
# Plot the training points
42-
plt.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.Set1, edgecolor="k")
43-
plt.xlabel("Sepal length")
44-
plt.ylabel("Sepal width")
42+
# %%
43+
# Each point in the scatter plot refers to one of the 150 iris flowers
44+
# in the dataset, with the color indicating their respective type
45+
# (Setosa, Versicolour, and Virginica).
46+
# You can already see a pattern regarding the Setosa type, which is
47+
# easily identifiable based on its short and wide sepal. Only
48+
# considering these 2 dimensions, sepal width and length, there's still
49+
# overlap between the Versicolor and Virginica types.
50+
51+
# %%
52+
# Plot a PCA representation
53+
# -------------------------
54+
# Let's apply a Principal Component Analysis (PCA) to the iris dataset
55+
# and then plot the irises across the first three PCA dimensions.
56+
# This will allow us to better differentiate between the three types!
4557

46-
plt.xlim(x_min, x_max)
47-
plt.ylim(y_min, y_max)
48-
plt.xticks(())
49-
plt.yticks(())
58+
# unused but required import for doing 3d projections with matplotlib < 3.2
59+
import mpl_toolkits.mplot3d # noqa: F401
60+
61+
from sklearn.decomposition import PCA
5062

51-
# To getter a better understanding of interaction of the dimensions
52-
# plot the first three PCA dimensions
5363
fig = plt.figure(1, figsize=(8, 6))
5464
ax = fig.add_subplot(111, projection="3d", elev=-150, azim=110)
5565

@@ -58,18 +68,22 @@
5868
X_reduced[:, 0],
5969
X_reduced[:, 1],
6070
X_reduced[:, 2],
61-
c=y,
62-
cmap=plt.cm.Set1,
63-
edgecolor="k",
71+
c=iris.target,
6472
s=40,
6573
)
6674

67-
ax.set_title("First three PCA directions")
68-
ax.set_xlabel("1st eigenvector")
75+
ax.set_title("First three PCA dimensions")
76+
ax.set_xlabel("1st Eigenvector")
6977
ax.xaxis.set_ticklabels([])
70-
ax.set_ylabel("2nd eigenvector")
78+
ax.set_ylabel("2nd Eigenvector")
7179
ax.yaxis.set_ticklabels([])
72-
ax.set_zlabel("3rd eigenvector")
80+
ax.set_zlabel("3rd Eigenvector")
7381
ax.zaxis.set_ticklabels([])
7482

7583
plt.show()
84+
85+
# %%
86+
# PCA will create 3 new features that are a linear combination of the
87+
# 4 original features. In addition, this transform maximizes the variance.
88+
# With this transformation, we see that we can identify each species using
89+
# only the first feature (i.e. first eigenvalues).

sklearn/datasets/_base.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -667,6 +667,9 @@ def load_iris(*, return_X_y=False, as_frame=False):
667667
array([0, 0, 1])
668668
>>> list(data.target_names)
669669
['setosa', 'versicolor', 'virginica']
670+
671+
See :ref:`sphx_glr_auto_examples_datasets_plot_iris_dataset.py` for a more
672+
detailed example of how to work with the iris dataset.
670673
"""
671674
data_file_name = "iris.csv"
672675
data, target, target_names, fdescr = load_csv_data(

0 commit comments

Comments
 (0)
0