From 85ae9a43ca6888318fb8f0b86135c4fc76f5614e Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Mon, 30 May 2022 10:42:19 +0200 Subject: [PATCH 01/17] TST Add test for quasi equa This comes before porting implementations toits. --- .../test_pairwise_distances_reduction.py | 129 ++++++++++++++---- 1 file changed, 106 insertions(+), 23 deletions(-) diff --git a/sklearn/metrics/tests/test_pairwise_distances_reduction.py b/sklearn/metrics/tests/test_pairwise_distances_reduction.py index 192f7ef43a6c6..cb698f89bcff2 100644 --- a/sklearn/metrics/tests/test_pairwise_distances_reduction.py +++ b/sklearn/metrics/tests/test_pairwise_distances_reduction.py @@ -66,7 +66,7 @@ def _get_metric_params_list(metric: str, n_features: int, seed: int = 1): return [{}] -def assert_argkmin_results_equality(ref_dist, dist, ref_indices, indices): +def assert_argkmin_results_equality(ref_dist, dist, ref_indices, indices, rtol=1e-7): assert_array_equal( ref_indices, indices, @@ -76,10 +76,69 @@ def assert_argkmin_results_equality(ref_dist, dist, ref_indices, indices): ref_dist, dist, err_msg="Query vectors have different neighbors' distances", - rtol=1e-7, + rtol=rtol, ) +def assert_argkmin_results_quasi_equality( + ref_dist, dist, ref_indices, indices, rtol=1e-4 +): + + ref_dist, dist, ref_indices, indices = map( + np.ndarray.flatten, [ref_dist, dist, ref_indices, indices] + ) + + assert ( + len(ref_dist) == len(dist) == len(ref_indices) == len(indices) + ), "Arrays of results have various length." + + n = len(ref_dist) + + skip_permutation_check = False + + for i in range(n - 1): + # We test the equality of pair of adjacent indices and distances + # of the references against the results. + rd_prev, rd_current, rd_next = ref_dist[i - 1], ref_dist[i], ref_dist[i + 1] + d_prev, d_current, d_next = dist[i - 1], dist[i], dist[i + 1] + ri_prev, ri_current, ri_next = ( + ref_indices[i - 1], + ref_indices[i], + ref_indices[i + 1], + ) + i_prev, i_current, i_next = indices[i - 1], indices[i], indices[i + 1] + + assert np.isclose( + d_current, rd_current, rtol=rtol + ), "Query vectors have different neighbors' distances" + + if ri_current != i_current: + # If the current reference index and index are different, + # it might be that their were permuted because their distances + # are relatively close to each other. + # In this case, we need to check for a valid permutation. + valid_permutation = ( + np.isclose(d_current, d_next, rtol=rtol) + and i_next == ri_current + and ri_next == i_current + ) + assert skip_permutation_check or valid_permutation, ( + "Query vectors have different neighbors' indices \n" + f"(i_prev, i_current, i_next) = {i_prev, i_current, i_next} \n" + f"(ri_prev, ri_current, ri_next) = {ri_prev, ri_current, ri_next} \n" + f"(d_prev, d_current, d_next) = {d_prev, d_current, d_next} \n" + f"(rd_prev, rd_current, rd_next) = {rd_prev, rd_current, rd_next} \n" + ) + # If there's a permutation at this iteration, we need to + # skip the following permutation check. + skip_permutation_check = True + continue + + # We need to check for potential permutations for the next iterations. + if skip_permutation_check: + skip_permutation_check = False + + def assert_radius_neighborhood_results_equality(ref_dist, dist, ref_indices, indices): # We get arrays of arrays and we need to check for individual pairs for i in range(ref_dist.shape[0]): @@ -97,8 +156,20 @@ def assert_radius_neighborhood_results_equality(ref_dist, dist, ref_indices, ind ASSERT_RESULT = { - PairwiseDistancesArgKmin: assert_argkmin_results_equality, - PairwiseDistancesRadiusNeighborhood: assert_radius_neighborhood_results_equality, + # In the case of 64bit, we test for exact equality. + (PairwiseDistancesArgKmin, np.float64): assert_argkmin_results_equality, + ( + PairwiseDistancesRadiusNeighborhood, + np.float64, + ): assert_radius_neighborhood_results_equality, + # In the case of 32bit, indices can be permuted due to small difference + # in the computations of their associated distances, hence we test equality of + # results up to valid permutations. + (PairwiseDistancesArgKmin, np.float32): assert_argkmin_results_quasi_equality, + ( + PairwiseDistancesRadiusNeighborhood, + np.float32, + ): assert_radius_neighborhood_results_equality, } @@ -107,13 +178,15 @@ def test_pairwise_distances_reduction_is_usable_for(): X = rng.rand(100, 10) Y = rng.rand(100, 10) metric = "euclidean" - assert PairwiseDistancesReduction.is_usable_for(X, Y, metric) + + assert PairwiseDistancesReduction.is_usable_for( + X.astype(np.float64), X.astype(np.float64), metric + ) assert not PairwiseDistancesReduction.is_usable_for( X.astype(np.int64), Y.astype(np.int64), metric ) assert not PairwiseDistancesReduction.is_usable_for(X, Y, metric="pyfunc") - # TODO: remove once 32 bits datasets are supported assert not PairwiseDistancesReduction.is_usable_for(X.astype(np.float32), Y, metric) assert not PairwiseDistancesReduction.is_usable_for(X, Y.astype(np.int32), metric) @@ -171,7 +244,7 @@ def test_argkmin_factory_method_wrong_usages(): message = ( r"Some metric_kwargs have been passed \({'p': 3}\) but aren't usable for this" r" case \(" - r"FastEuclideanPairwiseDistancesArgKmin\) and will be ignored." + r"FastEuclideanPairwiseDistancesArgKmin." ) with pytest.warns(UserWarning, match=message): @@ -187,23 +260,25 @@ def test_radius_neighborhood_factory_method_wrong_usages(): radius = 5 metric = "euclidean" + msg = ( + "Only 64bit float datasets are supported at this time, " + "got: X.dtype=float32 and Y.dtype=float64" + ) with pytest.raises( ValueError, - match=( - "Only 64bit float datasets are supported at this time, " - "got: X.dtype=float32 and Y.dtype=float64" - ), + match=msg, ): PairwiseDistancesRadiusNeighborhood.compute( X=X.astype(np.float32), Y=Y, radius=radius, metric=metric ) + msg = ( + "Only 64bit float datasets are supported at this time, " + "got: X.dtype=float64 and Y.dtype=int32" + ) with pytest.raises( ValueError, - match=( - "Only 64bit float datasets are supported at this time, " - "got: X.dtype=float64 and Y.dtype=int32" - ), + match=msg, ): PairwiseDistancesRadiusNeighborhood.compute( X=X, Y=Y.astype(np.int32), radius=radius, metric=metric @@ -233,8 +308,7 @@ def test_radius_neighborhood_factory_method_wrong_usages(): message = ( r"Some metric_kwargs have been passed \({'p': 3}\) but aren't usable for this" - r" case \(FastEuclideanPairwiseDistancesRadiusNeighborhood\) and will be" - r" ignored." + r" case \(FastEuclideanPairwiseDistancesRadiusNeighborhood" ) with pytest.warns(UserWarning, match=message): @@ -274,6 +348,7 @@ def test_chunk_size_agnosticism( X, Y, parameter, + metric="manhattan", return_distance=True, ) @@ -282,10 +357,13 @@ def test_chunk_size_agnosticism( Y, parameter, chunk_size=chunk_size, + metric="manhattan", return_distance=True, ) - ASSERT_RESULT[PairwiseDistancesReduction](ref_dist, dist, ref_indices, indices) + ASSERT_RESULT[(PairwiseDistancesReduction, dtype)]( + ref_dist, dist, ref_indices, indices + ) @pytest.mark.parametrize("n_samples", [100, 1000]) @@ -327,7 +405,9 @@ def test_n_threads_agnosticism( X, Y, parameter, return_distance=True ) - ASSERT_RESULT[PairwiseDistancesReduction](ref_dist, dist, ref_indices, indices) + ASSERT_RESULT[(PairwiseDistancesReduction, dtype)]( + ref_dist, dist, ref_indices, indices + ) # TODO: Remove filterwarnings in 1.3 when wminkowski is removed @@ -394,7 +474,7 @@ def test_strategies_consistency( return_distance=True, ) - ASSERT_RESULT[PairwiseDistancesReduction]( + ASSERT_RESULT[(PairwiseDistancesReduction, dtype)]( dist_par_X, dist_par_Y, indices_par_X, @@ -459,8 +539,11 @@ def test_pairwise_distances_argkmin( strategy=strategy, ) - ASSERT_RESULT[PairwiseDistancesArgKmin]( - argkmin_distances, argkmin_distances_ref, argkmin_indices, argkmin_indices_ref + ASSERT_RESULT[(PairwiseDistancesArgKmin, dtype)]( + argkmin_distances, + argkmin_distances_ref, + argkmin_indices, + argkmin_indices_ref, ) @@ -526,7 +609,7 @@ def test_pairwise_distances_radius_neighbors( sort_results=True, ) - ASSERT_RESULT[PairwiseDistancesRadiusNeighborhood]( + ASSERT_RESULT[(PairwiseDistancesRadiusNeighborhood, dtype)]( neigh_distances, neigh_distances_ref, neigh_indices, neigh_indices_ref ) From 4200b994fc65e9eb3f44ff3672edb3298242024a Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Mon, 30 May 2022 12:16:23 +0200 Subject: [PATCH 02/17] TST Add a test for assert_argkmin_results_quasi_equality --- .../test_pairwise_distances_reduction.py | 123 +++++++++++------- 1 file changed, 74 insertions(+), 49 deletions(-) diff --git a/sklearn/metrics/tests/test_pairwise_distances_reduction.py b/sklearn/metrics/tests/test_pairwise_distances_reduction.py index cb698f89bcff2..f2517be147357 100644 --- a/sklearn/metrics/tests/test_pairwise_distances_reduction.py +++ b/sklearn/metrics/tests/test_pairwise_distances_reduction.py @@ -1,3 +1,5 @@ +from collections import defaultdict + import numpy as np import pytest import threadpoolctl @@ -81,62 +83,48 @@ def assert_argkmin_results_equality(ref_dist, dist, ref_indices, indices, rtol=1 def assert_argkmin_results_quasi_equality( - ref_dist, dist, ref_indices, indices, rtol=1e-4 + ref_dist, + dist, + ref_indices, + indices, + rtol=1e-4, + decimals=5, ): + """Assert that argkmin results are valid up to: + - relative tolerance on distances + - permutations of indices for absolutely close distances - ref_dist, dist, ref_indices, indices = map( - np.ndarray.flatten, [ref_dist, dist, ref_indices, indices] - ) + To be used for 32bits datasets tests. + """ assert ( - len(ref_dist) == len(dist) == len(ref_indices) == len(indices) - ), "Arrays of results have various length." + ref_dist.shape == dist.shape == ref_indices.shape == indices.shape + ), "Arrays of results have various shapes." - n = len(ref_dist) + n, k = ref_dist.shape - skip_permutation_check = False + # Asserting equality results one row at a time + for i in range(n): + ref_dist_row = ref_dist[i] + dist_row = dist[i] - for i in range(n - 1): - # We test the equality of pair of adjacent indices and distances - # of the references against the results. - rd_prev, rd_current, rd_next = ref_dist[i - 1], ref_dist[i], ref_dist[i + 1] - d_prev, d_current, d_next = dist[i - 1], dist[i], dist[i + 1] - ri_prev, ri_current, ri_next = ( - ref_indices[i - 1], - ref_indices[i], - ref_indices[i + 1], - ) - i_prev, i_current, i_next = indices[i - 1], indices[i], indices[i + 1] - - assert np.isclose( - d_current, rd_current, rtol=rtol - ), "Query vectors have different neighbors' distances" - - if ri_current != i_current: - # If the current reference index and index are different, - # it might be that their were permuted because their distances - # are relatively close to each other. - # In this case, we need to check for a valid permutation. - valid_permutation = ( - np.isclose(d_current, d_next, rtol=rtol) - and i_next == ri_current - and ri_next == i_current - ) - assert skip_permutation_check or valid_permutation, ( - "Query vectors have different neighbors' indices \n" - f"(i_prev, i_current, i_next) = {i_prev, i_current, i_next} \n" - f"(ri_prev, ri_current, ri_next) = {ri_prev, ri_current, ri_next} \n" - f"(d_prev, d_current, d_next) = {d_prev, d_current, d_next} \n" - f"(rd_prev, rd_current, rd_next) = {rd_prev, rd_current, rd_next} \n" - ) - # If there's a permutation at this iteration, we need to - # skip the following permutation check. - skip_permutation_check = True - continue - - # We need to check for potential permutations for the next iterations. - if skip_permutation_check: - skip_permutation_check = False + assert_allclose(ref_dist_row, dist_row, rtol) + + ref_indices_row = ref_indices[i] + indices_row = indices[i] + + # Grouping indices by distances using sets + ref_mapping = defaultdict(set) + mapping = defaultdict(set) + + for j in range(k): + rounded_dist = np.round(ref_dist_row[j], decimals=decimals) + ref_mapping[rounded_dist].add(ref_indices_row[j]) + mapping[rounded_dist].add(indices_row[j]) + + # Asserting equality of groups (sets) for each distance + for j in ref_mapping.keys(): + assert ref_mapping[j] == mapping[j] def assert_radius_neighborhood_results_equality(ref_dist, dist, ref_indices, indices): @@ -173,6 +161,43 @@ def assert_radius_neighborhood_results_equality(ref_dist, dist, ref_indices, ind } +def test_assert_argkmin_results_quasi_equality(): + + rtol = 1e-7 + atol = 1e-7 + decimals = 6 + + ref_dist = np.array( + [ + [1.2, 2.5, 6.1 + atol, 6.1, 6.1 - atol], + [1.0 + atol, 1.0 - atol, 1, 1.0 + atol, 1.0 - atol], + ] + ) + ref_indices = np.array( + [ + [1, 2, 3, 4, 5], + [6, 7, 8, 9, 10], + ] + ) + + assert_argkmin_results_quasi_equality( + ref_dist, ref_dist, ref_indices, ref_indices, rtol, decimals + ) + + dist = ref_dist + + indices = np.array( + [ + [1, 2, 4, 5, 3], + [6, 9, 7, 8, 10], + ] + ) + + assert_argkmin_results_quasi_equality( + ref_dist, dist, ref_indices, indices, rtol, decimals + ) + + def test_pairwise_distances_reduction_is_usable_for(): rng = np.random.RandomState(0) X = rng.rand(100, 10) From 651bcaed5296389f292067826815d7761781818e Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Wed, 1 Jun 2022 10:56:08 +0200 Subject: [PATCH 03/17] TST Add failing assertion test Co-authored-by: Olivier Grisel --- .../tests/test_pairwise_distances_reduction.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/sklearn/metrics/tests/test_pairwise_distances_reduction.py b/sklearn/metrics/tests/test_pairwise_distances_reduction.py index f2517be147357..404bd47311e45 100644 --- a/sklearn/metrics/tests/test_pairwise_distances_reduction.py +++ b/sklearn/metrics/tests/test_pairwise_distances_reduction.py @@ -186,6 +186,7 @@ def test_assert_argkmin_results_quasi_equality(): dist = ref_dist + # Apply valid permutation on indices indices = np.array( [ [1, 2, 4, 5, 3], @@ -197,6 +198,19 @@ def test_assert_argkmin_results_quasi_equality(): ref_dist, dist, ref_indices, indices, rtol, decimals ) + # Apply invalid permutation on indices + indices = np.array( + [ + [2, 1, 3, 4, 5], + [6, 7, 8, 9, 10], + ] + ) + + msg = "Extra items in the left set" + with pytest.raises(AssertionError, match=msg): + assert_argkmin_results_quasi_equality( + ref_dist, dist, ref_indices, indices, rtol, decimals + ) def test_pairwise_distances_reduction_is_usable_for(): rng = np.random.RandomState(0) From f5937cc72a8ecedc2c8aab55de85d0cb861cd0b4 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Wed, 1 Jun 2022 10:57:08 +0200 Subject: [PATCH 04/17] TST Add assert_radius_neighborhood_results_quasi_equality and test it --- .../test_pairwise_distances_reduction.py | 121 +++++++++++++++++- 1 file changed, 120 insertions(+), 1 deletion(-) diff --git a/sklearn/metrics/tests/test_pairwise_distances_reduction.py b/sklearn/metrics/tests/test_pairwise_distances_reduction.py index 404bd47311e45..c3ebcdf1a79d0 100644 --- a/sklearn/metrics/tests/test_pairwise_distances_reduction.py +++ b/sklearn/metrics/tests/test_pairwise_distances_reduction.py @@ -143,6 +143,61 @@ def assert_radius_neighborhood_results_equality(ref_dist, dist, ref_indices, ind ) +def assert_radius_neighborhood_results_quasi_equality( + ref_dist, + dist, + ref_indices, + indices, + rtol=1e-4, + decimals=5, +): + """Assert that radius neighborhood results are valid up to: + - relative tolerance on distances + - permutations of indices for absolutely close distances + - missing last elements if the threshold + + To be used for 32bits datasets tests. + """ + + assert ( + len(ref_dist) == len(dist) == len(ref_indices) == len(indices) + ), "Arrays of results have various lengths." + + n = len(ref_dist) + + # Asserting equality of results one vector at a time + for i in range(n): + + ref_dist_row = ref_dist[i] + dist_row = dist[i] + + # Vectors' lengths might be different due to small + # numerical differences of distance w.r.t the `radius` threshold, + # so we group vectors to match the smallest. + m = min(len(ref_dist_row), len(dist_row)) + + ref_dist_row = ref_dist_row[:m] + dist_row = dist_row[:m] + + assert_allclose(ref_dist_row, dist_row, rtol) + + ref_indices_row = ref_indices[i] + indices_row = indices[i] + + # Grouping indices by distances using sets + ref_mapping = defaultdict(set) + mapping = defaultdict(set) + + for j in range(m): + rounded_dist = np.round(ref_dist_row[j], decimals=decimals) + ref_mapping[rounded_dist].add(ref_indices_row[j]) + mapping[rounded_dist].add(indices_row[j]) + + # Asserting equality of groups (sets) for each distance + for j in ref_mapping.keys(): + assert ref_mapping[j] == mapping[j] + + ASSERT_RESULT = { # In the case of 64bit, we test for exact equality. (PairwiseDistancesArgKmin, np.float64): assert_argkmin_results_equality, @@ -157,7 +212,7 @@ def assert_radius_neighborhood_results_equality(ref_dist, dist, ref_indices, ind ( PairwiseDistancesRadiusNeighborhood, np.float32, - ): assert_radius_neighborhood_results_equality, + ): assert_radius_neighborhood_results_quasi_equality, } @@ -212,6 +267,70 @@ def test_assert_argkmin_results_quasi_equality(): ref_dist, dist, ref_indices, indices, rtol, decimals ) + +def test_assert_radius_neighborhood_results_quasi_equality(): + + rtol = 1e-7 + atol = 1e-7 + decimals = 6 + + ref_dist = np.array( + [ + np.array([1.2, 2.5, 6.1 + atol, 6.1, 6.1 - atol]), + np.array([1.0 + atol, 1.0 - atol, 1, 1.0 + atol]), + ] + ) + ref_indices = np.array( + [ + [1, 2, 3, 4, 5], + [6, 7, 8, 9], + ] + ) + + assert_radius_neighborhood_results_quasi_equality( + ref_dist, ref_dist, ref_indices, ref_indices, rtol, decimals + ) + + dist = np.copy(ref_dist) + + # Apply valid permutation on indices + indices = np.array( + [ + [1, 2, 4, 5, 3], + [6, 9, 7, 8], + ] + ) + + assert_radius_neighborhood_results_quasi_equality( + ref_dist, dist, ref_indices, indices, rtol, decimals + ) + + # Apply invalid permutation on indices + indices = np.array( + [ + [2, 1, 3, 4, 5], + [6, 7, 8, 9], + ] + ) + + msg = "Extra items in the left set" + with pytest.raises(AssertionError, match=msg): + assert_radius_neighborhood_results_quasi_equality( + ref_dist, dist, ref_indices, indices, rtol, decimals + ) + + # Having missing last element is valid + dist = np.copy(ref_dist) + indices = np.copy(ref_indices) + + dist[0] = dist[0][:-1] + indices[0] = indices[0][:-1] + + assert_radius_neighborhood_results_quasi_equality( + ref_dist, dist, ref_indices, indices, rtol, decimals + ) + + def test_pairwise_distances_reduction_is_usable_for(): rng = np.random.RandomState(0) X = rng.rand(100, 10) From 6664bb56a5d6afff39f6a4fe901d459bbd7ea4e6 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Wed, 1 Jun 2022 11:40:33 +0200 Subject: [PATCH 05/17] TST Add more failing assertion tests cases Co-authored-by: Olivier Grisel --- .../test_pairwise_distances_reduction.py | 115 +++++++++++++++--- 1 file changed, 100 insertions(+), 15 deletions(-) diff --git a/sklearn/metrics/tests/test_pairwise_distances_reduction.py b/sklearn/metrics/tests/test_pairwise_distances_reduction.py index c3ebcdf1a79d0..0a51ef815f39a 100644 --- a/sklearn/metrics/tests/test_pairwise_distances_reduction.py +++ b/sklearn/metrics/tests/test_pairwise_distances_reduction.py @@ -1,3 +1,4 @@ +import copy from collections import defaultdict import numpy as np @@ -96,6 +97,7 @@ def assert_argkmin_results_quasi_equality( To be used for 32bits datasets tests. """ + is_sorted = lambda a: np.all(a[:-1] - a[1:]) assert ( ref_dist.shape == dist.shape == ref_indices.shape == indices.shape @@ -108,6 +110,9 @@ def assert_argkmin_results_quasi_equality( ref_dist_row = ref_dist[i] dist_row = dist[i] + is_sorted(ref_dist_row), f"Reference distances aren't sorted on row {i}" + is_sorted(dist_row), f"Distances aren't sorted on row {i}" + assert_allclose(ref_dist_row, dist_row, rtol) ref_indices_row = ref_indices[i] @@ -157,7 +162,10 @@ def assert_radius_neighborhood_results_quasi_equality( - missing last elements if the threshold To be used for 32bits datasets tests. + + Input arrays must be sorted w.r.t distances. """ + is_sorted = lambda a: np.all(a[:-1] - a[1:]) assert ( len(ref_dist) == len(dist) == len(ref_indices) == len(indices) @@ -171,6 +179,9 @@ def assert_radius_neighborhood_results_quasi_equality( ref_dist_row = ref_dist[i] dist_row = dist[i] + is_sorted(ref_dist_row), f"Reference distances aren't sorted on row {i}" + is_sorted(dist_row), f"Distances aren't sorted on row {i}" + # Vectors' lengths might be different due to small # numerical differences of distance w.r.t the `radius` threshold, # so we group vectors to match the smallest. @@ -224,8 +235,8 @@ def test_assert_argkmin_results_quasi_equality(): ref_dist = np.array( [ - [1.2, 2.5, 6.1 + atol, 6.1, 6.1 - atol], - [1.0 + atol, 1.0 - atol, 1, 1.0 + atol, 1.0 - atol], + [1.2, 2.5, 6.1 - atol, 6.1, 6.1 + atol], + [1.0 - atol, 1.0 - atol, 1, 1.0 + atol, 1.0 + atol], ] ) ref_indices = np.array( @@ -267,6 +278,43 @@ def test_assert_argkmin_results_quasi_equality(): ref_dist, dist, ref_indices, indices, rtol, decimals ) + indices = np.copy(ref_indices) + dist = np.copy(ref_dist) + + # Indices aren't properly sorted w.r.t their distances + indices[0][0], indices[0][1] = indices[0][1], indices[0][0] + + msg = "Extra items in the left set" + with pytest.raises(AssertionError, match=msg): + assert_argkmin_results_quasi_equality( + ref_dist, dist, ref_indices, indices, rtol, decimals + ) + + indices = np.copy(ref_indices) + dist = np.copy(ref_dist) + + # Distances aren't properly sorted + dist[0][0], dist[0][1] = dist[0][1], dist[0][0] + + msg = "Mismatched elements: 2 / 5" + with pytest.raises(AssertionError, match=msg): + assert_argkmin_results_quasi_equality( + ref_dist, dist, ref_indices, indices, rtol, decimals + ) + + indices = np.copy(ref_indices) + dist = np.copy(ref_dist) + + # Indices and distances aren't properly sorted + indices[0][0], indices[0][1] = indices[0][1], indices[0][0] + dist[0][0], dist[0][1] = dist[0][1], dist[0][0] + + msg = "Mismatched elements: 2 / 5" + with pytest.raises(AssertionError, match=msg): + assert_argkmin_results_quasi_equality( + ref_dist, dist, ref_indices, indices, rtol, decimals + ) + def test_assert_radius_neighborhood_results_quasi_equality(): @@ -276,14 +324,14 @@ def test_assert_radius_neighborhood_results_quasi_equality(): ref_dist = np.array( [ - np.array([1.2, 2.5, 6.1 + atol, 6.1, 6.1 - atol]), - np.array([1.0 + atol, 1.0 - atol, 1, 1.0 + atol]), + np.array([1.2, 2.5, 6.1 - atol, 6.1, 6.1 + atol]), + np.array([1.0 - atol, 1, 1.0 + atol, 1.0 + atol]), ] ) ref_indices = np.array( [ - [1, 2, 3, 4, 5], - [6, 7, 8, 9], + np.array([1, 2, 3, 4, 5]), + np.array([6, 7, 8, 9]), ] ) @@ -291,13 +339,13 @@ def test_assert_radius_neighborhood_results_quasi_equality(): ref_dist, ref_dist, ref_indices, ref_indices, rtol, decimals ) - dist = np.copy(ref_dist) + dist = copy.deepcopy(ref_dist) # Apply valid permutation on indices indices = np.array( [ - [1, 2, 4, 5, 3], - [6, 9, 7, 8], + np.array([1, 2, 4, 5, 3]), + np.array([6, 9, 7, 8]), ] ) @@ -308,8 +356,8 @@ def test_assert_radius_neighborhood_results_quasi_equality(): # Apply invalid permutation on indices indices = np.array( [ - [2, 1, 3, 4, 5], - [6, 7, 8, 9], + np.array([2, 1, 3, 4, 5]), + np.array([6, 7, 8, 9]), ] ) @@ -319,17 +367,54 @@ def test_assert_radius_neighborhood_results_quasi_equality(): ref_dist, dist, ref_indices, indices, rtol, decimals ) - # Having missing last element is valid - dist = np.copy(ref_dist) - indices = np.copy(ref_indices) + # Having missing last elements is valid + indices = copy.deepcopy(ref_indices) + dist = copy.deepcopy(ref_dist) - dist[0] = dist[0][:-1] indices[0] = indices[0][:-1] + dist[0] = dist[0][:-1] assert_radius_neighborhood_results_quasi_equality( ref_dist, dist, ref_indices, indices, rtol, decimals ) + indices = copy.deepcopy(ref_indices) + dist = copy.deepcopy(ref_dist) + + # Indices aren't properly sorted w.r.t their distances + indices[0][0], indices[0][1] = indices[0][1], indices[0][0] + + msg = "Extra items in the left set" + with pytest.raises(AssertionError, match=msg): + assert_radius_neighborhood_results_quasi_equality( + ref_dist, dist, ref_indices, indices, rtol, decimals + ) + + indices = copy.deepcopy(ref_indices) + dist = copy.deepcopy(ref_dist) + + # Distances aren't properly sorted + dist[0][0], dist[0][1] = dist[0][1], dist[0][0] + + msg = "Mismatched elements: 2 / 5" + with pytest.raises(AssertionError, match=msg): + assert_radius_neighborhood_results_quasi_equality( + ref_dist, dist, ref_indices, indices, rtol, decimals + ) + + indices = copy.deepcopy(ref_indices) + dist = copy.deepcopy(ref_dist) + + # Indices and distances aren't properly sorted + indices[0][0], indices[0][1] = indices[0][1], indices[0][0] + dist[0][0], dist[0][1] = dist[0][1], dist[0][0] + + msg = "Mismatched elements: 2 / 5" + with pytest.raises(AssertionError, match=msg): + assert_radius_neighborhood_results_quasi_equality( + ref_dist, dist, ref_indices, indices, rtol, decimals + ) + def test_pairwise_distances_reduction_is_usable_for(): rng = np.random.RandomState(0) From e5ca94c3e0db459baa8e7f7eb9902c109197d7ce Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Wed, 1 Jun 2022 11:43:06 +0200 Subject: [PATCH 06/17] TST Correct is_sorted --- sklearn/metrics/tests/test_pairwise_distances_reduction.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/metrics/tests/test_pairwise_distances_reduction.py b/sklearn/metrics/tests/test_pairwise_distances_reduction.py index 0a51ef815f39a..b2608113f46a3 100644 --- a/sklearn/metrics/tests/test_pairwise_distances_reduction.py +++ b/sklearn/metrics/tests/test_pairwise_distances_reduction.py @@ -97,7 +97,7 @@ def assert_argkmin_results_quasi_equality( To be used for 32bits datasets tests. """ - is_sorted = lambda a: np.all(a[:-1] - a[1:]) + is_sorted = lambda a: np.all(a[:-1] - a[1:] <= 0) assert ( ref_dist.shape == dist.shape == ref_indices.shape == indices.shape @@ -165,7 +165,7 @@ def assert_radius_neighborhood_results_quasi_equality( Input arrays must be sorted w.r.t distances. """ - is_sorted = lambda a: np.all(a[:-1] - a[1:]) + is_sorted = lambda a: np.all(a[:-1] - a[1:] <= 0) assert ( len(ref_dist) == len(dist) == len(ref_indices) == len(indices) From 1ff1b73fa79b8bb1c8f22078b8b598fa271ffae6 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Wed, 1 Jun 2022 11:51:52 +0200 Subject: [PATCH 07/17] TST Correct assertions on sort --- .../tests/test_pairwise_distances_reduction.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/sklearn/metrics/tests/test_pairwise_distances_reduction.py b/sklearn/metrics/tests/test_pairwise_distances_reduction.py index b2608113f46a3..aba2e21fc1ada 100644 --- a/sklearn/metrics/tests/test_pairwise_distances_reduction.py +++ b/sklearn/metrics/tests/test_pairwise_distances_reduction.py @@ -110,8 +110,8 @@ def assert_argkmin_results_quasi_equality( ref_dist_row = ref_dist[i] dist_row = dist[i] - is_sorted(ref_dist_row), f"Reference distances aren't sorted on row {i}" - is_sorted(dist_row), f"Distances aren't sorted on row {i}" + assert is_sorted(ref_dist_row), f"Reference distances aren't sorted on row {i}" + assert is_sorted(dist_row), f"Distances aren't sorted on row {i}" assert_allclose(ref_dist_row, dist_row, rtol) @@ -179,8 +179,8 @@ def assert_radius_neighborhood_results_quasi_equality( ref_dist_row = ref_dist[i] dist_row = dist[i] - is_sorted(ref_dist_row), f"Reference distances aren't sorted on row {i}" - is_sorted(dist_row), f"Distances aren't sorted on row {i}" + assert is_sorted(ref_dist_row), f"Reference distances aren't sorted on row {i}" + assert is_sorted(dist_row), f"Distances aren't sorted on row {i}" # Vectors' lengths might be different due to small # numerical differences of distance w.r.t the `radius` threshold, @@ -250,7 +250,7 @@ def test_assert_argkmin_results_quasi_equality(): ref_dist, ref_dist, ref_indices, ref_indices, rtol, decimals ) - dist = ref_dist + dist = np.copy(ref_dist) # Apply valid permutation on indices indices = np.array( @@ -296,7 +296,7 @@ def test_assert_argkmin_results_quasi_equality(): # Distances aren't properly sorted dist[0][0], dist[0][1] = dist[0][1], dist[0][0] - msg = "Mismatched elements: 2 / 5" + msg = "Distances aren't sorted on row 0" with pytest.raises(AssertionError, match=msg): assert_argkmin_results_quasi_equality( ref_dist, dist, ref_indices, indices, rtol, decimals @@ -309,7 +309,7 @@ def test_assert_argkmin_results_quasi_equality(): indices[0][0], indices[0][1] = indices[0][1], indices[0][0] dist[0][0], dist[0][1] = dist[0][1], dist[0][0] - msg = "Mismatched elements: 2 / 5" + msg = "Distances aren't sorted on row 0" with pytest.raises(AssertionError, match=msg): assert_argkmin_results_quasi_equality( ref_dist, dist, ref_indices, indices, rtol, decimals @@ -396,7 +396,7 @@ def test_assert_radius_neighborhood_results_quasi_equality(): # Distances aren't properly sorted dist[0][0], dist[0][1] = dist[0][1], dist[0][0] - msg = "Mismatched elements: 2 / 5" + msg = "Distances aren't sorted on row 0" with pytest.raises(AssertionError, match=msg): assert_radius_neighborhood_results_quasi_equality( ref_dist, dist, ref_indices, indices, rtol, decimals @@ -409,7 +409,7 @@ def test_assert_radius_neighborhood_results_quasi_equality(): indices[0][0], indices[0][1] = indices[0][1], indices[0][0] dist[0][0], dist[0][1] = dist[0][1], dist[0][0] - msg = "Mismatched elements: 2 / 5" + msg = "Distances aren't sorted on row 0" with pytest.raises(AssertionError, match=msg): assert_radius_neighborhood_results_quasi_equality( ref_dist, dist, ref_indices, indices, rtol, decimals From 57ae169ddd6902abbac628cd4bb705263342510a Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Tue, 7 Jun 2022 16:54:43 +0200 Subject: [PATCH 08/17] TST Inline values in assertion and complete comments Co-authored-by: Olivier Grisel --- .../test_pairwise_distances_reduction.py | 374 ++++++++++++++---- 1 file changed, 292 insertions(+), 82 deletions(-) diff --git a/sklearn/metrics/tests/test_pairwise_distances_reduction.py b/sklearn/metrics/tests/test_pairwise_distances_reduction.py index aba2e21fc1ada..09898fe0c9ce7 100644 --- a/sklearn/metrics/tests/test_pairwise_distances_reduction.py +++ b/sklearn/metrics/tests/test_pairwise_distances_reduction.py @@ -1,4 +1,3 @@ -import copy from collections import defaultdict import numpy as np @@ -118,7 +117,8 @@ def assert_argkmin_results_quasi_equality( ref_indices_row = ref_indices[i] indices_row = indices[i] - # Grouping indices by distances using sets + # Grouping indices by distances using sets on + # a rounded distances up to a given number of decimals ref_mapping = defaultdict(set) mapping = defaultdict(set) @@ -195,7 +195,8 @@ def assert_radius_neighborhood_results_quasi_equality( ref_indices_row = ref_indices[i] indices_row = indices[i] - # Grouping indices by distances using sets + # Grouping indices by distances using sets on + # a rounded distances up to a given number of decimals ref_mapping = defaultdict(set) mapping = defaultdict(set) @@ -250,69 +251,166 @@ def test_assert_argkmin_results_quasi_equality(): ref_dist, ref_dist, ref_indices, ref_indices, rtol, decimals ) - dist = np.copy(ref_dist) - # Apply valid permutation on indices - indices = np.array( - [ - [1, 2, 4, 5, 3], - [6, 9, 7, 8, 10], - ] + assert_argkmin_results_quasi_equality( + ref_dist=np.array( + [ + [1.2, 2.5, 6.1 - atol, 6.1, 6.1 + atol], + ] + ), + dist=np.array( + [ + [1.2, 2.5, 6.1 - atol, 6.1, 6.1 + atol], + ] + ), + ref_indices=np.array( + [ + [1, 2, 3, 4, 5], + ] + ), + indices=np.array( + [ + [1, 2, 4, 5, 3], + ] + ), + rtol=rtol, + decimals=decimals, ) - assert_argkmin_results_quasi_equality( - ref_dist, dist, ref_indices, indices, rtol, decimals + ref_dist=np.array( + [ + [1.0 - atol, 1.0 - atol, 1, 1.0 + atol, 1.0 + atol], + ] + ), + dist=np.array( + [ + [1.0 - atol, 1.0 - atol, 1, 1.0 + atol, 1.0 + atol], + ] + ), + ref_indices=np.array( + [ + [6, 7, 8, 9, 10], + ] + ), + indices=np.array( + [ + [6, 9, 7, 8, 10], + ] + ), + rtol=rtol, + decimals=decimals, ) # Apply invalid permutation on indices - indices = np.array( - [ - [2, 1, 3, 4, 5], - [6, 7, 8, 9, 10], - ] - ) - msg = "Extra items in the left set" with pytest.raises(AssertionError, match=msg): assert_argkmin_results_quasi_equality( - ref_dist, dist, ref_indices, indices, rtol, decimals + ref_dist=np.array( + [ + [1.2, 2.5, 6.1 - atol, 6.1, 6.1 + atol], + ] + ), + dist=np.array( + [ + [1.2, 2.5, 6.1 - atol, 6.1, 6.1 + atol], + ] + ), + ref_indices=np.array( + [ + [1, 2, 3, 4, 5], + ] + ), + indices=np.array( + [ + [2, 1, 3, 4, 5], + ] + ), + rtol=rtol, + decimals=decimals, ) - indices = np.copy(ref_indices) - dist = np.copy(ref_dist) - # Indices aren't properly sorted w.r.t their distances - indices[0][0], indices[0][1] = indices[0][1], indices[0][0] - msg = "Extra items in the left set" with pytest.raises(AssertionError, match=msg): assert_argkmin_results_quasi_equality( - ref_dist, dist, ref_indices, indices, rtol, decimals + ref_dist=np.array( + [ + [1.2, 2.5, 6.1 - atol, 6.1, 6.1 + atol], + ] + ), + dist=np.array( + [ + [1.2, 2.5, 6.1 - atol, 6.1, 6.1 + atol], + ] + ), + ref_indices=np.array( + [ + [1, 2, 3, 4, 5], + ] + ), + indices=np.array( + [ + [2, 1, 4, 5, 3], + ] + ), + rtol=rtol, + decimals=decimals, ) - indices = np.copy(ref_indices) - dist = np.copy(ref_dist) - # Distances aren't properly sorted - dist[0][0], dist[0][1] = dist[0][1], dist[0][0] - msg = "Distances aren't sorted on row 0" with pytest.raises(AssertionError, match=msg): assert_argkmin_results_quasi_equality( - ref_dist, dist, ref_indices, indices, rtol, decimals + ref_dist=np.array( + [ + [1.2, 2.5, 6.1 - atol, 6.1, 6.1 + atol], + ] + ), + dist=np.array( + [ + [2.5, 1.2, 6.1 - atol, 6.1, 6.1 + atol], + ] + ), + ref_indices=np.array( + [ + [1, 2, 3, 4, 5], + ] + ), + indices=np.array( + [ + [1, 2, 4, 5, 3], + ] + ), + rtol=rtol, + decimals=decimals, ) - indices = np.copy(ref_indices) - dist = np.copy(ref_dist) - # Indices and distances aren't properly sorted - indices[0][0], indices[0][1] = indices[0][1], indices[0][0] - dist[0][0], dist[0][1] = dist[0][1], dist[0][0] - msg = "Distances aren't sorted on row 0" with pytest.raises(AssertionError, match=msg): assert_argkmin_results_quasi_equality( - ref_dist, dist, ref_indices, indices, rtol, decimals + ref_dist=np.array( + [ + [1.2, 2.5, 6.1 - atol, 6.1, 6.1 + atol], + ] + ), + dist=np.array( + [ + [2.5, 1.2, 6.1 - atol, 6.1, 6.1 + atol], + ] + ), + ref_indices=np.array( + [ + [1, 2, 3, 4, 5], + ] + ), + indices=np.array( + [ + [2, 1, 4, 5, 3], + ] + ), + rtol=rtol, + decimals=decimals, ) @@ -339,80 +437,192 @@ def test_assert_radius_neighborhood_results_quasi_equality(): ref_dist, ref_dist, ref_indices, ref_indices, rtol, decimals ) - dist = copy.deepcopy(ref_dist) - # Apply valid permutation on indices - indices = np.array( - [ - np.array([1, 2, 4, 5, 3]), - np.array([6, 9, 7, 8]), - ] + assert_radius_neighborhood_results_quasi_equality( + ref_dist=np.array( + [ + np.array([1.2, 2.5, 6.1 - atol, 6.1, 6.1 + atol]), + ] + ), + dist=np.array( + [ + np.array([1.2, 2.5, 6.1 - atol, 6.1, 6.1 + atol]), + ] + ), + ref_indices=np.array( + [ + np.array([1, 2, 3, 4, 5]), + ] + ), + indices=np.array( + [ + np.array([1, 2, 4, 5, 3]), + ] + ), + rtol=rtol, + decimals=decimals, ) - assert_radius_neighborhood_results_quasi_equality( - ref_dist, dist, ref_indices, indices, rtol, decimals + ref_dist=np.array( + [ + np.array([1.0 - atol, 1.0 - atol, 1, 1.0 + atol, 1.0 + atol]), + ] + ), + dist=np.array( + [ + np.array([1.0 - atol, 1.0 - atol, 1, 1.0 + atol, 1.0 + atol]), + ] + ), + ref_indices=np.array( + [ + np.array([6, 7, 8, 9, 10]), + ] + ), + indices=np.array( + [ + np.array([6, 9, 7, 8, 10]), + ] + ), + rtol=rtol, + decimals=decimals, ) # Apply invalid permutation on indices - indices = np.array( - [ - np.array([2, 1, 3, 4, 5]), - np.array([6, 7, 8, 9]), - ] - ) - msg = "Extra items in the left set" with pytest.raises(AssertionError, match=msg): assert_radius_neighborhood_results_quasi_equality( - ref_dist, dist, ref_indices, indices, rtol, decimals + ref_dist=np.array( + [ + np.array([1.2, 2.5, 6.1 - atol, 6.1, 6.1 + atol]), + ] + ), + dist=np.array( + [ + np.array([1.2, 2.5, 6.1 - atol, 6.1, 6.1 + atol]), + ] + ), + ref_indices=np.array( + [ + np.array([1, 2, 3, 4, 5]), + ] + ), + indices=np.array( + [ + np.array([2, 1, 3, 4, 5]), + ] + ), + rtol=rtol, + decimals=decimals, ) # Having missing last elements is valid - indices = copy.deepcopy(ref_indices) - dist = copy.deepcopy(ref_dist) - - indices[0] = indices[0][:-1] - dist[0] = dist[0][:-1] - assert_radius_neighborhood_results_quasi_equality( - ref_dist, dist, ref_indices, indices, rtol, decimals + ref_dist=np.array( + [ + np.array([1.2, 2.5, 6.1 - atol, 6.1, 6.1 + atol]), + ] + ), + dist=np.array( + [ + np.array([1.2, 2.5, 6.1 - atol, 6.1]), + ] + ), + ref_indices=np.array( + [ + np.array([1, 2, 3, 4, 5]), + ] + ), + indices=np.array( + [ + np.array([1, 2, 3, 4]), + ] + ), + rtol=rtol, + decimals=decimals, ) - indices = copy.deepcopy(ref_indices) - dist = copy.deepcopy(ref_dist) - # Indices aren't properly sorted w.r.t their distances - indices[0][0], indices[0][1] = indices[0][1], indices[0][0] - msg = "Extra items in the left set" with pytest.raises(AssertionError, match=msg): assert_radius_neighborhood_results_quasi_equality( - ref_dist, dist, ref_indices, indices, rtol, decimals + ref_dist=np.array( + [ + np.array([1.2, 2.5, 6.1 - atol, 6.1, 6.1 + atol]), + ] + ), + dist=np.array( + [ + np.array([1.2, 2.5, 6.1 - atol, 6.1, 6.1 + atol]), + ] + ), + ref_indices=np.array( + [ + np.array([1, 2, 3, 4, 5]), + ] + ), + indices=np.array( + [ + np.array([2, 1, 4, 5, 3]), + ] + ), + rtol=rtol, + decimals=decimals, ) - indices = copy.deepcopy(ref_indices) - dist = copy.deepcopy(ref_dist) - # Distances aren't properly sorted - dist[0][0], dist[0][1] = dist[0][1], dist[0][0] - msg = "Distances aren't sorted on row 0" with pytest.raises(AssertionError, match=msg): assert_radius_neighborhood_results_quasi_equality( - ref_dist, dist, ref_indices, indices, rtol, decimals + ref_dist=np.array( + [ + np.array([1.2, 2.5, 6.1 - atol, 6.1, 6.1 + atol]), + ] + ), + dist=np.array( + [ + np.array([2.5, 1.2, 6.1 - atol, 6.1, 6.1 + atol]), + ] + ), + ref_indices=np.array( + [ + np.array([1, 2, 3, 4, 5]), + ] + ), + indices=np.array( + [ + np.array([1, 2, 4, 5, 3]), + ] + ), + rtol=rtol, + decimals=decimals, ) - indices = copy.deepcopy(ref_indices) - dist = copy.deepcopy(ref_dist) - # Indices and distances aren't properly sorted - indices[0][0], indices[0][1] = indices[0][1], indices[0][0] - dist[0][0], dist[0][1] = dist[0][1], dist[0][0] - msg = "Distances aren't sorted on row 0" with pytest.raises(AssertionError, match=msg): assert_radius_neighborhood_results_quasi_equality( - ref_dist, dist, ref_indices, indices, rtol, decimals + ref_dist=np.array( + [ + np.array([1.2, 2.5, 6.1 - atol, 6.1, 6.1 + atol]), + ] + ), + dist=np.array( + [ + np.array([2.5, 1.2, 6.1 - atol, 6.1, 6.1 + atol]), + ] + ), + ref_indices=np.array( + [ + np.array([1, 2, 3, 4, 5]), + ] + ), + indices=np.array( + [ + np.array([2, 1, 4, 5, 3]), + ] + ), + rtol=rtol, + decimals=decimals, ) From b1ef2f573d81238370cef357fb6325b0bb74138d Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Tue, 7 Jun 2022 17:31:07 +0200 Subject: [PATCH 09/17] TST Complete assertion with radius Co-authored-by: Olivier Grisel --- .../test_pairwise_distances_reduction.py | 82 ++++++++++++------- 1 file changed, 52 insertions(+), 30 deletions(-) diff --git a/sklearn/metrics/tests/test_pairwise_distances_reduction.py b/sklearn/metrics/tests/test_pairwise_distances_reduction.py index 09898fe0c9ce7..c291716629374 100644 --- a/sklearn/metrics/tests/test_pairwise_distances_reduction.py +++ b/sklearn/metrics/tests/test_pairwise_distances_reduction.py @@ -132,9 +132,12 @@ def assert_argkmin_results_quasi_equality( assert ref_mapping[j] == mapping[j] -def assert_radius_neighborhood_results_equality(ref_dist, dist, ref_indices, indices): +def assert_radius_neighborhood_results_equality( + ref_dist, dist, ref_indices, indices, radius +): # We get arrays of arrays and we need to check for individual pairs for i in range(ref_dist.shape[0]): + assert (ref_dist[i] <= radius).all() assert_array_equal( ref_indices[i], indices[i], @@ -153,6 +156,7 @@ def assert_radius_neighborhood_results_quasi_equality( dist, ref_indices, indices, + radius, rtol=1e-4, decimals=5, ): @@ -183,14 +187,21 @@ def assert_radius_neighborhood_results_quasi_equality( assert is_sorted(dist_row), f"Distances aren't sorted on row {i}" # Vectors' lengths might be different due to small - # numerical differences of distance w.r.t the `radius` threshold, - # so we group vectors to match the smallest. - m = min(len(ref_dist_row), len(dist_row)) + # numerical differences of distance w.r.t the `radius` threshold. + largest_row = ref_dist_row if len(ref_dist_row) > len(dist_row) else dist_row - ref_dist_row = ref_dist_row[:m] - dist_row = dist_row[:m] + # For the longest distances vector, we check that elements that aren't present + # in the other vector are all in: [radius ± atol] + atol = 10 ** (-decimals) + min_length = min(len(ref_dist_row), len(dist_row)) + assert np.all(radius - atol <= largest_row[min_length:] <= radius + atol) - assert_allclose(ref_dist_row, dist_row, rtol) + # We truncate the neighbors results list on the smallest length to + # be able to compare them, ignoring the elements checked above. + ref_dist_row = ref_dist_row[:min_length] + dist_row = dist_row[:min_length] + + assert_allclose(ref_dist_row, dist_row, rtol=rtol) ref_indices_row = ref_indices[i] indices_row = indices[i] @@ -200,7 +211,7 @@ def assert_radius_neighborhood_results_quasi_equality( ref_mapping = defaultdict(set) mapping = defaultdict(set) - for j in range(m): + for j in range(min_length): rounded_dist = np.round(ref_dist_row[j], decimals=decimals) ref_mapping[rounded_dist].add(ref_indices_row[j]) mapping[rounded_dist].add(indices_row[j]) @@ -419,6 +430,7 @@ def test_assert_radius_neighborhood_results_quasi_equality(): rtol = 1e-7 atol = 1e-7 decimals = 6 + radius = 6.1 ref_dist = np.array( [ @@ -459,6 +471,7 @@ def test_assert_radius_neighborhood_results_quasi_equality(): np.array([1, 2, 4, 5, 3]), ] ), + radius=radius, rtol=rtol, decimals=decimals, ) @@ -483,6 +496,7 @@ def test_assert_radius_neighborhood_results_quasi_equality(): np.array([6, 9, 7, 8, 10]), ] ), + radius=radius, rtol=rtol, decimals=decimals, ) @@ -511,6 +525,7 @@ def test_assert_radius_neighborhood_results_quasi_equality(): np.array([2, 1, 3, 4, 5]), ] ), + radius=radius, rtol=rtol, decimals=decimals, ) @@ -537,6 +552,7 @@ def test_assert_radius_neighborhood_results_quasi_equality(): np.array([1, 2, 3, 4]), ] ), + radius=radius, rtol=rtol, decimals=decimals, ) @@ -565,6 +581,7 @@ def test_assert_radius_neighborhood_results_quasi_equality(): np.array([2, 1, 4, 5, 3]), ] ), + radius=radius, rtol=rtol, decimals=decimals, ) @@ -593,6 +610,7 @@ def test_assert_radius_neighborhood_results_quasi_equality(): np.array([1, 2, 4, 5, 3]), ] ), + radius=radius, rtol=rtol, decimals=decimals, ) @@ -621,6 +639,7 @@ def test_assert_radius_neighborhood_results_quasi_equality(): np.array([2, 1, 4, 5, 3]), ] ), + radius=radius, rtol=rtol, decimals=decimals, ) @@ -790,12 +809,14 @@ def test_chunk_size_agnosticism( X = rng.rand(n_samples, n_features).astype(dtype) * spread Y = rng.rand(n_samples, n_features).astype(dtype) * spread - parameter = ( - 10 - if PairwiseDistancesReduction is PairwiseDistancesArgKmin + if PairwiseDistancesReduction is PairwiseDistancesArgKmin: + parameter = 10 + check_parameters = {} + else: # Scaling the radius slightly with the numbers of dimensions - else 10 ** np.log(n_features) - ) + radius = 10 ** np.log(n_features) + parameter = radius + check_parameters = {"radius": radius} ref_dist, ref_indices = PairwiseDistancesReduction.compute( X, @@ -815,7 +836,7 @@ def test_chunk_size_agnosticism( ) ASSERT_RESULT[(PairwiseDistancesReduction, dtype)]( - ref_dist, dist, ref_indices, indices + ref_dist, dist, ref_indices, indices, **check_parameters ) @@ -839,12 +860,14 @@ def test_n_threads_agnosticism( X = rng.rand(n_samples, n_features).astype(dtype) * spread Y = rng.rand(n_samples, n_features).astype(dtype) * spread - parameter = ( - 10 - if PairwiseDistancesReduction is PairwiseDistancesArgKmin + if PairwiseDistancesReduction is PairwiseDistancesArgKmin: + parameter = 10 + check_parameters = {} + else: # Scaling the radius slightly with the numbers of dimensions - else 10 ** np.log(n_features) - ) + radius = 10 ** np.log(n_features) + parameter = radius + check_parameters = {"radius": radius} ref_dist, ref_indices = PairwiseDistancesReduction.compute( X, @@ -859,7 +882,7 @@ def test_n_threads_agnosticism( ) ASSERT_RESULT[(PairwiseDistancesReduction, dtype)]( - ref_dist, dist, ref_indices, indices + ref_dist, dist, ref_indices, indices, **check_parameters ) @@ -890,12 +913,14 @@ def test_strategies_consistency( X = np.ascontiguousarray(X[:, :2]) Y = np.ascontiguousarray(Y[:, :2]) - parameter = ( - 10 - if PairwiseDistancesReduction is PairwiseDistancesArgKmin + if PairwiseDistancesReduction is PairwiseDistancesArgKmin: + parameter = 10 + check_parameters = {} + else: # Scaling the radius slightly with the numbers of dimensions - else 10 ** np.log(n_features) - ) + radius = 10 ** np.log(n_features) + parameter = radius + check_parameters = {"radius": radius} dist_par_X, indices_par_X = PairwiseDistancesReduction.compute( X, @@ -928,10 +953,7 @@ def test_strategies_consistency( ) ASSERT_RESULT[(PairwiseDistancesReduction, dtype)]( - dist_par_X, - dist_par_Y, - indices_par_X, - indices_par_Y, + dist_par_X, dist_par_Y, indices_par_X, indices_par_Y, **check_parameters ) @@ -1063,7 +1085,7 @@ def test_pairwise_distances_radius_neighbors( ) ASSERT_RESULT[(PairwiseDistancesRadiusNeighborhood, dtype)]( - neigh_distances, neigh_distances_ref, neigh_indices, neigh_indices_ref + neigh_distances, neigh_distances_ref, neigh_indices, neigh_indices_ref, radius ) From 5d4fb18369504ba3b23202e31738ac03db400725 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Tue, 7 Jun 2022 17:33:11 +0200 Subject: [PATCH 10/17] TST Rename variables Co-authored-by: Olivier Grisel --- sklearn/metrics/tests/test_pairwise_distances_reduction.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sklearn/metrics/tests/test_pairwise_distances_reduction.py b/sklearn/metrics/tests/test_pairwise_distances_reduction.py index c291716629374..eee660b66c980 100644 --- a/sklearn/metrics/tests/test_pairwise_distances_reduction.py +++ b/sklearn/metrics/tests/test_pairwise_distances_reduction.py @@ -102,10 +102,10 @@ def assert_argkmin_results_quasi_equality( ref_dist.shape == dist.shape == ref_indices.shape == indices.shape ), "Arrays of results have various shapes." - n, k = ref_dist.shape + n_queries, n_neighbors = ref_dist.shape # Asserting equality results one row at a time - for i in range(n): + for i in range(n_queries): ref_dist_row = ref_dist[i] dist_row = dist[i] @@ -122,7 +122,7 @@ def assert_argkmin_results_quasi_equality( ref_mapping = defaultdict(set) mapping = defaultdict(set) - for j in range(k): + for j in range(n_neighbors): rounded_dist = np.round(ref_dist_row[j], decimals=decimals) ref_mapping[rounded_dist].add(ref_indices_row[j]) mapping[rounded_dist].add(indices_row[j]) From 15bcf6210fb6afa03969cbc74bcb751fe60207a8 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Tue, 7 Jun 2022 17:59:21 +0200 Subject: [PATCH 11/17] Improve comments Co-authored-by: Olivier Grisel --- .../test_pairwise_distances_reduction.py | 21 +++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/sklearn/metrics/tests/test_pairwise_distances_reduction.py b/sklearn/metrics/tests/test_pairwise_distances_reduction.py index eee660b66c980..d52202363fe8d 100644 --- a/sklearn/metrics/tests/test_pairwise_distances_reduction.py +++ b/sklearn/metrics/tests/test_pairwise_distances_reduction.py @@ -92,9 +92,12 @@ def assert_argkmin_results_quasi_equality( ): """Assert that argkmin results are valid up to: - relative tolerance on distances - - permutations of indices for absolutely close distances + - permutations of indices for distances values that differ up to + a precision level set by `decimals`. - To be used for 32bits datasets tests. + To be used for testing neighbor queries on float32 datasets: we + accept neighbors rank swaps only if they are caused by small + rounding errors on the distance computations. """ is_sorted = lambda a: np.all(a[:-1] - a[1:] <= 0) @@ -112,7 +115,7 @@ def assert_argkmin_results_quasi_equality( assert is_sorted(ref_dist_row), f"Reference distances aren't sorted on row {i}" assert is_sorted(dist_row), f"Distances aren't sorted on row {i}" - assert_allclose(ref_dist_row, dist_row, rtol) + assert_allclose(ref_dist_row, dist_row, rtol=rtol) ref_indices_row = ref_indices[i] indices_row = indices[i] @@ -262,7 +265,9 @@ def test_assert_argkmin_results_quasi_equality(): ref_dist, ref_dist, ref_indices, ref_indices, rtol, decimals ) - # Apply valid permutation on indices + # Apply valid permutation on indices: the last 3 points are + # all very close to one another so we accept any permutation + # on their rankings. assert_argkmin_results_quasi_equality( ref_dist=np.array( [ @@ -271,7 +276,7 @@ def test_assert_argkmin_results_quasi_equality(): ), dist=np.array( [ - [1.2, 2.5, 6.1 - atol, 6.1, 6.1 + atol], + [1.2, 2.5, 6.1, 6.1, 6.1], ] ), ref_indices=np.array( @@ -287,6 +292,8 @@ def test_assert_argkmin_results_quasi_equality(): rtol=rtol, decimals=decimals, ) + # All points are have close distances so any ranking permutation + # is valid for this query result. assert_argkmin_results_quasi_equality( ref_dist=np.array( [ @@ -312,7 +319,9 @@ def test_assert_argkmin_results_quasi_equality(): decimals=decimals, ) - # Apply invalid permutation on indices + # Apply invalid permutation on indices: permuting the ranks + # of the 2 nearest neighbors is invalid because the distance + # values are too different. msg = "Extra items in the left set" with pytest.raises(AssertionError, match=msg): assert_argkmin_results_quasi_equality( From d9013e936b4307895ce3b6a9618fd6567a7ccf5f Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Tue, 7 Jun 2022 17:56:10 +0200 Subject: [PATCH 12/17] TST Improve some more --- .../test_pairwise_distances_reduction.py | 98 ++++++++----------- 1 file changed, 42 insertions(+), 56 deletions(-) diff --git a/sklearn/metrics/tests/test_pairwise_distances_reduction.py b/sklearn/metrics/tests/test_pairwise_distances_reduction.py index d52202363fe8d..0b3c83b1b7e3c 100644 --- a/sklearn/metrics/tests/test_pairwise_distances_reduction.py +++ b/sklearn/metrics/tests/test_pairwise_distances_reduction.py @@ -1,3 +1,4 @@ +import re from collections import defaultdict import numpy as np @@ -131,8 +132,12 @@ def assert_argkmin_results_quasi_equality( mapping[rounded_dist].add(indices_row[j]) # Asserting equality of groups (sets) for each distance + msg = ( + f"Neighbors indices for query {i} are not matching " + f"when rounding distances at decimals={decimals}" + ) for j in ref_mapping.keys(): - assert ref_mapping[j] == mapping[j] + assert ref_mapping[j] == mapping[j], msg def assert_radius_neighborhood_results_equality( @@ -193,11 +198,15 @@ def assert_radius_neighborhood_results_quasi_equality( # numerical differences of distance w.r.t the `radius` threshold. largest_row = ref_dist_row if len(ref_dist_row) > len(dist_row) else dist_row - # For the longest distances vector, we check that elements that aren't present - # in the other vector are all in: [radius ± atol] + # For the longest distances vector, we check that last extra elements + # that aren't present in the other vector are all in: [radius ± atol] atol = 10 ** (-decimals) min_length = min(len(ref_dist_row), len(dist_row)) - assert np.all(radius - atol <= largest_row[min_length:] <= radius + atol) + last_extra_elements = largest_row[min_length:] + assert np.all(radius - atol <= last_extra_elements <= radius + atol), ( + f"The last extra elements ({last_extra_elements}) aren't in [radius ±" + f" atol]=[{radius} ± {atol}]" + ) # We truncate the neighbors results list on the smallest length to # be able to compare them, ignoring the elements checked above. @@ -220,8 +229,12 @@ def assert_radius_neighborhood_results_quasi_equality( mapping[rounded_dist].add(indices_row[j]) # Asserting equality of groups (sets) for each distance + msg = ( + f"Neighbors indices for query {i} are not matching " + f"when rounding distances at decimals={decimals}" + ) for j in ref_mapping.keys(): - assert ref_mapping[j] == mapping[j] + assert ref_mapping[j] == mapping[j], msg ASSERT_RESULT = { @@ -322,7 +335,7 @@ def test_assert_argkmin_results_quasi_equality(): # Apply invalid permutation on indices: permuting the ranks # of the 2 nearest neighbors is invalid because the distance # values are too different. - msg = "Extra items in the left set" + msg = "Neighbors indices for query 0 are not matching" with pytest.raises(AssertionError, match=msg): assert_argkmin_results_quasi_equality( ref_dist=np.array( @@ -350,7 +363,7 @@ def test_assert_argkmin_results_quasi_equality(): ) # Indices aren't properly sorted w.r.t their distances - msg = "Extra items in the left set" + msg = "Neighbors indices for query 0 are not matching" with pytest.raises(AssertionError, match=msg): assert_argkmin_results_quasi_equality( ref_dist=np.array( @@ -379,34 +392,6 @@ def test_assert_argkmin_results_quasi_equality(): # Distances aren't properly sorted msg = "Distances aren't sorted on row 0" - with pytest.raises(AssertionError, match=msg): - assert_argkmin_results_quasi_equality( - ref_dist=np.array( - [ - [1.2, 2.5, 6.1 - atol, 6.1, 6.1 + atol], - ] - ), - dist=np.array( - [ - [2.5, 1.2, 6.1 - atol, 6.1, 6.1 + atol], - ] - ), - ref_indices=np.array( - [ - [1, 2, 3, 4, 5], - ] - ), - indices=np.array( - [ - [1, 2, 4, 5, 3], - ] - ), - rtol=rtol, - decimals=decimals, - ) - - # Indices and distances aren't properly sorted - msg = "Distances aren't sorted on row 0" with pytest.raises(AssertionError, match=msg): assert_argkmin_results_quasi_equality( ref_dist=np.array( @@ -439,7 +424,6 @@ def test_assert_radius_neighborhood_results_quasi_equality(): rtol = 1e-7 atol = 1e-7 decimals = 6 - radius = 6.1 ref_dist = np.array( [ @@ -480,7 +464,7 @@ def test_assert_radius_neighborhood_results_quasi_equality(): np.array([1, 2, 4, 5, 3]), ] ), - radius=radius, + radius=6.1, rtol=rtol, decimals=decimals, ) @@ -505,13 +489,13 @@ def test_assert_radius_neighborhood_results_quasi_equality(): np.array([6, 9, 7, 8, 10]), ] ), - radius=radius, + radius=6.1, rtol=rtol, decimals=decimals, ) # Apply invalid permutation on indices - msg = "Extra items in the left set" + msg = "Neighbors indices for query 0 are not matching" with pytest.raises(AssertionError, match=msg): assert_radius_neighborhood_results_quasi_equality( ref_dist=np.array( @@ -534,12 +518,12 @@ def test_assert_radius_neighborhood_results_quasi_equality(): np.array([2, 1, 3, 4, 5]), ] ), - radius=radius, + radius=6.1, rtol=rtol, decimals=decimals, ) - # Having missing last elements is valid + # Having extra last elements is valid if they are in: [radius ± atol] assert_radius_neighborhood_results_quasi_equality( ref_dist=np.array( [ @@ -561,42 +545,44 @@ def test_assert_radius_neighborhood_results_quasi_equality(): np.array([1, 2, 3, 4]), ] ), - radius=radius, + radius=6.1, rtol=rtol, decimals=decimals, ) - # Indices aren't properly sorted w.r.t their distances - msg = "Extra items in the left set" + # Having extra last elements is invalid if they are lesser than radius - atol + msg = re.escape( + "The last extra elements ([6.]) aren't in [radius ± atol]=[6.1 ± 1e-06]" + ) with pytest.raises(AssertionError, match=msg): assert_radius_neighborhood_results_quasi_equality( ref_dist=np.array( [ - np.array([1.2, 2.5, 6.1 - atol, 6.1, 6.1 + atol]), + np.array([1.2, 2.5, 6]), ] ), dist=np.array( [ - np.array([1.2, 2.5, 6.1 - atol, 6.1, 6.1 + atol]), + np.array([1.2, 2.5]), ] ), ref_indices=np.array( [ - np.array([1, 2, 3, 4, 5]), + np.array([1, 2, 3]), ] ), indices=np.array( [ - np.array([2, 1, 4, 5, 3]), + np.array([1, 2]), ] ), - radius=radius, + radius=6.1, rtol=rtol, decimals=decimals, ) - # Distances aren't properly sorted - msg = "Distances aren't sorted on row 0" + # Indices aren't properly sorted w.r.t their distances + msg = "Neighbors indices for query 0 are not matching" with pytest.raises(AssertionError, match=msg): assert_radius_neighborhood_results_quasi_equality( ref_dist=np.array( @@ -606,7 +592,7 @@ def test_assert_radius_neighborhood_results_quasi_equality(): ), dist=np.array( [ - np.array([2.5, 1.2, 6.1 - atol, 6.1, 6.1 + atol]), + np.array([1.2, 2.5, 6.1 - atol, 6.1, 6.1 + atol]), ] ), ref_indices=np.array( @@ -616,15 +602,15 @@ def test_assert_radius_neighborhood_results_quasi_equality(): ), indices=np.array( [ - np.array([1, 2, 4, 5, 3]), + np.array([2, 1, 4, 5, 3]), ] ), - radius=radius, + radius=6.1, rtol=rtol, decimals=decimals, ) - # Indices and distances aren't properly sorted + # Distances aren't properly sorted msg = "Distances aren't sorted on row 0" with pytest.raises(AssertionError, match=msg): assert_radius_neighborhood_results_quasi_equality( @@ -648,7 +634,7 @@ def test_assert_radius_neighborhood_results_quasi_equality(): np.array([2, 1, 4, 5, 3]), ] ), - radius=radius, + radius=6.1, rtol=rtol, decimals=decimals, ) From 93496ae0c145f23f55670dbb84f21516a1fecf22 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Wed, 8 Jun 2022 11:42:59 +0200 Subject: [PATCH 13/17] TST Apply review comments Co-authored-by: Olivier Grisel --- .../test_pairwise_distances_reduction.py | 120 ++++++++++++------ 1 file changed, 83 insertions(+), 37 deletions(-) diff --git a/sklearn/metrics/tests/test_pairwise_distances_reduction.py b/sklearn/metrics/tests/test_pairwise_distances_reduction.py index 0b3c83b1b7e3c..c7a882e063786 100644 --- a/sklearn/metrics/tests/test_pairwise_distances_reduction.py +++ b/sklearn/metrics/tests/test_pairwise_distances_reduction.py @@ -4,6 +4,7 @@ import numpy as np import pytest import threadpoolctl +from math import log10, floor from numpy.testing import assert_array_equal, assert_allclose from scipy.sparse import csr_matrix from scipy.spatial.distance import cdist @@ -83,6 +84,23 @@ def assert_argkmin_results_equality(ref_dist, dist, ref_indices, indices, rtol=1 ) +def adaptive_rounding(scalar, n_significant_digits): + """Round a scalar to a number of significant digits adaptively to its value.""" + magnitude = int(floor(log10(abs(scalar)))) + 1 + return round(scalar, n_significant_digits - magnitude) + + +def test_adaptive_rounding(): + + assert adaptive_rounding(123456789, 2) == 120000000 + assert adaptive_rounding(123456789, 3) == 123000000 + assert adaptive_rounding(123456789, 10) == 123456789 + + assert adaptive_rounding(1.23456789, 2) == 1.2 + assert adaptive_rounding(1.23456789, 3) == 1.23 + assert adaptive_rounding(1.23456789, 10) == 1.23456789 + + def assert_argkmin_results_quasi_equality( ref_dist, dist, @@ -92,11 +110,11 @@ def assert_argkmin_results_quasi_equality( decimals=5, ): """Assert that argkmin results are valid up to: - - relative tolerance on distances + - relative tolerance on computed distance values - permutations of indices for distances values that differ up to a precision level set by `decimals`. - To be used for testing neighbor queries on float32 datasets: we + To be used for testing neighbors queries on float32 datasets: we accept neighbors rank swaps only if they are caused by small rounding errors on the distance computations. """ @@ -123,21 +141,26 @@ def assert_argkmin_results_quasi_equality( # Grouping indices by distances using sets on # a rounded distances up to a given number of decimals - ref_mapping = defaultdict(set) - mapping = defaultdict(set) + reference_neighbors_groups = defaultdict(set) + effective_neighbors_groups = defaultdict(set) - for j in range(n_neighbors): - rounded_dist = np.round(ref_dist_row[j], decimals=decimals) - ref_mapping[rounded_dist].add(ref_indices_row[j]) - mapping[rounded_dist].add(indices_row[j]) + for neighbor_rank in range(n_neighbors): + rounded_dist = adaptive_rounding( + ref_dist_row[neighbor_rank], n_significant_digits=decimals + ) + reference_neighbors_groups[rounded_dist].add(ref_indices_row[neighbor_rank]) + effective_neighbors_groups[rounded_dist].add(indices_row[neighbor_rank]) # Asserting equality of groups (sets) for each distance msg = ( f"Neighbors indices for query {i} are not matching " f"when rounding distances at decimals={decimals}" ) - for j in ref_mapping.keys(): - assert ref_mapping[j] == mapping[j], msg + for rounded_distance in reference_neighbors_groups.keys(): + assert ( + reference_neighbors_groups[rounded_distance] + == effective_neighbors_groups[rounded_distance] + ), msg def assert_radius_neighborhood_results_equality( @@ -169,11 +192,15 @@ def assert_radius_neighborhood_results_quasi_equality( decimals=5, ): """Assert that radius neighborhood results are valid up to: - - relative tolerance on distances - - permutations of indices for absolutely close distances - - missing last elements if the threshold + - relative tolerance on computed distance values + - permutations of indices for distances values that + differ up to a precision level set by `decimals` + - missing or extra last elements if their distance is + close to the radius - To be used for 32bits datasets tests. + To be used for testing neighbors queries on float32 datasets: we + accept neighbors rank swaps only if they are caused by small + rounding errors on the distance computations. Input arrays must be sorted w.r.t distances. """ @@ -183,16 +210,18 @@ def assert_radius_neighborhood_results_quasi_equality( len(ref_dist) == len(dist) == len(ref_indices) == len(indices) ), "Arrays of results have various lengths." - n = len(ref_dist) + n_queries = len(ref_dist) # Asserting equality of results one vector at a time - for i in range(n): + for query_idx in range(n_queries): - ref_dist_row = ref_dist[i] - dist_row = dist[i] + ref_dist_row = ref_dist[query_idx] + dist_row = dist[query_idx] - assert is_sorted(ref_dist_row), f"Reference distances aren't sorted on row {i}" - assert is_sorted(dist_row), f"Distances aren't sorted on row {i}" + assert is_sorted( + ref_dist_row + ), f"Reference distances aren't sorted on row {query_idx}" + assert is_sorted(dist_row), f"Distances aren't sorted on row {query_idx}" # Vectors' lengths might be different due to small # numerical differences of distance w.r.t the `radius` threshold. @@ -203,42 +232,51 @@ def assert_radius_neighborhood_results_quasi_equality( atol = 10 ** (-decimals) min_length = min(len(ref_dist_row), len(dist_row)) last_extra_elements = largest_row[min_length:] - assert np.all(radius - atol <= last_extra_elements <= radius + atol), ( - f"The last extra elements ({last_extra_elements}) aren't in [radius ±" - f" atol]=[{radius} ± {atol}]" - ) + if last_extra_elements.size > 0: + assert np.all(radius - atol <= last_extra_elements <= radius + atol), ( + f"The last extra elements ({last_extra_elements}) aren't in [radius ±" + f" atol]=[{radius} ± {atol}]" + ) # We truncate the neighbors results list on the smallest length to # be able to compare them, ignoring the elements checked above. ref_dist_row = ref_dist_row[:min_length] dist_row = dist_row[:min_length] + print(type(ref_dist_row)) + print(type(dist_row)) assert_allclose(ref_dist_row, dist_row, rtol=rtol) - ref_indices_row = ref_indices[i] - indices_row = indices[i] + ref_indices_row = ref_indices[query_idx] + indices_row = indices[query_idx] # Grouping indices by distances using sets on # a rounded distances up to a given number of decimals - ref_mapping = defaultdict(set) - mapping = defaultdict(set) + reference_neighbors_groups = defaultdict(set) + effective_neighbors_groups = defaultdict(set) - for j in range(min_length): - rounded_dist = np.round(ref_dist_row[j], decimals=decimals) - ref_mapping[rounded_dist].add(ref_indices_row[j]) - mapping[rounded_dist].add(indices_row[j]) + for neighbor_rank in range(min_length): + rounded_dist = adaptive_rounding( + ref_dist_row[neighbor_rank], n_significant_digits=decimals + ) + reference_neighbors_groups[rounded_dist].add(ref_indices_row[neighbor_rank]) + effective_neighbors_groups[rounded_dist].add(indices_row[neighbor_rank]) # Asserting equality of groups (sets) for each distance msg = ( - f"Neighbors indices for query {i} are not matching " + f"Neighbors indices for query {query_idx} are not matching " f"when rounding distances at decimals={decimals}" ) - for j in ref_mapping.keys(): - assert ref_mapping[j] == mapping[j], msg + for rounded_distance in reference_neighbors_groups.keys(): + assert ( + reference_neighbors_groups[rounded_distance] + == effective_neighbors_groups[rounded_distance] + ), msg ASSERT_RESULT = { - # In the case of 64bit, we test for exact equality. + # In the case of 64bit, we test for exact equality of the results rankings + # and standard tolerance levels for the computed distance values. (PairwiseDistancesArgKmin, np.float64): assert_argkmin_results_equality, ( PairwiseDistancesRadiusNeighborhood, @@ -274,6 +312,7 @@ def test_assert_argkmin_results_quasi_equality(): ] ) + # Sanity check: compare the reference results to themselves. assert_argkmin_results_quasi_equality( ref_dist, ref_dist, ref_indices, ref_indices, rtol, decimals ) @@ -438,8 +477,15 @@ def test_assert_radius_neighborhood_results_quasi_equality(): ] ) + # Sanity check: compare the reference results to themselves. assert_radius_neighborhood_results_quasi_equality( - ref_dist, ref_dist, ref_indices, ref_indices, rtol, decimals + ref_dist, + ref_dist, + ref_indices, + ref_indices, + radius=6.1, + rtol=rtol, + decimals=decimals, ) # Apply valid permutation on indices From 589ecae1df31d9ba9c624b184b972dcfe1c1095c Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Wed, 8 Jun 2022 16:24:38 +0200 Subject: [PATCH 14/17] Apply review comments MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Jérémie du Boisberranger Co-authored-by: Olivier Grisel --- .../test_pairwise_distances_reduction.py | 140 ++++++++++-------- 1 file changed, 77 insertions(+), 63 deletions(-) diff --git a/sklearn/metrics/tests/test_pairwise_distances_reduction.py b/sklearn/metrics/tests/test_pairwise_distances_reduction.py index c7a882e063786..56e01017cb5f2 100644 --- a/sklearn/metrics/tests/test_pairwise_distances_reduction.py +++ b/sklearn/metrics/tests/test_pairwise_distances_reduction.py @@ -5,7 +5,6 @@ import pytest import threadpoolctl from math import log10, floor -from numpy.testing import assert_array_equal, assert_allclose from scipy.sparse import csr_matrix from scipy.spatial.distance import cdist @@ -18,6 +17,10 @@ from sklearn.metrics import euclidean_distances from sklearn.utils.fixes import sp_version, parse_version +from sklearn.utils._testing import ( + assert_array_equal, + assert_allclose, +) # Common supported metric between scipy.spatial.distance.cdist # and PairwiseDistancesReduction. @@ -84,21 +87,27 @@ def assert_argkmin_results_equality(ref_dist, dist, ref_indices, indices, rtol=1 ) -def adaptive_rounding(scalar, n_significant_digits): - """Round a scalar to a number of significant digits adaptively to its value.""" +def relative_rounding(scalar, n_significant_digits): + """Round a scalar to a number of significant digits relatively to its value.""" magnitude = int(floor(log10(abs(scalar)))) + 1 return round(scalar, n_significant_digits - magnitude) -def test_adaptive_rounding(): +def test_relative_rounding(): + + assert relative_rounding(123456789, 0) == 0 + assert relative_rounding(123456789, 2) == 120000000 + assert relative_rounding(123456789, 3) == 123000000 + assert relative_rounding(123456789, 10) == 123456789 + assert relative_rounding(123456789, 20) == 123456789 - assert adaptive_rounding(123456789, 2) == 120000000 - assert adaptive_rounding(123456789, 3) == 123000000 - assert adaptive_rounding(123456789, 10) == 123456789 + assert relative_rounding(1.23456789, 2) == 1.2 + assert relative_rounding(1.23456789, 3) == 1.23 + assert relative_rounding(1.23456789, 10) == 1.23456789 - assert adaptive_rounding(1.23456789, 2) == 1.2 - assert adaptive_rounding(1.23456789, 3) == 1.23 - assert adaptive_rounding(1.23456789, 10) == 1.23456789 + assert relative_rounding(123.456789, 3) == 123.0 + assert relative_rounding(123.456789, 9) == 123.456789 + assert relative_rounding(123.456789, 10) == 123.456789 def assert_argkmin_results_quasi_equality( @@ -107,18 +116,19 @@ def assert_argkmin_results_quasi_equality( ref_indices, indices, rtol=1e-4, - decimals=5, ): """Assert that argkmin results are valid up to: - relative tolerance on computed distance values - permutations of indices for distances values that differ up to - a precision level set by `decimals`. + a precision level To be used for testing neighbors queries on float32 datasets: we accept neighbors rank swaps only if they are caused by small rounding errors on the distance computations. """ - is_sorted = lambda a: np.all(a[:-1] - a[1:] <= 0) + is_sorted = lambda a: np.all(a[:-1] <= a[1:]) + + n_significant_digits = -(int(floor(log10(abs(rtol)))) + 1) assert ( ref_dist.shape == dist.shape == ref_indices.shape == indices.shape @@ -145,8 +155,9 @@ def assert_argkmin_results_quasi_equality( effective_neighbors_groups = defaultdict(set) for neighbor_rank in range(n_neighbors): - rounded_dist = adaptive_rounding( - ref_dist_row[neighbor_rank], n_significant_digits=decimals + rounded_dist = relative_rounding( + ref_dist_row[neighbor_rank], + n_significant_digits=n_significant_digits, ) reference_neighbors_groups[rounded_dist].add(ref_indices_row[neighbor_rank]) effective_neighbors_groups[rounded_dist].add(indices_row[neighbor_rank]) @@ -154,7 +165,7 @@ def assert_argkmin_results_quasi_equality( # Asserting equality of groups (sets) for each distance msg = ( f"Neighbors indices for query {i} are not matching " - f"when rounding distances at decimals={decimals}" + f"when rounding distances at n_significant_digits={n_significant_digits}" ) for rounded_distance in reference_neighbors_groups.keys(): assert ( @@ -189,7 +200,6 @@ def assert_radius_neighborhood_results_quasi_equality( indices, radius, rtol=1e-4, - decimals=5, ): """Assert that radius neighborhood results are valid up to: - relative tolerance on computed distance values @@ -204,7 +214,9 @@ def assert_radius_neighborhood_results_quasi_equality( Input arrays must be sorted w.r.t distances. """ - is_sorted = lambda a: np.all(a[:-1] - a[1:] <= 0) + is_sorted = lambda a: np.all(a[:-1] <= a[1:]) + + n_significant_digits = -(int(floor(log10(abs(rtol)))) + 1) assert ( len(ref_dist) == len(dist) == len(ref_indices) == len(indices) @@ -229,7 +241,7 @@ def assert_radius_neighborhood_results_quasi_equality( # For the longest distances vector, we check that last extra elements # that aren't present in the other vector are all in: [radius ± atol] - atol = 10 ** (-decimals) + atol = 10 ** (-n_significant_digits) min_length = min(len(ref_dist_row), len(dist_row)) last_extra_elements = largest_row[min_length:] if last_extra_elements.size > 0: @@ -256,8 +268,9 @@ def assert_radius_neighborhood_results_quasi_equality( effective_neighbors_groups = defaultdict(set) for neighbor_rank in range(min_length): - rounded_dist = adaptive_rounding( - ref_dist_row[neighbor_rank], n_significant_digits=decimals + rounded_dist = relative_rounding( + ref_dist_row[neighbor_rank], + n_significant_digits=n_significant_digits, ) reference_neighbors_groups[rounded_dist].add(ref_indices_row[neighbor_rank]) effective_neighbors_groups[rounded_dist].add(indices_row[neighbor_rank]) @@ -265,7 +278,7 @@ def assert_radius_neighborhood_results_quasi_equality( # Asserting equality of groups (sets) for each distance msg = ( f"Neighbors indices for query {query_idx} are not matching " - f"when rounding distances at decimals={decimals}" + f"when rounding distances at n_significant_digits={n_significant_digits}" ) for rounded_distance in reference_neighbors_groups.keys(): assert ( @@ -277,6 +290,11 @@ def assert_radius_neighborhood_results_quasi_equality( ASSERT_RESULT = { # In the case of 64bit, we test for exact equality of the results rankings # and standard tolerance levels for the computed distance values. + # + # XXX: Note that in the future we might be interested in using quasi equality + # checks also for float64 data (with a larger number of significant digits) + # as the tests could be unstable because of numerically tied distances on + # some datasets (e.g. uniform grids). (PairwiseDistancesArgKmin, np.float64): assert_argkmin_results_equality, ( PairwiseDistancesRadiusNeighborhood, @@ -297,12 +315,16 @@ def test_assert_argkmin_results_quasi_equality(): rtol = 1e-7 atol = 1e-7 - decimals = 6 + _1m = 1.0 - atol + _1p = 1.0 + atol + + _6_1m = 6.1 - atol + _6_1p = 6.1 + atol ref_dist = np.array( [ - [1.2, 2.5, 6.1 - atol, 6.1, 6.1 + atol], - [1.0 - atol, 1.0 - atol, 1, 1.0 + atol, 1.0 + atol], + [1.2, 2.5, _6_1m, 6.1, _6_1p], + [_1m, _1m, 1, _1p, _1p], ] ) ref_indices = np.array( @@ -314,7 +336,7 @@ def test_assert_argkmin_results_quasi_equality(): # Sanity check: compare the reference results to themselves. assert_argkmin_results_quasi_equality( - ref_dist, ref_dist, ref_indices, ref_indices, rtol, decimals + ref_dist, ref_dist, ref_indices, ref_indices, rtol ) # Apply valid permutation on indices: the last 3 points are @@ -323,7 +345,7 @@ def test_assert_argkmin_results_quasi_equality(): assert_argkmin_results_quasi_equality( ref_dist=np.array( [ - [1.2, 2.5, 6.1 - atol, 6.1, 6.1 + atol], + [1.2, 2.5, _6_1m, 6.1, _6_1p], ] ), dist=np.array( @@ -342,19 +364,18 @@ def test_assert_argkmin_results_quasi_equality(): ] ), rtol=rtol, - decimals=decimals, ) # All points are have close distances so any ranking permutation # is valid for this query result. assert_argkmin_results_quasi_equality( ref_dist=np.array( [ - [1.0 - atol, 1.0 - atol, 1, 1.0 + atol, 1.0 + atol], + [_1m, _1m, 1, _1p, _1p], ] ), dist=np.array( [ - [1.0 - atol, 1.0 - atol, 1, 1.0 + atol, 1.0 + atol], + [_1m, _1m, 1, _1p, _1p], ] ), ref_indices=np.array( @@ -368,7 +389,6 @@ def test_assert_argkmin_results_quasi_equality(): ] ), rtol=rtol, - decimals=decimals, ) # Apply invalid permutation on indices: permuting the ranks @@ -379,12 +399,12 @@ def test_assert_argkmin_results_quasi_equality(): assert_argkmin_results_quasi_equality( ref_dist=np.array( [ - [1.2, 2.5, 6.1 - atol, 6.1, 6.1 + atol], + [1.2, 2.5, _6_1m, 6.1, _6_1p], ] ), dist=np.array( [ - [1.2, 2.5, 6.1 - atol, 6.1, 6.1 + atol], + [1.2, 2.5, _6_1m, 6.1, _6_1p], ] ), ref_indices=np.array( @@ -398,7 +418,6 @@ def test_assert_argkmin_results_quasi_equality(): ] ), rtol=rtol, - decimals=decimals, ) # Indices aren't properly sorted w.r.t their distances @@ -407,12 +426,12 @@ def test_assert_argkmin_results_quasi_equality(): assert_argkmin_results_quasi_equality( ref_dist=np.array( [ - [1.2, 2.5, 6.1 - atol, 6.1, 6.1 + atol], + [1.2, 2.5, _6_1m, 6.1, _6_1p], ] ), dist=np.array( [ - [1.2, 2.5, 6.1 - atol, 6.1, 6.1 + atol], + [1.2, 2.5, _6_1m, 6.1, _6_1p], ] ), ref_indices=np.array( @@ -426,7 +445,6 @@ def test_assert_argkmin_results_quasi_equality(): ] ), rtol=rtol, - decimals=decimals, ) # Distances aren't properly sorted @@ -435,12 +453,12 @@ def test_assert_argkmin_results_quasi_equality(): assert_argkmin_results_quasi_equality( ref_dist=np.array( [ - [1.2, 2.5, 6.1 - atol, 6.1, 6.1 + atol], + [1.2, 2.5, _6_1m, 6.1, _6_1p], ] ), dist=np.array( [ - [2.5, 1.2, 6.1 - atol, 6.1, 6.1 + atol], + [2.5, 1.2, _6_1m, 6.1, _6_1p], ] ), ref_indices=np.array( @@ -454,7 +472,6 @@ def test_assert_argkmin_results_quasi_equality(): ] ), rtol=rtol, - decimals=decimals, ) @@ -462,12 +479,17 @@ def test_assert_radius_neighborhood_results_quasi_equality(): rtol = 1e-7 atol = 1e-7 - decimals = 6 + + _1m = 1.0 - atol + _1p = 1.0 + atol + + _6_1m = 6.1 - atol + _6_1p = 6.1 + atol ref_dist = np.array( [ - np.array([1.2, 2.5, 6.1 - atol, 6.1, 6.1 + atol]), - np.array([1.0 - atol, 1, 1.0 + atol, 1.0 + atol]), + np.array([1.2, 2.5, _6_1m, 6.1, _6_1p]), + np.array([_1m, 1, _1p, _1p]), ] ) ref_indices = np.array( @@ -485,19 +507,18 @@ def test_assert_radius_neighborhood_results_quasi_equality(): ref_indices, radius=6.1, rtol=rtol, - decimals=decimals, ) # Apply valid permutation on indices assert_radius_neighborhood_results_quasi_equality( ref_dist=np.array( [ - np.array([1.2, 2.5, 6.1 - atol, 6.1, 6.1 + atol]), + np.array([1.2, 2.5, _6_1m, 6.1, _6_1p]), ] ), dist=np.array( [ - np.array([1.2, 2.5, 6.1 - atol, 6.1, 6.1 + atol]), + np.array([1.2, 2.5, _6_1m, 6.1, _6_1p]), ] ), ref_indices=np.array( @@ -512,17 +533,16 @@ def test_assert_radius_neighborhood_results_quasi_equality(): ), radius=6.1, rtol=rtol, - decimals=decimals, ) assert_radius_neighborhood_results_quasi_equality( ref_dist=np.array( [ - np.array([1.0 - atol, 1.0 - atol, 1, 1.0 + atol, 1.0 + atol]), + np.array([_1m, _1m, 1, _1p, _1p]), ] ), dist=np.array( [ - np.array([1.0 - atol, 1.0 - atol, 1, 1.0 + atol, 1.0 + atol]), + np.array([_1m, _1m, 1, _1p, _1p]), ] ), ref_indices=np.array( @@ -537,7 +557,6 @@ def test_assert_radius_neighborhood_results_quasi_equality(): ), radius=6.1, rtol=rtol, - decimals=decimals, ) # Apply invalid permutation on indices @@ -546,12 +565,12 @@ def test_assert_radius_neighborhood_results_quasi_equality(): assert_radius_neighborhood_results_quasi_equality( ref_dist=np.array( [ - np.array([1.2, 2.5, 6.1 - atol, 6.1, 6.1 + atol]), + np.array([1.2, 2.5, _6_1m, 6.1, _6_1p]), ] ), dist=np.array( [ - np.array([1.2, 2.5, 6.1 - atol, 6.1, 6.1 + atol]), + np.array([1.2, 2.5, _6_1m, 6.1, _6_1p]), ] ), ref_indices=np.array( @@ -566,19 +585,18 @@ def test_assert_radius_neighborhood_results_quasi_equality(): ), radius=6.1, rtol=rtol, - decimals=decimals, ) # Having extra last elements is valid if they are in: [radius ± atol] assert_radius_neighborhood_results_quasi_equality( ref_dist=np.array( [ - np.array([1.2, 2.5, 6.1 - atol, 6.1, 6.1 + atol]), + np.array([1.2, 2.5, _6_1m, 6.1, _6_1p]), ] ), dist=np.array( [ - np.array([1.2, 2.5, 6.1 - atol, 6.1]), + np.array([1.2, 2.5, _6_1m, 6.1]), ] ), ref_indices=np.array( @@ -593,7 +611,6 @@ def test_assert_radius_neighborhood_results_quasi_equality(): ), radius=6.1, rtol=rtol, - decimals=decimals, ) # Having extra last elements is invalid if they are lesser than radius - atol @@ -624,7 +641,6 @@ def test_assert_radius_neighborhood_results_quasi_equality(): ), radius=6.1, rtol=rtol, - decimals=decimals, ) # Indices aren't properly sorted w.r.t their distances @@ -633,12 +649,12 @@ def test_assert_radius_neighborhood_results_quasi_equality(): assert_radius_neighborhood_results_quasi_equality( ref_dist=np.array( [ - np.array([1.2, 2.5, 6.1 - atol, 6.1, 6.1 + atol]), + np.array([1.2, 2.5, _6_1m, 6.1, _6_1p]), ] ), dist=np.array( [ - np.array([1.2, 2.5, 6.1 - atol, 6.1, 6.1 + atol]), + np.array([1.2, 2.5, _6_1m, 6.1, _6_1p]), ] ), ref_indices=np.array( @@ -653,7 +669,6 @@ def test_assert_radius_neighborhood_results_quasi_equality(): ), radius=6.1, rtol=rtol, - decimals=decimals, ) # Distances aren't properly sorted @@ -662,12 +677,12 @@ def test_assert_radius_neighborhood_results_quasi_equality(): assert_radius_neighborhood_results_quasi_equality( ref_dist=np.array( [ - np.array([1.2, 2.5, 6.1 - atol, 6.1, 6.1 + atol]), + np.array([1.2, 2.5, _6_1m, 6.1, _6_1p]), ] ), dist=np.array( [ - np.array([2.5, 1.2, 6.1 - atol, 6.1, 6.1 + atol]), + np.array([2.5, 1.2, _6_1m, 6.1, _6_1p]), ] ), ref_indices=np.array( @@ -682,7 +697,6 @@ def test_assert_radius_neighborhood_results_quasi_equality(): ), radius=6.1, rtol=rtol, - decimals=decimals, ) From badc58e5a15c89d87cf788eb0e7233fec45dee67 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Wed, 8 Jun 2022 17:10:14 +0200 Subject: [PATCH 15/17] Improve readability Co-authored-by: Olivier Grisel --- .../test_pairwise_distances_reduction.py | 264 +++--------------- 1 file changed, 44 insertions(+), 220 deletions(-) diff --git a/sklearn/metrics/tests/test_pairwise_distances_reduction.py b/sklearn/metrics/tests/test_pairwise_distances_reduction.py index 56e01017cb5f2..15b8d8ab6db94 100644 --- a/sklearn/metrics/tests/test_pairwise_distances_reduction.py +++ b/sklearn/metrics/tests/test_pairwise_distances_reduction.py @@ -343,51 +343,19 @@ def test_assert_argkmin_results_quasi_equality(): # all very close to one another so we accept any permutation # on their rankings. assert_argkmin_results_quasi_equality( - ref_dist=np.array( - [ - [1.2, 2.5, _6_1m, 6.1, _6_1p], - ] - ), - dist=np.array( - [ - [1.2, 2.5, 6.1, 6.1, 6.1], - ] - ), - ref_indices=np.array( - [ - [1, 2, 3, 4, 5], - ] - ), - indices=np.array( - [ - [1, 2, 4, 5, 3], - ] - ), + np.array([[1.2, 2.5, _6_1m, 6.1, _6_1p]]), + np.array([[1.2, 2.5, 6.1, 6.1, 6.1]]), + np.array([[1, 2, 3, 4, 5]]), + np.array([[1, 2, 4, 5, 3]]), rtol=rtol, ) # All points are have close distances so any ranking permutation # is valid for this query result. assert_argkmin_results_quasi_equality( - ref_dist=np.array( - [ - [_1m, _1m, 1, _1p, _1p], - ] - ), - dist=np.array( - [ - [_1m, _1m, 1, _1p, _1p], - ] - ), - ref_indices=np.array( - [ - [6, 7, 8, 9, 10], - ] - ), - indices=np.array( - [ - [6, 9, 7, 8, 10], - ] - ), + np.array([[_1m, _1m, 1, _1p, _1p]]), + np.array([[_1m, _1m, 1, _1p, _1p]]), + np.array([[6, 7, 8, 9, 10]]), + np.array([[6, 9, 7, 8, 10]]), rtol=rtol, ) @@ -397,26 +365,10 @@ def test_assert_argkmin_results_quasi_equality(): msg = "Neighbors indices for query 0 are not matching" with pytest.raises(AssertionError, match=msg): assert_argkmin_results_quasi_equality( - ref_dist=np.array( - [ - [1.2, 2.5, _6_1m, 6.1, _6_1p], - ] - ), - dist=np.array( - [ - [1.2, 2.5, _6_1m, 6.1, _6_1p], - ] - ), - ref_indices=np.array( - [ - [1, 2, 3, 4, 5], - ] - ), - indices=np.array( - [ - [2, 1, 3, 4, 5], - ] - ), + np.array([[1.2, 2.5, _6_1m, 6.1, _6_1p]]), + np.array([[1.2, 2.5, _6_1m, 6.1, _6_1p]]), + np.array([[1, 2, 3, 4, 5]]), + np.array([[2, 1, 3, 4, 5]]), rtol=rtol, ) @@ -424,26 +376,10 @@ def test_assert_argkmin_results_quasi_equality(): msg = "Neighbors indices for query 0 are not matching" with pytest.raises(AssertionError, match=msg): assert_argkmin_results_quasi_equality( - ref_dist=np.array( - [ - [1.2, 2.5, _6_1m, 6.1, _6_1p], - ] - ), - dist=np.array( - [ - [1.2, 2.5, _6_1m, 6.1, _6_1p], - ] - ), - ref_indices=np.array( - [ - [1, 2, 3, 4, 5], - ] - ), - indices=np.array( - [ - [2, 1, 4, 5, 3], - ] - ), + np.array([[1.2, 2.5, _6_1m, 6.1, _6_1p]]), + np.array([[1.2, 2.5, _6_1m, 6.1, _6_1p]]), + np.array([[1, 2, 3, 4, 5]]), + np.array([[2, 1, 4, 5, 3]]), rtol=rtol, ) @@ -511,50 +447,18 @@ def test_assert_radius_neighborhood_results_quasi_equality(): # Apply valid permutation on indices assert_radius_neighborhood_results_quasi_equality( - ref_dist=np.array( - [ - np.array([1.2, 2.5, _6_1m, 6.1, _6_1p]), - ] - ), - dist=np.array( - [ - np.array([1.2, 2.5, _6_1m, 6.1, _6_1p]), - ] - ), - ref_indices=np.array( - [ - np.array([1, 2, 3, 4, 5]), - ] - ), - indices=np.array( - [ - np.array([1, 2, 4, 5, 3]), - ] - ), + np.array([np.array([1.2, 2.5, _6_1m, 6.1, _6_1p])]), + np.array([np.array([1.2, 2.5, _6_1m, 6.1, _6_1p])]), + np.array([np.array([1, 2, 3, 4, 5])]), + np.array([np.array([1, 2, 4, 5, 3])]), radius=6.1, rtol=rtol, ) assert_radius_neighborhood_results_quasi_equality( - ref_dist=np.array( - [ - np.array([_1m, _1m, 1, _1p, _1p]), - ] - ), - dist=np.array( - [ - np.array([_1m, _1m, 1, _1p, _1p]), - ] - ), - ref_indices=np.array( - [ - np.array([6, 7, 8, 9, 10]), - ] - ), - indices=np.array( - [ - np.array([6, 9, 7, 8, 10]), - ] - ), + np.array([np.array([_1m, _1m, 1, _1p, _1p])]), + np.array([np.array([_1m, _1m, 1, _1p, _1p])]), + np.array([np.array([6, 7, 8, 9, 10])]), + np.array([np.array([6, 9, 7, 8, 10])]), radius=6.1, rtol=rtol, ) @@ -563,52 +467,20 @@ def test_assert_radius_neighborhood_results_quasi_equality(): msg = "Neighbors indices for query 0 are not matching" with pytest.raises(AssertionError, match=msg): assert_radius_neighborhood_results_quasi_equality( - ref_dist=np.array( - [ - np.array([1.2, 2.5, _6_1m, 6.1, _6_1p]), - ] - ), - dist=np.array( - [ - np.array([1.2, 2.5, _6_1m, 6.1, _6_1p]), - ] - ), - ref_indices=np.array( - [ - np.array([1, 2, 3, 4, 5]), - ] - ), - indices=np.array( - [ - np.array([2, 1, 3, 4, 5]), - ] - ), + np.array([np.array([1.2, 2.5, _6_1m, 6.1, _6_1p])]), + np.array([np.array([1.2, 2.5, _6_1m, 6.1, _6_1p])]), + np.array([np.array([1, 2, 3, 4, 5])]), + np.array([np.array([2, 1, 3, 4, 5])]), radius=6.1, rtol=rtol, ) # Having extra last elements is valid if they are in: [radius ± atol] assert_radius_neighborhood_results_quasi_equality( - ref_dist=np.array( - [ - np.array([1.2, 2.5, _6_1m, 6.1, _6_1p]), - ] - ), - dist=np.array( - [ - np.array([1.2, 2.5, _6_1m, 6.1]), - ] - ), - ref_indices=np.array( - [ - np.array([1, 2, 3, 4, 5]), - ] - ), - indices=np.array( - [ - np.array([1, 2, 3, 4]), - ] - ), + np.array([np.array([1.2, 2.5, _6_1m, 6.1, _6_1p])]), + np.array([np.array([1.2, 2.5, _6_1m, 6.1])]), + np.array([np.array([1, 2, 3, 4, 5])]), + np.array([np.array([1, 2, 3, 4])]), radius=6.1, rtol=rtol, ) @@ -619,26 +491,10 @@ def test_assert_radius_neighborhood_results_quasi_equality(): ) with pytest.raises(AssertionError, match=msg): assert_radius_neighborhood_results_quasi_equality( - ref_dist=np.array( - [ - np.array([1.2, 2.5, 6]), - ] - ), - dist=np.array( - [ - np.array([1.2, 2.5]), - ] - ), - ref_indices=np.array( - [ - np.array([1, 2, 3]), - ] - ), - indices=np.array( - [ - np.array([1, 2]), - ] - ), + np.array([np.array([1.2, 2.5, 6])]), + np.array([np.array([1.2, 2.5])]), + np.array([np.array([1, 2, 3])]), + np.array([np.array([1, 2])]), radius=6.1, rtol=rtol, ) @@ -647,26 +503,10 @@ def test_assert_radius_neighborhood_results_quasi_equality(): msg = "Neighbors indices for query 0 are not matching" with pytest.raises(AssertionError, match=msg): assert_radius_neighborhood_results_quasi_equality( - ref_dist=np.array( - [ - np.array([1.2, 2.5, _6_1m, 6.1, _6_1p]), - ] - ), - dist=np.array( - [ - np.array([1.2, 2.5, _6_1m, 6.1, _6_1p]), - ] - ), - ref_indices=np.array( - [ - np.array([1, 2, 3, 4, 5]), - ] - ), - indices=np.array( - [ - np.array([2, 1, 4, 5, 3]), - ] - ), + np.array([np.array([1.2, 2.5, _6_1m, 6.1, _6_1p])]), + np.array([np.array([1.2, 2.5, _6_1m, 6.1, _6_1p])]), + np.array([np.array([1, 2, 3, 4, 5])]), + np.array([np.array([2, 1, 4, 5, 3])]), radius=6.1, rtol=rtol, ) @@ -675,26 +515,10 @@ def test_assert_radius_neighborhood_results_quasi_equality(): msg = "Distances aren't sorted on row 0" with pytest.raises(AssertionError, match=msg): assert_radius_neighborhood_results_quasi_equality( - ref_dist=np.array( - [ - np.array([1.2, 2.5, _6_1m, 6.1, _6_1p]), - ] - ), - dist=np.array( - [ - np.array([2.5, 1.2, _6_1m, 6.1, _6_1p]), - ] - ), - ref_indices=np.array( - [ - np.array([1, 2, 3, 4, 5]), - ] - ), - indices=np.array( - [ - np.array([2, 1, 4, 5, 3]), - ] - ), + np.array([np.array([1.2, 2.5, _6_1m, 6.1, _6_1p])]), + np.array([np.array([2.5, 1.2, _6_1m, 6.1, _6_1p])]), + np.array([np.array([1, 2, 3, 4, 5])]), + np.array([np.array([2, 1, 4, 5, 3])]), radius=6.1, rtol=rtol, ) From 8e2f82a98493fc5180f636a2b4dfe1d0c13672d8 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Wed, 8 Jun 2022 17:56:18 +0200 Subject: [PATCH 16/17] TST Review comments Co-authored-by: Olivier Grisel --- .../test_pairwise_distances_reduction.py | 79 ++++++++----------- sklearn/neighbors/tests/test_neighbors.py | 14 +++- 2 files changed, 44 insertions(+), 49 deletions(-) diff --git a/sklearn/metrics/tests/test_pairwise_distances_reduction.py b/sklearn/metrics/tests/test_pairwise_distances_reduction.py index 15b8d8ab6db94..9ade4400288c6 100644 --- a/sklearn/metrics/tests/test_pairwise_distances_reduction.py +++ b/sklearn/metrics/tests/test_pairwise_distances_reduction.py @@ -137,20 +137,22 @@ def assert_argkmin_results_quasi_equality( n_queries, n_neighbors = ref_dist.shape # Asserting equality results one row at a time - for i in range(n_queries): - ref_dist_row = ref_dist[i] - dist_row = dist[i] + for query_idx in range(n_queries): + ref_dist_row = ref_dist[query_idx] + dist_row = dist[query_idx] - assert is_sorted(ref_dist_row), f"Reference distances aren't sorted on row {i}" - assert is_sorted(dist_row), f"Distances aren't sorted on row {i}" + assert is_sorted( + ref_dist_row + ), f"Reference distances aren't sorted on row {query_idx}" + assert is_sorted(dist_row), f"Distances aren't sorted on row {query_idx}" assert_allclose(ref_dist_row, dist_row, rtol=rtol) - ref_indices_row = ref_indices[i] - indices_row = indices[i] + ref_indices_row = ref_indices[query_idx] + indices_row = indices[query_idx] - # Grouping indices by distances using sets on - # a rounded distances up to a given number of decimals + # Grouping indices by distances using sets on a rounded distances up + # to a given number of decimals of significant digits derived from rtol. reference_neighbors_groups = defaultdict(set) effective_neighbors_groups = defaultdict(set) @@ -164,8 +166,9 @@ def assert_argkmin_results_quasi_equality( # Asserting equality of groups (sets) for each distance msg = ( - f"Neighbors indices for query {i} are not matching " - f"when rounding distances at n_significant_digits={n_significant_digits}" + f"Neighbors indices for query {query_idx} are not matching " + f"when rounding distances at {n_significant_digits} significant digits " + f"derived from rtol={rtol:.1e}" ) for rounded_distance in reference_neighbors_groups.keys(): assert ( @@ -254,16 +257,14 @@ def assert_radius_neighborhood_results_quasi_equality( # be able to compare them, ignoring the elements checked above. ref_dist_row = ref_dist_row[:min_length] dist_row = dist_row[:min_length] - print(type(ref_dist_row)) - print(type(dist_row)) assert_allclose(ref_dist_row, dist_row, rtol=rtol) ref_indices_row = ref_indices[query_idx] indices_row = indices[query_idx] - # Grouping indices by distances using sets on - # a rounded distances up to a given number of decimals + # Grouping indices by distances using sets on a rounded distances up + # to a given number of significant digits derived from rtol. reference_neighbors_groups = defaultdict(set) effective_neighbors_groups = defaultdict(set) @@ -278,7 +279,8 @@ def assert_radius_neighborhood_results_quasi_equality( # Asserting equality of groups (sets) for each distance msg = ( f"Neighbors indices for query {query_idx} are not matching " - f"when rounding distances at n_significant_digits={n_significant_digits}" + f"when rounding distances at {n_significant_digits} significant digits " + f"derived from rtol={rtol:.1e}" ) for rounded_distance in reference_neighbors_groups.keys(): assert ( @@ -314,12 +316,12 @@ def assert_radius_neighborhood_results_quasi_equality( def test_assert_argkmin_results_quasi_equality(): rtol = 1e-7 - atol = 1e-7 - _1m = 1.0 - atol - _1p = 1.0 + atol + eps = 1e-7 + _1m = 1.0 - eps + _1p = 1.0 + eps - _6_1m = 6.1 - atol - _6_1p = 6.1 + atol + _6_1m = 6.1 - eps + _6_1p = 6.1 + eps ref_dist = np.array( [ @@ -387,26 +389,10 @@ def test_assert_argkmin_results_quasi_equality(): msg = "Distances aren't sorted on row 0" with pytest.raises(AssertionError, match=msg): assert_argkmin_results_quasi_equality( - ref_dist=np.array( - [ - [1.2, 2.5, _6_1m, 6.1, _6_1p], - ] - ), - dist=np.array( - [ - [2.5, 1.2, _6_1m, 6.1, _6_1p], - ] - ), - ref_indices=np.array( - [ - [1, 2, 3, 4, 5], - ] - ), - indices=np.array( - [ - [2, 1, 4, 5, 3], - ] - ), + np.array([[1.2, 2.5, _6_1m, 6.1, _6_1p]]), + np.array([[2.5, 1.2, _6_1m, 6.1, _6_1p]]), + np.array([[1, 2, 3, 4, 5]]), + np.array([[2, 1, 4, 5, 3]]), rtol=rtol, ) @@ -414,13 +400,12 @@ def test_assert_argkmin_results_quasi_equality(): def test_assert_radius_neighborhood_results_quasi_equality(): rtol = 1e-7 - atol = 1e-7 - - _1m = 1.0 - atol - _1p = 1.0 + atol + eps = 1e-7 + _1m = 1.0 - eps + _1p = 1.0 + eps - _6_1m = 6.1 - atol - _6_1p = 6.1 + atol + _6_1m = 6.1 - eps + _6_1p = 6.1 + eps ref_dist = np.array( [ diff --git a/sklearn/neighbors/tests/test_neighbors.py b/sklearn/neighbors/tests/test_neighbors.py index 337e777191475..faffa8bf85265 100644 --- a/sklearn/neighbors/tests/test_neighbors.py +++ b/sklearn/neighbors/tests/test_neighbors.py @@ -2152,7 +2152,12 @@ def test_neighbors_distance_metric_deprecation(): "metric", sorted(set(neighbors.VALID_METRICS["brute"]) - set(["precomputed"])) ) def test_radius_neighbors_brute_backend( - metric, n_samples=2000, n_features=30, n_query_pts=100, n_neighbors=5 + metric, + n_samples=2000, + n_features=30, + n_query_pts=100, + n_neighbors=5, + radius=1.0, ): # Both backends for the 'brute' algorithm of radius_neighbors # must give identical results. @@ -2179,6 +2184,7 @@ def test_radius_neighbors_brute_backend( neigh = neighbors.NearestNeighbors( n_neighbors=n_neighbors, + radius=radius, algorithm="brute", metric=metric, p=p, @@ -2199,7 +2205,11 @@ def test_radius_neighbors_brute_backend( ) assert_radius_neighborhood_results_equality( - legacy_brute_dst, pdr_brute_dst, legacy_brute_idx, pdr_brute_idx + legacy_brute_dst, + pdr_brute_dst, + legacy_brute_idx, + pdr_brute_idx, + radius=radius, ) From a412ac8d6d04ceef3cc6984a0928eedf281439bd Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Thu, 9 Jun 2022 10:31:21 +0200 Subject: [PATCH 17/17] Use rtol directly MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Jérémie du Boisberranger --- .../test_pairwise_distances_reduction.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/sklearn/metrics/tests/test_pairwise_distances_reduction.py b/sklearn/metrics/tests/test_pairwise_distances_reduction.py index 9ade4400288c6..fa475134c7a9f 100644 --- a/sklearn/metrics/tests/test_pairwise_distances_reduction.py +++ b/sklearn/metrics/tests/test_pairwise_distances_reduction.py @@ -206,10 +206,10 @@ def assert_radius_neighborhood_results_quasi_equality( ): """Assert that radius neighborhood results are valid up to: - relative tolerance on computed distance values - - permutations of indices for distances values that - differ up to a precision level set by `decimals` + - permutations of indices for distances values that differ up to + a precision level - missing or extra last elements if their distance is - close to the radius + close to the radius To be used for testing neighbors queries on float32 datasets: we accept neighbors rank swaps only if they are caused by small @@ -243,14 +243,13 @@ def assert_radius_neighborhood_results_quasi_equality( largest_row = ref_dist_row if len(ref_dist_row) > len(dist_row) else dist_row # For the longest distances vector, we check that last extra elements - # that aren't present in the other vector are all in: [radius ± atol] - atol = 10 ** (-n_significant_digits) + # that aren't present in the other vector are all in: [radius ± rtol] min_length = min(len(ref_dist_row), len(dist_row)) last_extra_elements = largest_row[min_length:] if last_extra_elements.size > 0: - assert np.all(radius - atol <= last_extra_elements <= radius + atol), ( + assert np.all(radius - rtol <= last_extra_elements <= radius + rtol), ( f"The last extra elements ({last_extra_elements}) aren't in [radius ±" - f" atol]=[{radius} ± {atol}]" + f" rtol]=[{radius} ± {rtol}]" ) # We truncate the neighbors results list on the smallest length to @@ -460,7 +459,7 @@ def test_assert_radius_neighborhood_results_quasi_equality(): rtol=rtol, ) - # Having extra last elements is valid if they are in: [radius ± atol] + # Having extra last elements is valid if they are in: [radius ± rtol] assert_radius_neighborhood_results_quasi_equality( np.array([np.array([1.2, 2.5, _6_1m, 6.1, _6_1p])]), np.array([np.array([1.2, 2.5, _6_1m, 6.1])]), @@ -470,9 +469,9 @@ def test_assert_radius_neighborhood_results_quasi_equality(): rtol=rtol, ) - # Having extra last elements is invalid if they are lesser than radius - atol + # Having extra last elements is invalid if they are lesser than radius - rtol msg = re.escape( - "The last extra elements ([6.]) aren't in [radius ± atol]=[6.1 ± 1e-06]" + "The last extra elements ([6.]) aren't in [radius ± rtol]=[6.1 ± 1e-07]" ) with pytest.raises(AssertionError, match=msg): assert_radius_neighborhood_results_quasi_equality(