8000 ENH reduce memory consumption in nan_euclidean_distances (#15615) · adrinjalali/scikit-learn@c8b0dc3 · GitHub
[go: up one dir, main page]

Skip to content

Commit c8b0dc3

Browse files
jnothmanadrinjalali
authored andcommitted
ENH reduce memory consumption in nan_euclidean_distances (scikit-learn#15615)
1 parent 78f0be6 commit c8b0dc3

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

sklearn/metrics/pairwise.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -406,20 +406,23 @@ def nan_euclidean_distances(X, Y=None, squared=False,
406406
distances -= np.dot(XX, missing_Y.T)
407407
distances -= np.dot(missing_X, YY.T)
408408

409-
present_coords_cnt = np.dot(1 - missing_X, 1 - missing_Y.T)
410-
present_mask = (present_coords_cnt != 0)
411-
distances[present_mask] *= (X.shape[1] / present_coords_cnt[present_mask])
412-
413409
if X is Y:
414410
# Ensure that distances between vectors and themselves are set to 0.0.
415411
# This may not be the case due to floating point rounding errors.
416412
np.fill_diagonal(distances, 0.0)
417413

414+
present_X = 1 - missing_X
415+
present_Y = present_X if Y is X else ~missing_Y
416+
present_count = np.dot(present_X, present_Y.T)
417+
distances[present_count == 0] = np.nan
418+
# avoid divide by zero
419+
np.maximum(1, present_count, out=present_count)
420+
distances /= present_count
421+
distances *= X.shape[1]
422+
418423
if not squared:
419424
np.sqrt(distances, out=distances)
420425

421-
# coordinates with no common coordinates have a nan distance
422-
distances[~present_mask] = np.nan
423426
return distances
424427

425428

0 commit comments

Comments
 (0)
0