8000 FIX ensure object array are properly casted when dtype=object (#16076) · jeremiedbb/scikit-learn@c4ea377 · GitHub
[go: up one dir, main page]

Skip to content

Commit c4ea377

Browse files
alexshackedTomDLT
authored andcommitted
FIX ensure object array are properly casted when dtype=object (scikit-learn#16076)
1 parent 88eadf0 commit c4ea377

File tree

6 files changed

+92
-12
lines changed

6 files changed

+92
-12
lines changed

doc/whats_new/v0.23.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,15 @@ Changelog
136136
differs between `predict` and `fit`.
137137
:pr:`16090` by :user:`Madhura Jayaratne <madhuracj>`.
138138

139+
:mod:`sklearn.neighbors`
140+
..............................
141+
142+
- |Fix| Fix a bug which converted a list of arrays into a 2-D object
143+
array instead of a 1-D array containing NumPy arrays. This bug
144+
was affecting :meth:`neighbors.NearestNeighbors.radius_neighbors`.
145+
:pr:`16076` by :user:`Guillaume Lemaitre <glemaitre>` and
146+
:user:`Alex Shacked <alexshacked>`.
147+
139148
:mod:`sklearn.preprocessing`
140149
............................
141150

sklearn/neighbors/_base.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from ..metrics import pairwise_distances_chunked
2525
from ..metrics.pairwise import PAIRWISE_DISTANCE_FUNCTIONS
2626
from ..utils import check_X_y, check_array, gen_even_slices
27+
from ..utils import _to_object_array
2728
from ..utils.multiclass import check_classification_targets
2829
from ..utils.validation import check_is_fitted
2930
from ..utils.validation import check_non_negative
@@ -276,8 +277,8 @@ def _radius_neighbors_from_graph(graph, radius, return_distance):
276277
indices = indices.astype(np.intp, copy=no_filter_needed)
277278

278279
if return_distance:
279-
neigh_dist = np.array(np.split(data, indptr[1:-1]), dtype=object)
280-
neigh_ind = np.array(np.split(indices, indptr[1:-1]), dtype=object)
280+
neigh_dist = _to_object_array(np.split(data, indptr[1:-1]))
281+
neigh_ind = _to_object_array(np.split(indices, indptr[1:-1]))
281282

282283
if return_distance:
283284
return neigh_dist, neigh_ind
@@ -940,17 +941,12 @@ class from an array representing our data set and ask who's
940941
neigh_dist_chunks, neigh_ind_chunks = zip(*chunked_results)
941942
neigh_dist_list = sum(neigh_dist_chunks, [])
942943
neigh_ind_list = sum(neigh_ind_chunks, [])
943-
# See https://github.com/numpy/numpy/issues/5456
944-
# to understand why this is initialized this way.
945-
neigh_dist = np.empty(len(neigh_dist_list), dtype='object')
946-
neigh_dist[:] = neigh_dist_list
947-
neigh_ind = np.empty(len(neigh_ind_list), dtype='object')
948-
neigh_ind[:] = neigh_ind_list
944+
neigh_dist = _to_object_array(neigh_dist_list)
945+
neigh_ind = _to_object_array(neigh_ind_list)
9 8000 49946
results = neigh_dist, neigh_ind
950947
else:
951948
neigh_ind_list = sum(chunked_results, [])
952-
results = np.empty(len(neigh_ind_list), dtype='object')
953-
results[:] = neigh_ind_list
949+
results = _to_object_array(neigh_ind_list)
954950

955951
elif self._fit_method in ['ball_tree', 'kd_tree']:
956952
if issparse(X):

sklearn/neighbors/tests/test_neighbors.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -649,6 +649,30 @@ def test_radius_neighbors_boundary_handling():
649649
assert_array_equal(results[0], [0, 1])
650650

651651

652+
def test_radius_neighbors_returns_array_of_objects():
653+
# check that we can pass precomputed distances to
654+
# NearestNeighbors.radius_neighbors()
655+
# non-regression test for
656+
# https://github.com/scikit-learn/scikit-learn/issues/16036
657+
X = csr_matrix(np.ones((4, 4)))
658+
X.setdiag([0, 0, 0, 0])
659+
660+
nbrs = neighbors.NearestNeighbors(radius=0.5, algorithm='auto',
661+
leaf_size=30,
662+
metric='precomputed').fit(X)
663+
neigh_dist, neigh_ind = nbrs.radius_neighbors(X, return_distance=True)
664+
665+
expected_dist = np.empty(X.shape[0], dtype=object)
666+
expected_dist[:] = [np.array([0]), np.array([0]), np.array([0]),
667+
np.array([0])]
668+
expected_ind = np.empty(X.shape[0], dtype=object)
669+
expected_ind[:] = [np.array([0]), np.array([1]), np.array([2]),
670+
np.array([3])]
671+
672+
assert_array_equal(neigh_dist, expected_dist)
673+
assert_array_equal(neigh_ind, expec A3E2 ted_ind)
674+
675+
652676
def test_RadiusNeighborsClassifier_multioutput():
653677
# Test k-NN classifier on multioutput data
654678
rng = check_random_state(0)

sklearn/preprocessing/tests/test_label.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from sklearn.utils._testing import assert_array_equal
1515
from sklearn.utils._testing import assert_warns_message
1616
from sklearn.utils._testing import ignore_warnings
17+
from sklearn.utils import _to_object_array
1718

1819
from sklearn.preprocessing._label import LabelBinarizer
1920
from sklearn.preprocessing._label import MultiLabelBinarizer
@@ -433,8 +434,7 @@ def test_multilabel_binarizer_same_length_sequence():
433434

434435

435436
def test_multilabel_binarizer_non_integer_labels():
436-
tuple_classes = np.empty(3, dtype=object)
437-
tuple_classes[:] = [(1,), (2,), (3,)]
437+
tuple_classes = _to_object_array([(1,), (2,), (3,)])
438438
inputs = [
439439
([('2', '3'), ('1',), ('1', '2')], ['1', '2', '3']),
440440
([('b', 'c'), ('a',), ('a', 'b')], ['a', 'b', 'c']),

sklearn/utils/__init__.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -819,6 +819,45 @@ def tosequence(x):
819819
return list(x)
820820

821821

822+
def _to_object_array(sequence):
823+
"""Convert sequence to a 1-D NumPy array of object dtype.
824+
825+
numpy.array constructor has a similar use but it's output
826+
is ambiguous. It can be 1-D NumPy array of object dtype if
827+
the input is a ragged array, but if the input is a list of
828+
equal length arrays, then the output is a 2D numpy.array.
829+
_to_object_array solves this ambiguity by guarantying that
830+
the output is a 1-D NumPy array of objects for any input.
831+
832+
Parameters
833+
----------
834+
sequence : array-like of shape (n_elements,)
835+
The sequence to be converted.
836+
837+
Returns
838+
-------
839+
out : ndarray of shape (n_elements,), dtype=object
840+
The converted sequence into a 1-D NumPy array of object dtype.
841+
842+
Examples
843+
--------
844+
>>> import numpy as np
845+
>>> from sklearn.utils import _to_object_array
846+
>>> _to_object_array([np.array([0]), np.array([1])])
847+
array([array([0]), array([1])], dtype=object)
848+
>>> _to_object_array([np.array([0]), np.array([1, 2])])
849+
array([array([0]), array([1, 2])], dtype=object)
850+
>>> np.array([np.array([0]), np.array([1])])
851+
array([[0],
852+
[1]])
853+
>>> np.array([np.array([0]), np.array([1, 2])])
854+
array([array([0]), array([1, 2])], dtype=object)
855+
"""
856+
out = np.empty(len(sequence), dtype=object)
857+
out[:] = sequence
858+
return out
859+
860+
822861
def indices_to_mask(indices, mask_length):
823862
"""Convert list of indices to boolean mask.
824863

sklearn/utils/tests/test_utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from sklearn.utils import _message_with_time, _print_elapsed_time
2929
from sklearn.utils import get_chunk_n_rows
3030
from sklearn.utils import is_scalar_nan
31+
from sklearn.utils import _to_object_array
3132
from sklearn.utils._mocking import MockDataFrame
3233
from sklearn import config_context
3334

@@ -646,3 +647,14 @@ def test_deprecation_joblib_api(tmpdir):
646647

647648
from sklearn.utils._joblib import joblib
648649
del joblib.parallel.BACKENDS['failing']
650+
651+
652+
@pytest.mark.parametrize(
653+
"sequence",
654+
[[np.array(1), np.array(2)], [[1, 2], [3, 4]]]
655+
)
656+
def test_to_object_array(sequence):
657+
out = _to_object_array(sequence)
658+
assert isinstance(out, np.ndarray)
659+
assert out.dtype.kind == 'O'
660+
assert out.ndim == 1

0 commit comments

Comments
 (0)
0