8000 API pairwise_distances will require explicit V/VI param if Y is given… · thomasjpfan/scikit-learn@5b2c931 · GitHub
[go: up one dir, main page]

Skip to content

Commit 5b2c931

Browse files
API pairwise_distances will require explicit V/VI param if Y is given (scikit-learn#16993)
* API pairwise_distances will require explicit V/VI param if Y is given Deprecation until version 0.25. The current approach in `_precompute_metric_params` (https://github.com/scikit-learn/scikit-learn/blob/f82a2cb33871a67b36150647ece1c7e56d3132bb/sklearn/metrics/pairwise.py#L1429-L1444) means that we may be applying a different metric at training and test time. Ideally we'd have a framework for fitting a metric on some specific training data, but in the meantime, this deprecation stops users making mistakes. * DOC update what's new * Update sklearn/metrics/tests/test_pairwise.py Co-Authored-By: Thomas J Fan <thomasjpfan@gmail.com> * Update sklearn/metrics/pairwise.py Co-Authored-By: Thomas J Fan <thomasjpfan@gmail.com> * Update sklearn/metrics/pairwise.py Co-Authored-By: Thomas J Fan <thomasjpfan@gmail.com> * Update sklearn/metrics/tests/test_pairwise.py Co-Authored-By: Thomas J Fan <thomasjpfan@gmail.com> Co-authored-by: Thomas J Fan <thomasjpfan@gmail.com>
1 parent 1ba0651 commit 5b2c931

File tree

3 files changed

+22
-2
lines changed

3 files changed

+22
-2
lines changed

doc/whats_new/v0.23.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,12 @@ Changelog
397397
or 'd'). :pr:`16159` by :user:`Rick Mackenbach <Rick-Mackenbach>` and
398398
`Thomas Fan`_.
399399

400+
- |API| From version 0.25, :func:`metrics.pairwise.pairwise_distances` will no
401+
longer automatically compute the ``VI`` parameter for Mahalanobis distance
402+
and the ``V`` parameter for seuclidean distance if ``Y`` is passed. The user
403+
will be expected to compute this parameter on the training data of their
404+
choice and pass it to `pairwise_distances`. :pr:`16993` by `Joel Nothman`_.
405+
400406
:mod:`sklearn.model_selection`
401407
..............................
402408

sklearn/metrics/pairwise.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1441,12 +1441,18 @@ def _precompute_metric_params(X, Y, metric=None, **kwds):
14411441
if X is Y:
14421442
V = np.var(X, axis=0, ddof=1)
14431443
else:
1444+
warnings.warn("from version 0.25, pairwise_distances for "
1445+
"metric='seuclidean' will require V to be "
1446+
"specified if Y is passed.", FutureWarning)
14441447
V = np.var(np.vstack([X, Y]), axis=0, ddof=1)
14451448
return {'V': V}
14461449
if metric == "mahalanobis" and 'VI' not in kwds:
14471450
if X is Y:
14481451
VI = np.linalg.inv(np.cov(X.T)).T
14491452
else:
1453+
warnings.warn("from version 0.25, pairwise_distances for "
1454+
"metric='mahalanobis' will require VI to be "
1455+
"specified if Y is passed.", FutureWarning)
14501456
VI = np.linalg.inv(np.cov(np.vstack([X, Y]).T)).T
14511457
return {'VI': VI}
14521458
return {}

sklearn/metrics/tests/test_pairwise.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1281,8 +1281,16 @@ def test_pairwise_distances_data_derived_params(n_jobs, metric, dist_function,
12811281
params = {'VI': np.linalg.inv(np.cov(np.vstack([X, Y]).T)).T}
12821282

12831283
expected_dist_explicit_params = cdist(X, Y, metric=metric, **params)
1284-
dist = np.vstack(tuple(dist_function(X, Y,
1285-
metric=metric, n_jobs=n_jobs)))
1284+
# TODO: Remove warn_checker in 0.25
1285+
if y_is_x:
1286+
warn_checker = pytest.warns(None)
1287+
else:
1288+
warn_checker = pytest.warns(FutureWarning,
1289+
match="to be specified if Y is passed")
1290+
with warn_checker:
1291+
dist = np.vstack(tuple(dist_function(X, Y,
1292+
metric=metric,
1293+
n_jobs=n_jobs)))
12861294

12871295
assert_allclose(dist, expected_dist_explicit_params)
12881296
assert_allclose(dist, expected_dist_default_params)

0 commit comments

Comments
 (0)
0