8000 MAINT replace cnp.ndarray with memory views in _fast_dict (#25754) · scikit-learn/scikit-learn@dfda968 · GitHub
[go: up one dir, main page]

Skip to content

Commit dfda968

Browse files
authored
MAINT replace cnp.ndarray with memory views in _fast_dict (#25754)
1 parent 4180b07 commit dfda968

File tree

2 files changed

+23
-13
lines changed

2 files changed

+23
-13
lines changed

sklearn/utils/_fast_dict.pyx

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,6 @@ from libcpp.map cimport map as cpp_map
1212

1313
import numpy as np
1414

15-
# Import the C-level symbols of numpy
16-
cimport numpy as cnp
17-
18-
# Numpy must be initialized. When using numpy from C or Cython you must
19-
# _always_ do that, or you will have segfaults
20-
cnp.import_array()
21-
2215
#DTYPE = np.float64
2316
#ctypedef cnp.float64_t DTYPE_t
2417

@@ -35,8 +28,11 @@ cnp.import_array()
3528

3629
cdef class IntFloatDict:
3730

38-
def __init__(self, cnp.ndarray[ITYPE_t, ndim=1] keys,
39-
cnp.ndarray[DTYPE_t, ndim=1] values):
31+
def __init__(
32+
self,
33+
ITYPE_t[:] keys,
34+
DTYPE_t[:] values,
35+
):
4036
cdef int i
4137
cdef int size = values.size
4238
# Should check that sizes for keys and values are equal, and
@@ -91,10 +87,8 @@ cdef class IntFloatDict:
9187
The values of the data points
9288
"""
9389
cdef int size = self.my_map.size()
94-
cdef cnp.ndarray[ITYPE_t, ndim=1] keys = np.empty(size,
95-
dtype=np.intp)
96-
cdef cnp.ndarray[DTYPE_t, ndim=1] values = np.empty(size,
97-
dtype=np.float64)
90+
keys = np.empty(size, dtype=np.intp)
91+
values = np.empty(size, dtype=np.float64)
9892
self._to_arrays(keys, values)
9993
return keys, values
10094

sklearn/utils/tests/test_fast_dict.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
""" Test fast_dict.
22
"""
33
import numpy as np
4+
from numpy.testing import assert_array_equal, assert_allclose
45

56
from sklearn.utils._fast_dict import IntFloatDict, argmin
67

@@ -29,3 +30,18 @@ def test_int_float_dict_argmin():
2930
values = np.arange(100, dtype=np.float64)
3031
d = IntFloatDict(keys, values)
3132
assert argmin(d) == (0, 0)
33+
34+
35+
def test_to_arrays():
36+
# Test that an IntFloatDict is converted into arrays
37+
# of keys and values correctly
38+
keys_in = np.array([1, 2, 3], dtype=np.intp)
39+
values_in = np.array([4, 5, 6], dtype=np.float64)
40+
41+
d = IntFloatDict(keys_in, values_in)
42+
keys_out, values_out = d.to_arrays()
43+
44+
assert keys_out.dtype == keys_in.dtype
45+
assert values_in.dtype == values_out.dtype
46+
assert_array_equal(keys_out, keys_in)
47+
assert_allclose(values_out, values_in)

0 commit comments

Comments
 (0)
0