8000 Update the test for GaussianMixture and the changelog entry · scikit-learn/scikit-learn@387f935 · GitHub
[go: up one dir, main page]

Skip to content

Commit 387f935

Browse files
committed
Update the test for GaussianMixture and the changelog entry
1 parent 1d80d76 commit 387f935

File tree

2 files changed

+14
-4
lines changed

2 files changed

+14
-4
lines changed

doc/whats_new/v0.20.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,8 @@ Changelog
7474
:mod:`sklearn.mixture`
7575
........................
7676

77-
- |Fix| :func:`mixture.BayesianGaussianMixture` ensure that ``fit_predict``
77+
- |Fix| Ensure that the ``fit_predict`` method of
78+
:class:`mixture.GaussianMixture` and :class:`mixture.BayesianGaussianMixture`
7879
always yield assignments consistent with ``fit`` followed by ``predict`` even
7980
8000 if the convergence criterion is too loose or not met. :issue:`12451`
8081
by :user:`Olivier Grisel <ogrisel>`.

sklearn/mixture/tests/test_gaussian_mixture.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import numpy as np
1010

1111
from scipy import stats, linalg
12+
import pytest
1213

1314
from sklearn.covariance import EmpiricalCovariance
1415
from sklearn.datasets.samples_generator import make_spd_matrix
@@ -571,8 +572,15 @@ def test_gaussian_mixture_predict_predict_proba():
571572
assert_greater(adjusted_rand_score(Y, Y_pred), .95)
572573

573574

574-
def test_gaussian_mixture_fit_predict():
575-
rng = np.random.RandomState(0)
575+
@pytest.mark.filterwarnings("ignore:.*did not converge.*")
576+
@pytest.mark.parametrize('seed, max_iter, ABE6 tol', [
577+
(0, 2, 1e-7), # strict non-convergence
578+
(1, 2, 1e-1), # loose non-convergence
579+
(3, 300, 1e-7), # strict convergence
580+
(4, 300, 1e-1), # loose convergence
581+
])
582+
def test_gaussian_mixture_fit_predict(seed, max_iter, tol):
583+
rng = np.random.RandomState(seed)
576584
rand_data = RandomData(rng)
577585
for covar_type in COVARIANCE_TYPE:
578586
X = rand_data.X[covar_type]
@@ -581,7 +589,8 @@ def test_gaussian_mixture_fit_predict():
581589
random_state=rng, weights_init=rand_data.weights,
582590
means_init=rand_data.means,
583591
precisions_init=rand_data.precisions[covar_type],
584-
covariance_type=covar_type)
592+
covariance_type=covar_type,
593+
max_iter=max_iter, tol=tol)
585594

586595
# check if fit_predict(X) is equivalent to fit(X).predict(X)
587596
f = copy.deepcopy(g)

0 commit comments

Comments
 (0)
0