8000 ENH Forces shape to be tuple when using Array API's reshape (#26030) · scikit-learn/scikit-learn@87ec740 · GitHub
[go: up one dir, main page]

Skip to content

Commit 87ec740

Browse files
thomasjpfanogriselbetatim
authored
ENH Forces shape to be tuple when using Array API's reshape (#26030)
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org> Co-authored-by: Tim Head <betatim@gmail.com>
1 parent 7f1e15d commit 87ec740

File tree

5 files changed

+34
-4
lines changed

5 files changed

+34
-4
lines changed

sklearn/discriminant_analysis.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -640,7 +640,7 @@ def fit(self, X, y):
640640
intercept_ = xp.asarray(
641641
self.intercept_[1] - self.intercept_[0], dtype=X.dtype
642642
)
643-
self.intercept_ = xp.reshape(intercept_, 1)
643+
self.intercept_ = xp.reshape(intercept_, (1,))
644644
self._n_features_out = self._max_components
645645
return self
646646

sklearn/linear_model/_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,7 @@ def decision_function(self, X):
399399

400400
X = self._validate_data(X, accept_sparse="csr", reset=False)
401401
scores = safe_sparse_dot(X, self.coef_.T, dense_output=True) + self.intercept_
402-
return xp.reshape(scores, -1) if scores.shape[1] == 1 else scores
402+
return xp.reshape(scores, (-1,)) if scores.shape[1] == 1 else scores
403403

404404
def predict(self, X):
405405
"""

sklearn/utils/_array_api.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,21 @@ def unique_values(self, x):
166166
def concat(self, arrays, *, axis=None):
167167
return numpy.concatenate(arrays, axis=axis)
168168

169+
def reshape(self, x, shape, *, copy=None):
170+
"""Gives a new shape to an array without changing its data.
171+
172+
The Array API specification requires shape to be a tuple.
173+
https://data-apis.org/array-api/latest/API_specification/generated/array_api.reshape.html
174+
"""
175+
if not isinstance(shape, tuple):
176+
raise TypeError(
177+
f"shape must be a tuple, got {shape!r} of type {type(shape)}"
178+
)
179+
180+
if copy is True:
181+
x = x.copy()
182+
return numpy.reshape(x, shape)
183+
169184
def isdtype(self, dtype, kind):
170185
return isdtype(dtype, kind, xp=self)
171186

sklearn/utils/tests/test_array_api.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,21 @@ def test_convert_estimator_to_array_api():
189189
assert hasattr(new_est.X_, "__array_namespace__")
190190

191191

192+
def test_reshape_behavior():
193+
"""Check reshape behavior with copy and is strict with non-tuple shape."""
194+
xp = _NumPyAPIWrapper()
195+
X = xp.asarray([[1, 2, 3], [3, 4, 5]])
196+
197+
X_no_copy = xp.reshape(X, (-1,), copy=False)
198+
assert X_no_copy.base is X
199+
200+
X_copy = xp.reshape(X, (6, 1), copy=True)
201+
assert X_copy.base is not X.base
202+
203+
with pytest.raises(TypeError, match="shape must be a tuple"):
204+
xp.reshape(X, -1)
205+
206+
192207
@pytest.mark.parametrize("wrapper", [_ArrayAPIWrapper, _NumPyAPIWrapper])
193208
def test_get_namespace_array_api_isdtype(wrapper):
194209
"""Test isdtype implementation from _ArrayAPIWrapper and _NumPyAPIWrapper."""

sklearn/utils/validation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1209,7 +1209,7 @@ def column_or_1d(y, *, dtype=None, warn=False):
12091209

12101210
shape = y.shape
12111211
if len(shape) == 1:
1212-
return _asarray_with_order(xp.reshape(y, -1), order="C", xp=xp)
1212+
return _asarray_with_order(xp.reshape(y, (-1,)), order="C", xp=xp)
12131213
if len(shape) == 2 and shape[1] == 1:
12141214
if warn:
12151215
warnings.warn(
@@ -1219,7 +1219,7 @@ def column_or_1d(y, *, dtype=None, warn=False):
12191219
DataConversionWarning,
12201220
stacklevel=2,
12211221
)
1222-
return _asarray_with_order(xp.reshape(y, -1), order="C", xp=xp)
1222+
return _asarray_with_order(xp.reshape(y, (-1,)), order="C", xp=xp)
12231223

12241224
raise ValueError(
12251225
"y should be a 1d array, got an array of shape {} instead.".format(shape)

0 commit comments

Comments
 (0)
0