8000 Add empty_like function so we can be future NEP18 compatible · scikit-learn/scikit-learn@721d062 · GitHub
[go: up one dir, main page]

Skip to content

Commit 721d062

Browse files
Add empty_like function so we can be future NEP18 compatible
1 parent 5c99577 commit 721d062

File tree

2 files changed

+86
-0
lines changed

2 files changed

+86
-0
lines changed

sklearn/utils/fixes.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,41 @@ def _object_dtype_isnan(X):
185185
return X != X
186186

187187

188+
def empty_like(prototype, dtype=None, order='K', subok=True, shape=None):
189+
"""Forwards call to numpy.empty_like or empty, to be compatible with NEP18.
190+
191+
Before numpy 1.17, numpy.empty_like did not take a shape argument.
192+
193+
When version of numpy < (1, 17), prototype is not an ndarray instance
194+
and shape is provided, a ValueError will be raised.
195+
When version of numpy < (1, 17), and shape is provided, the call will
196+
be forwarded to numpy.empty.
197+
"""
198+
if np_version < (1, 17):
199+
if not isinstance(prototype, np.ndarray) and shape is not None:
200+
raise ValueError('NumPy %r does not support NEP18' % (np_version,))
201+
if shape is not None:
202+
prototype = np.array(prototype, copy=False, order=order, subok=subok)
203+
if dtype is None:
204+
dtype = prototype.dtype
205+
if order == 'A':
206+
order = 'F' if prototype.flags['F_CONTIGUOUS'] else 'C'
207+
elif order == 'K':
208+
if prototype.flags['C_CONTIGUOUS'] or prototype.ndim <= 1:
209+
order = 'C'
210+
elif prototype.flags['F_CONTIGUOUS']:
211+
order = 'F'
212+
else:
213+
raise NotImplementedError('order=K not implemented for'
214+
'non contiguous C or F array')
215+
return np.empty(shape, dtype=dtype, order=order)
216+
else:
217+
return np.empty_like(prototype, dtype=dtype, order=order, subok=subok)
218+
else:
219+
return np.empty_like(prototype, dtype=dtype, order=order, subok=subok,
220+
shape=shape)
221+
222+
188223
# TODO: replace by copy=False, when only scipy > 1.1 is supported.
189224
def _astype_copy_false(X):
190225
"""Returns the copy=False parameter for

sklearn/utils/tests/test_fixes.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import math
77
import pickle
8+
from unittest.mock import MagicMock
89

910
import numpy as np
1011
import pytest
@@ -16,6 +17,8 @@
1617
from sklearn.utils.fixes import _joblib_parallel_args
1718
from sklearn.utils.fixes import _object_dtype_isnan
1819
from sklearn.utils.fixes import loguniform
20+
from sklearn.utils.fixes import empty_like
21+
from sklearn.utils.fixes import np_version
1922

2023

2124
def test_masked_array_obj_dtype_pickleable():
@@ -95,3 +98,51 @@ def test_loguniform(low, high, base):
9598
loguniform(base ** low, base ** high).rvs(random_state=0)
9699
== loguniform(base ** low, base ** high).rvs(random_state=0)
97100
)
101+
102+
103+
@pytest.mark.skipif(np_version < (1, 17),
104+
reason="NEP18 not supported before 1.17")
105+
def test_empty_like_nep18():
106+
class ArrayLike:
107+
__array_function__ = MagicMock(return_value=42)
108+
109+
# if NEP18 is supported, empty_like should be forwarded to us
110+
array_like = ArrayLike()
111+
value = empty_like(array_like, dtype=np.float32, shape=(4, 2))
112+
assert value == 42
113+
114+
115+
def test_empty_like():
116+
# Normaly arrays should just work with all versions of numpy
117+
X = np.arange(8)
118+
Y = empty_like(X.reshape((4, 2)))
119+
assert isinstance(Y, np.ndarray)
120+
assert Y.shape == (4, 2)
121+
122+
123+
@pytest.mark.skipif(np_version >= (1, 17),
124+
reason="NEP18 not supported before 1.17")
125+
def test_empty_like_no_nep18():
126+
class NotAnArray:
127+
def __array__(self):
128+
return np.arange(8, dtype=np.float64).reshape((4, 2))
129+
130+
# for numpy < 1.17, we should give an error msg, if we provide shape
131+
no_array = NotAnArray()
132+
with pytest.raises(ValueError):
133+
empty_like(no_array, dtype=np.float32, shape=(4, 2))
134+
135+
# we can pass a non-ndarray object, but without shape
136+
no_array = NotAnArray()
137+
an_array = empty_like(no_array, dtype=np.float32)
138+
assert an_array.shape == (4, 2)
139+
assert an_array.dtype == np.float32
140+
141+
# but with a ndarray, we can pass with shape
142+
second_array = empty_like(an_array, dtype=np.float64, shape=(3,5))
143+
assert second_array.shape == (3, 5)
144+
assert second_array.dtype == np.float64
145+
146+
second_array_same_type = empty_like(an_array, shape=(3,5))
147+
assert second_array_same_type.shape == (3, 5)
148+
assert second_array_same_type.dtype == np.float32

0 commit comments

Comments
 (0)
0