8000 also test against pdist/cdist default params · scikit-learn/scikit-learn@0eea3f1 · GitHub
[go: up one dir, main page]

Skip to content

Commit 0eea3f1

Browse files
committed
also test against pdist/cdist default params
1 parent 421c0ed commit 0eea3f1

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

sklearn/metrics/tests/test_pairwise.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from scipy.sparse import dok_matrix, csr_matrix, issparse
77
from scipy.spatial.distance import cosine, cityblock, minkowski, wminkowski
8-
from scipy.spatial.distance import cdist
8+
from scipy.spatial.distance import cdist, pdist, squareform
99

1010
import pytest
1111

@@ -914,19 +914,22 @@ def test_pairwise_distances_data_derived_params(n_jobs, metric, dist_function,
914914

915915
if y_is_x:
916916
Y = X
917+
expected_dist_default_params = squareform(pdist(X, metric=metric))
917918
if metric == "seuclidean":
918919
params = {'V': np.var(X, axis=0, ddof=1)}
919920
else:
920921
params = {'VI': np.linalg.inv(np.cov(X.T)).T}
921922
else:
922923
Y = rng.random_sample((1000, 10))
924+
expected_dist_default_params = cdist(X, Y, metric=metric)
923925
if metric == "seuclidean":
924926
params = {'V': np.var(np.vstack([X, Y]), axis=0, ddof=1)}
925927
else:
926928
params = {'VI': np.linalg.inv(np.cov(np.vstack([X, Y]).T)).T}
927929

928-
expected_dist = cdist(X, Y, metric=metric, **params)
930+
expected_dist_explicit_params = cdist(X, Y, metric=metric, **params)
929931

930932
dist = np.vstack(dist_function(X, Y, metric=metric, n_jobs=n_jobs))
931933

932-
assert_allclose(dist, expected_dist)
934+
assert_allclose(dist, expected_dist_explicit_params)
935+
assert_allclose(dist, expected_dist_default_params)

0 commit comments

Comments
 (0)
0