8000 FIX euclidean_distances float32 numerical instabilities (#13554) · marcelobeckmann/scikit-learn@8cc70af · GitHub
[go: up one dir, main page]

Skip to content

Commit 8cc70af

Browse files
jeremiedbbmarcelobeckmann
authored andcommitted
FIX euclidean_distances float32 numerical instabilities (scikit-learn#13554)
1 parent fb21d0f commit 8cc70af

File tree

4 files changed

+203
-33
lines changed

4 files changed

+203
-33
lines changed

doc/whats_new/v0.21.rst

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -543,9 +543,14 @@ Support for Python 3.4 and below has been officially dropped.
543543
:pr:`13447` by :user:`Dan Ellis <dpwe>`.
544544

545545
- |API| The parameter ``labels`` in :func:`metrics.hamming_loss` is deprecated
546-
in version 0.21 and will be removed in version 0.23.
547-
:pr:`10580` by :user:`Reshama Shaikh <reshamas>` and :user:`Sandra
548-
Mitrovic <SandraMNE>`.
546+
in version 0.21 and will be removed in version 0.23. :pr:`10580` by
547+
:user:`Reshama Shaikh <reshamas>` and :user:`Sandra Mitrovic <SandraMNE>`.
548+
549+
- |Fix| The function :func:`euclidean_distances`, and therefore
550+
several estimators with ``metric='euclidean'``, suffered from numerical
551+
precision issues with ``float32`` features. Precision has been increased at the
552+
cost of a small drop of performance. :pr:`13554` by :user:`Celelibi` and
553+
:user:`Jérémie du Boisberranger <jeremiedbb>`.
549554

550555
- |API| :func:`metrics.jaccard_similarity_score` is deprecated in favour of
551556
the more consistent :func:`metrics.jaccard_score`. The former behavior for

sklearn/metrics/pairwise.py

Lines changed: 100 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -211,17 +211,24 @@ def euclidean_distances(X, Y=None, Y_norm_squared=None, squared=False,
211211
Y_norm_squared : array-like, shape (n_samples_2, ), optional
212212
Pre-computed dot-products of vectors in Y (e.g.,
213213
``(Y**2).sum(axis=1)``)
214+
May be ignored in some cases, see the note below.
214215
215216
squared : boolean, optional
216217
Return squared Euclidean distances.
217218
218219
X_norm_squared : array-like, shape = [n_samples_1], optional
219220
Pre-computed dot-products of vectors in X (e.g.,
220221
``(X**2).sum(axis=1)``)
222+
May be ignored in some cases, see the note below.
223+
224+
Notes
225+
-----
226+
To achieve better accuracy, `X_norm_squared` and `Y_norm_squared` may be
227+
unused if they are passed as ``float32``.
221228
222229
Returns
223230
-------
224-
distances : {array, sparse matrix}, shape (n_samples_1, n_samples_2)
231+
distances : array, shape (n_samples_1, n_samples_2)
225232
226233
Examples
227234
--------
@@ -242,41 +249,125 @@ def euclidean_distances(X, Y=None, Y_norm_squared=None, squared=False,
242249
"""
243250
X, Y = check_pairwise_arrays(X, Y)
244251

252+
# If norms are passed as float32, they are unused. If arrays are passed as
253+
# float32, norms needs to be recomputed on upcast chunks.
254+
# TODO: use a float64 accumulator in row_norms to avoid the latter.
245255
if X_norm_squared is not None:
246256
XX = check_array(X_norm_squared)
247257
if XX.shape == (1, X.shape[0]):
248258
XX = XX.T
249259
elif XX.shape != (X.shape[0], 1):
250260
raise ValueError(
251261
"Incompatible dimensions for X and X_norm_squared")
262+
if XX.dtype == np.float32:
263+
XX = None
264+
elif X.dtype == np.float32:
265+
XX = None
252266
else:
253267
XX = row_norms(X, squared=True)[:, np.newaxis]
254268

255-
if X is Y: # shortcut in the common case euclidean_distances(X, X)
269+
if X is Y and XX is not None:
270+
# shortcut in the common case euclidean_distances(X, X)
256271
YY = XX.T
257272
elif Y_norm_squared is not None:
258273
YY = np.atleast_2d(Y_norm_squared)
259274

260275
if YY.shape != (1, Y.shape[0]):
261276
raise ValueError(
262277
"Incompatible dimensions for Y and Y_norm_squared")
278+
if YY.dtype == np.float32:
279+
YY = None
280+
elif Y.dtype == np.float32:
281+
YY = None
263282
else:
264283
YY = row_norms(Y, squared=True)[np.newaxis, :]
265284

266-
distances = safe_sparse_dot(X, Y.T, dense_output=True)
267-
distances *= -2
268-
distances += XX
269-
distances += YY
285+
if X.dtype == np.float32:
286+
# To minimize precision issues with float32, we compute the distance
287+
# matrix on chunks of X and Y upcast to float64
288+
distances = _euclidean_distances_upcast(X, XX, Y, YY)
289+
else:
290+
# if dtype is already float64, no need to chunk and upcast
291+
distances = - 2 * safe_sparse_dot(X, Y.T, dense_output=True)
292+
distances += XX
293+
distances += YY
270294
np.maximum(distances, 0, out=distances)
271295

296+
# Ensure that distances between vectors and themselves are set to 0.0.
297+
# This may not be the case due to floating point rounding errors.
272298
if X is Y:
273-
# Ensure that distances between vectors and themselves are set to 0.0.
274-
# This may not be the case due to floating point rounding errors.
275-
distances.flat[::distances.shape[0] + 1] = 0.0
299+
np.fill_diagonal(distances, 0)
276300

277301
return distances if squared else np.sqrt(distances, out=distances)
278302

279303

304+
def _euclidean_distances_upcast(X, XX=None, Y=None, YY=None):
305+
"""Euclidean distances between X and Y
306+
307+
Assumes X and Y have float32 dtype.
308+
Assumes XX and YY have float64 dtype or are None.
309+
310+
X and Y are upcast to float64 by chunks, which size is chosen to limit
311+
memory increase by approximately 10% (at least 10MiB).
312+
"""
313+
n_samples_X = X.shape[0]
314+
n_samples_Y = Y.shape[0]
315+
n_features = X.shape[1]
316+
317+
distances = np.empty((n_samples_X, n_samples_Y), dtype=np.float32)
318+
319+
x_density = X.nnz / np.prod(X.shape) if issparse(X) else 1
320+
y_density = Y.nnz / np.prod(Y.shape) if issparse(Y) else 1
321+
322+
# Allow 10% more memory than X, Y and the distance matrix take (at least
323+
# 10MiB)
324+
maxmem = max(
325+
((x_density * n_samples_X + y_density * n_samples_Y) * n_features
326+
+ (x_density * n_samples_X * y_density * n_samples_Y)) / 10,
327+
10 * 2**17)
328+
329+
# The increase amount of memory in 8-byte blocks is:
330+
# - x_density * batch_size * n_features (copy of chunk of X)
331+
# - y_density * batch_size * n_features (copy of chunk of Y)
332+
# - batch_size * batch_size (chunk of distance matrix)
333+
# Hence x² + (xd+yd)kx = M, where x=batch_size, k=n_features, M=maxmem
334+
# xd=x_density and yd=y_density
335+
tmp = (x_density + y_density) * n_features
336+
batch_size = (-tmp + np.sqrt(tmp**2 + 4 * maxmem)) / 2
337+
batch_size = max(int(batch_size), 1)
338+
339+
x_batches = gen_batches(X.shape[0], batch_size)
340+
y_batches = gen_batches(Y.shape[0], batch_size)
341+
342+
for i, x_slice in enumerate(x_batches):
343+
X_chunk = X[x_slice].astype(np.float64)
344+
if XX is None:
345+
XX_chunk = row_norms(X_chunk, squared=True)[:, np.newaxis]
346+
else:
347+
XX_chunk = XX[x_slice]
348+
349+
for j, y_slice in enumerate(y_batches):
350+
if X is Y and j < i:
351+
# when X is Y the distance matrix is symmetric so we only need
352+
# to compute half of it.
353+
d = distances[y_slice, x_slice].T
354+
355+
else:
356+
Y_chunk = Y[y_slice].astype(np.float64)
357+
if YY is None:
358+
YY_chunk = row_norms(Y_chunk, squared=True)[np.newaxis, :]
359+
else:
360+
YY_chunk = YY[:, y_slice]
361+
362+
d = -2 * safe_sparse_dot(X_chunk, Y_chunk.T, dense_output=True)
363+
d += XX_chunk
364+
d += YY_chunk
365+
366+
distances[x_slice, y_slice] = d.astype(np.float32, copy=False)
367+
368+
return distances
369+
370+
280371
def _argmin_min_reduce(dist, start):
281372
indices = dist.argmin(axis=1)
282373
values = dist[np.arange(dist.shape[0]), indices]

sklearn/metrics/pairwise_fast.pyx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@
77
#
88
# License: BSD 3 clause
99

10-
from libc.string cimport memset
1110
import numpy as np
1211
cimport numpy as np
1312
from cython cimport floating
13+
from libc.string cimport memset
1414

1515
from ..utils._cython_blas cimport _asum
1616

sklearn/metrics/tests/test_pairwise.py

Lines changed: 94 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -585,41 +585,115 @@ def test_pairwise_distances_chunked():
585585
assert_raises(StopIteration, next, gen)
586586

587587

588-
def test_euclidean_distances():
589-
# Check the pairwise Euclidean distances computation
590-
X = [[0]]
591-
Y = [[1], [2]]
588+
@pytest.mark.parametrize("x_array_constr", [np.array, csr_matrix],
589+
ids=["dense", "sparse"])
590+
@pytest.mark.parametrize("y_array_constr", [np.array, csr_matrix],
591+
ids=["dense", "sparse"])
592+
def test_euclidean_distances_known_result(x_array_constr, y_array_constr):
593+
# Check the pairwise Euclidean distances computation on known result
594+
X = x_array_constr([[0]])
595+
Y = y_array_constr([[1], [2]])
592596
D = euclidean_distances(X, Y)
593-
assert_array_almost_equal(D, [[1., 2.]])
597+
assert_allclose(D, [[1., 2.]])
594598

595-
X = csr_matrix(X)
596-
Y = csr_matrix(Y)
597-
D = euclidean_distances(X, Y)
598-
assert_array_almost_equal(D, [[1., 2.]])
599599

600+
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
601+
@pytest.mark.parametrize("y_array_constr", [np.array, csr_matrix],
602+
ids=["dense", "sparse"])
603+
def test_euclidean_distances_with_norms(dtype, y_array_constr):
604+
# check that we still get the right answers with {X,Y}_norm_squared
605+
# and that we get a wrong answer with wrong {X,Y}_norm_squared
600606
rng = np.random.RandomState(0)
601-
X = rng.random_sample((10, 4))
602-
Y = rng.random_sample((20, 4))
603-
X_norm_sq = (X ** 2).sum(axis=1).reshape(1, -1)
604-
Y_norm_sq = (Y ** 2).sum(axis=1).reshape(1, -1)
607+
X = rng.random_sample((10, 10)).astype(dtype, copy=False)
608+
Y = rng.random_sample((20, 10)).astype(dtype, copy=False)
609+
610+
# norms will only be used if their dtype is float64
611+
X_norm_sq = (X.astype(np.float64) ** 2).sum(axis=1).reshape(1, -1)
612+
Y_norm_sq = (Y.astype(np.float64) ** 2).sum(axis=1).reshape(1, -1)
613+
614+
Y = y_array_constr(Y)
605615

606-
# check that we still get the right answers with {X,Y}_norm_squared
607616
D1 = euclidean_distances(X, Y)
608617
D2 = euclidean_distances(X, Y, X_norm_squared=X_norm_sq)
609618
D3 = euclidean_distances(X, Y, Y_norm_squared=Y_norm_sq)
610619
D4 = euclidean_distances(X, Y, X_norm_squared=X_norm_sq,
611620
Y_norm_squared=Y_norm_sq)
612-
assert_array_almost_equal(D2, D1)
613-
assert_array_almost_equal(D3, D1)
614-
assert_array_almost_equal(D4, D1)
621+
assert_allclose(D2, D1)
622+
assert_allclose(D3, D1)
623+
assert_allclose(D4, D1)
615624

616625
# check we get the wrong answer with wrong {X,Y}_norm_squared
617-
X_norm_sq *= 0.5
618-
Y_norm_sq *= 0.5
619626
wrong_D = euclidean_distances(X, Y,
620627
X_norm_squared=np.zeros_like(X_norm_sq),
621628
Y_norm_squared=np.zeros_like(Y_norm_sq))
622-
assert_greater(np.max(np.abs(wrong_D - D1)), .01)
629+
with pytest.raises(AssertionError):
630+
assert_allclose(wrong_D, D1)
631+
632+
633+
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
634+
@pytest.mark.parametrize("x_array_constr", [np.array, csr_matrix],
635+
ids=["dense", "sparse"])
636+
@pytest.mark.parametrize("y_array_constr", [np.array, csr_matrix],
637+
ids=["dense", "sparse"])
638+
def test_euclidean_distances(dtype, x_array_constr, y_array_constr):
639+
# check that euclidean distances gives same result as scipy cdist
640+
# when X and Y != X are provided
641+
rng = np.random.RandomState(0)
642+
X = rng.random_sample((100, 10)).astype(dtype, copy=False)
643+
X[X < 0.8] = 0
644+
Y = rng.random_sample((10, 10)).astype(dtype, copy=False)
645+
Y[Y < 0.8] = 0
646+
647+
expected = cdist(X, Y)
648+
649+
X = x_array_constr(X)
650+
Y = y_array_constr(Y)
651+
distances = euclidean_distances(X, Y)
652+
653+
# the default rtol=1e-7 is too close to the float32 precision
654+
# and fails due too rounding errors.
655+
assert_allclose(distances, expected, rtol=1e-6)
656+
assert distances.dtype == dtype
657+
658+
659+
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
660+
@pytest.mark.parametrize("x_array_constr", [np.array, csr_matrix],
661+
ids=["dense", "sparse"])
662+
def test_euclidean_distances_sym(dtype, x_array_constr):
663+
# check that euclidean distances gives same result as scipy pdist
664+
# when only X is provided
665+
rng = np.random.RandomState(0)
666+
X = rng.random_sample((100, 10)).astype(dtype, copy=False)
667+
X[X < 0.8] = 0
668+
669+
expected = squareform(pdist(X))
670+
671+
X = x_array_constr(X)
672+
distances = euclidean_distances(X)
673+
674+
# the default rtol=1e-7 is too close to the float32 precision
675+
# and fails due too rounding errors.
676+
assert_allclose(distances, expected, rtol=1e-6)
677+
assert distances.dtype == dtype
678+
679+
680+
@pytest.mark.parametrize(
681+
"dtype, eps, rtol",
682+
[(np.float32, 1e-4, 1e-5),
683+
pytest.param(
684+
np.float64, 1e-8, 0.99,
685+
marks=pytest.mark.xfail(reason='failing due to lack of precision'))])
686+
@pytest.mark.parametrize("dim", [1, 1000000])
687+
def test_euclidean_distances_extreme_values(dtype, eps, rtol, dim):
688+
# check that euclidean distances is correct with float32 input thanks to
689+
# upcasting. On float64 there are still precision issues.
690+
X = np.array([[1.] * dim], dtype=dtype)
691+
Y = np.array([[1. + eps] * dim], dtype=dtype)
692+
693+
distances = euclidean_distances(X, Y)
694+
expected = cdist(X, Y)
695+
696+
assert_allclose(distances, expected, rtol=1e-5)
623697

624698

625699
def test_cosine_distances():

0 commit comments

Comments
 (0)
0