8000 use config_context · scikit-learn/scikit-learn@5820938 · GitHub
[go: up one dir, main page]

Skip to content

Commit 5820938

Browse files
committed
use config_context
1 parent c86d9e3 commit 5820938

File tree

1 file changed

+18
-22
lines changed

1 file changed

+18
-22
lines changed

sklearn/metrics/tests/test_pairwise.py

Lines changed: 18 additions & 22 deletions
-
Y = rng.random_sample((1000, 10))
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import pytest
1111

12-
from sklearn import set_config, get_config
12+
from sklearn import config_context
1313

1414
from sklearn.utils.testing import assert_greater
1515
from sklearn.utils.testing import assert_array_almost_equal
@@ -907,30 +907,26 @@ def test_pairwise_distances_data_derived_params(n_jobs, metric, dist_function,
907907
y_is_x):
908908
# check that pairwise_distances give the same result in sequential and
909909
# parallel, when metric has data-derived parameters.
910-
wm = get_config()['working_memory']
911-
set_config(working_memory=0.1) # to have more than 1 chunk
910+
with config_context(working_memory=0.1): # to have more than 1 chunk
911+
rng = np.random.RandomState(0)
912912

913-
rng = np.random.RandomState(0)
914-
915-
X = rng.random_sample((1000, 10))
913+
X = rng.random_sample((1000, 10))
916914

917-
if y_is_x:
918-
Y = X
919-
if metric == "seuclidean":
920-
params = {'V': np.var(X, axis=0, ddof=1)}
921-
else:
922-
params = {'VI': np.linalg.inv(np.cov(X.T)).T}
923-
else:
924
925-
if metric == "seuclidean":
926-
params = {'V': np.var(np.vstack([X, Y]), axis=0, ddof=1)}
915+
if y_is_x:
916+
Y = X
917+
if metric == "seuclidean":
918+
params = {'V': np.var(X, axis=0, ddof=1)}
919+
else:
920+
params = {'VI': np.linalg.inv(np.cov(X.T)).T}
927921
else:
928-
params = {'VI': np.linalg.inv(np.cov(np.vstack([X, Y]).T)).T}
929-
930-
expected_dist = cdist(X, Y, metric=metric, **params)
922+
Y = rng.random_sample((1000, 10))
923+
if metric == "seuclidean":
924+
params = {'V': np.var(np.vstack([X, Y]), axis=0, ddof=1)}
925+
else:
926+
params = {'VI': np.linalg.inv(np.cov(np.vstack([X, Y]).T)).T}
931927

932-
dist = np.vstack(dist_function(X, Y, metric=metric, n_jobs=n_jobs))
928+
expected_dist = cdist(X, Y, metric=metric, **params)
933929

934-
assert_allclose(dist, expected_dist)
930+
dist = np.vstack(dist_function(X, Y, metric=metric, n_jobs=n_jobs))
935931

936-
set_config(working_memory=wm) # reset working memory to initial value
932+
assert_allclose(dist, expected_dist)

0 commit comments

Comments
 (0)
0