8000 Revert changes to test_bayesian_mixture.py · lesteve/scikit-learn@ce214a6 · GitHub
[go: up one dir, main page]

Skip to content

Commit ce214a6

Browse files
committed
Revert changes to test_bayesian_mixture.py
1 parent 4fe3766 commit ce214a6

File tree

1 file changed

+4
-6
lines changed

1 file changed

+4
-6
lines changed

sklearn/mixture/tests/test_bayesian_mixture.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from sklearn.mixture import BayesianGaussianMixture
1313
from sklearn.mixture._bayesian_mixture import _log_dirichlet_norm, _log_wishart_norm
1414
from sklearn.mixture.tests.test_gaussian_mixture import RandomData
15-
from sklearn.utils._array_api import get_namespace
1615
from sklearn.utils._testing import (
1716
assert_almost_equal,
1817
assert_array_equal,
@@ -260,7 +259,6 @@ def test_compare_covar_type():
260259
rand_data = RandomData(rng, scale=7)
261260
X = rand_data.X["full"]
262261
n_components = rand_data.n_components
263-
xp, _ = get_namespace(X)
264262

265263
for prior_type in PRIOR_TYPE:
266264
# Computation of the full_covariance
@@ -273,7 +271,7 @@ def test_compare_covar_type():
273271
tol=1e-7,
274272
)
275273
bgmm._check_parameters(X)
276-
bgmm._initialize_parameters(X, np.random.RandomState(0), xp=xp)
274+
bgmm._initialize_parameters(X, np.random.RandomState(0))
277275
full_covariances = (
278276
bgmm.covariances_ * bgmm.degrees_of_freedom_[:, np.newaxis, np.newaxis]
279277
)
@@ -288,7 +286,7 @@ def test_compare_covar_type():
288286
tol=1e-7,
289287
)
290288
bgmm._check_parameters(X)
291-
bgmm._initialize_parameters(X, np.random.RandomState(0), xp=xp)
289+
bgmm._initialize_parameters(X, np.random.RandomState(0))
292290

293291
tied_covariance = bgmm.covariances_ * bgmm.degrees_of_freedom_
294292
assert_almost_equal(tied_covariance, np.mean(full_covariances, 0))
@@ -303,7 +301,7 @@ def test_compare_covar_type():
303301
tol=1e-7,
304302
)
305303
bgmm._check_parameters(X)
306-
bgmm._initialize_parameters(X, np.random.RandomState(0), xp=xp)
304+
bgmm._initialize_parameters(X, np.random.RandomState(0))
307305

308306
diag_covariances = bgmm.covariances_ * bgmm.degrees_of_freedom_[:, np.newaxis]
309307
assert_almost_equal(
@@ -320,7 +318,7 @@ def test_compare_covar_type():
320318
tol=1e-7,
321319
)
322320
bgmm._check_parameters(X)
323-
bgmm._initialize_parameters(X, np.random.RandomState(0), xp=xp)
321+
bgmm._initialize_parameters(X, np.random.RandomState(0))
324322

325323
spherical_covariances = bgmm.covariances_ * bgmm.degrees_of_freedom_
326324
assert_almost_equal(spherical_covariances, np.mean(diag_covariances, 1))

0 commit comments

Comments
 (0)
0