8000 [MRG] Updating StratifiedKFold user_guide example (#14809) · jeremiedbb/scikit-learn@ee5bc2c · GitHub
[go: up one dir, main page]

10000
Skip to content

Commit ee5bc2c

Browse files
venkyyuvyqinhanmin2014
authored andcommitted
[MRG] Updating StratifiedKFold user_guide example (scikit-learn#14809)
1 parent cad5a92 commit ee5bc2c

File tree

1 file changed

+23
-12
lines changed

1 file changed

+23
-12
lines changed

doc/modules/cross_validation.rst

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -534,19 +534,30 @@ Stratified k-fold
534534
folds: each set contains approximately the same percentage of samples of each
535535
target class as the complete set.
536536

537-
Example of stratified 3-fold cross-validation on a dataset with 10 samples from
538-
two slightly unbalanced classes::
537+
Here is an example of stratified 3-fold cross-validation on a dataset with 50 samples from
538+
two unbalanced classes. We show the number of samples in each class and compare with
539+
:class:`KFold`.
539540

540-
>>> from sklearn.model_selection import StratifiedKFold
541-
542-
>>> X = np.ones(10)
543-
>>> y = [0, 0, 0, 0, 1, 1, 1, 1, 1, 1]
544-
>>> skf = StratifiedKFold(n_splits=3)
545-
>>> for train, test in skf.split(X, y):
546-
... print("%s %s" % (train, test))
547-
[2 3 6 7 8 9] [0 1 4 5]
548-
[0 1 3 4 5 8 9] [2 6 7]
549-
[0 1 2 4 5 6 7] [3 8 9]
541+
>>> from sklearn.model_selection import StratifiedKFold, KFold
542+
>>> import numpy as np
543+
>>> X, y = np.ones((50, 1)), np.hstack(([0] * 45, [1] * 5))
544+
>>> skf = StratifiedKFold(n_splits=3)
545+
>>> for train, test in skf.split(X, y):
546+
... print('train - {} | test - {}'.format(
547+
... np.bincount(y[train]), np.bincount(y[test])))
548+
train - [30 3] | test - [15 2]
549+
train - [30 3] | test - [15 2]
550+
train - [30 4] | test - [15 1]
551+
>>> kf = KFold(n_splits=3)
552+
>>> for train, test in kf.split(X, y):
553+
... print('train - {} | test - {}'.format(
554+
... np.bincount(y[train]), np.bincount(y[test])))
555+
train - [28 5] | test - [17]
556+
train - [28 5] | test - [17]
557+
train - [34] | test - [11 5]
558+
559+
We can see that :class:`StratifiedKFold` preserves the class ratios
560+
(approximately 1 / 10) in both train and test dataset.
550561

551562
Here is a visualization of the cross-validation behavior.
552563

0 commit comments

Comments
 (0)
0