diff --git a/sklearn/utils/_fast_dict.pyx b/sklearn/utils/_fast_dict.pyx index 74aaa16b020eb..5fe642b14c626 100644 --- a/sklearn/utils/_fast_dict.pyx +++ b/sklearn/utils/_fast_dict.pyx @@ -12,13 +12,6 @@ from libcpp.map cimport map as cpp_map import numpy as np -# Import the C-level symbols of numpy -cimport numpy as cnp - -# Numpy must be initialized. When using numpy from C or Cython you must -# _always_ do that, or you will have segfaults -cnp.import_array() - #DTYPE = np.float64 #ctypedef cnp.float64_t DTYPE_t @@ -35,8 +28,11 @@ cnp.import_array() cdef class IntFloatDict: - def __init__(self, cnp.ndarray[ITYPE_t, ndim=1] keys, - cnp.ndarray[DTYPE_t, ndim=1] values): + def __init__( + self, + ITYPE_t[:] keys, + DTYPE_t[:] values, + ): cdef int i cdef int size = values.size # Should check that sizes for keys and values are equal, and @@ -91,10 +87,8 @@ cdef class IntFloatDict: The values of the data points """ cdef int size = self.my_map.size() - cdef cnp.ndarray[ITYPE_t, ndim=1] keys = np.empty(size, - dtype=np.intp) - cdef cnp.ndarray[DTYPE_t, ndim=1] values = np.empty(size, - dtype=np.float64) + keys = np.empty(size, dtype=np.intp) + values = np.empty(size, dtype=np.float64) self._to_arrays(keys, values) return keys, values diff --git a/sklearn/utils/tests/test_fast_dict.py b/sklearn/utils/tests/test_fast_dict.py index 050df133a2d24..96c14068f0db1 100644 --- a/sklearn/utils/tests/test_fast_dict.py +++ b/sklearn/utils/tests/test_fast_dict.py @@ -1,6 +1,7 @@ """ Test fast_dict. """ import numpy as np +from numpy.testing import assert_array_equal, assert_allclose from sklearn.utils._fast_dict import IntFloatDict, argmin @@ -29,3 +30,18 @@ def test_int_float_dict_argmin(): values = np.arange(100, dtype=np.float64) d = IntFloatDict(keys, values) assert argmin(d) == (0, 0) + + +def test_to_arrays(): + # Test that an IntFloatDict is converted into arrays + # of keys and values correctly + keys_in = np.array([1, 2, 3], dtype=np.intp) + values_in = np.array([4, 5, 6], dtype=np.float64) + + d = IntFloatDict(keys_in, values_in) + keys_out, values_out = d.to_arrays() + + assert keys_out.dtype == keys_in.dtype + assert values_in.dtype == values_out.dtype + assert_array_equal(keys_out, keys_in) + assert_allclose(values_out, values_in)