8000 DOC use train/test split in GaussianNB example (#14080) · scikit-learn/scikit-learn@e2a69b7 · GitHub
[go: up one dir, main page]

Skip to content

Commit e2a69b7

Browse files
CYHSMqinhanmin2014
authored andcommitted
DOC use train/test split in GaussianNB example (#14080)
1 parent bec8308 commit e2a69b7

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

doc/modules/naive_bayes.rst

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -90,14 +90,16 @@ classification. The likelihood of the features is assumed to be Gaussian:
9090
The parameters :math:`\sigma_y` and :math:`\mu_y`
9191
are estimated using maximum likelihood.
9292

93-
>>> from sklearn import datasets
94-
>>> iris = datasets.load_iris()
95-
>>> from sklearn.naive_bayes import GaussianNB
96-
>>> gnb = GaussianNB()
97-
>>> y_pred = gnb.fit(iris.data, iris.target).predict(iris.data)
98-
>>> print("Number of mislabeled points out of a total %d points : %d"
99-
... % (iris.data.shape[0],(iris.target != y_pred).sum()))
100-
Number of mislabeled points out of a total 150 points : 6
93+
>>> from sklearn.datasets import load_iris
94+
>>> from sklearn.model_selection import train_test_split
95+
>>> from sklearn.naive_bayes import GaussianNB
96+
>>> X, y = load_iris(return_X_y=True)
97+
>>> X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=0)
98+
>>> gnb = GaussianNB()
99+
>>> y_pred = gnb.fit(X_train, y_train).predict(X_test)
100+
>>> print("Number of mislabeled points out of a total %d points : %d"
101+
... % (X_test.shape[0], (y_test != y_pred).sum()))
102+
Number of mislabeled points out of a total 75 points : 4
101103

102104
.. _multinomial_naive_bayes:
103105

0 commit comments

Comments
 (0)
0