8000 Merge pull request #23789 from bwalshe/array-api-reshape-copy · numpy/numpy@6087bcd · GitHub
[go: up one dir, main page]

Skip to content

Commit 6087bcd

Browse files
authored
Merge pull request #23789 from bwalshe/array-api-reshape-copy
ENH: add copy parameter for api.reshape function
2 parents 8692e24 + 8c184b5 commit 6087bcd

File tree

2 files changed

+53
-2
lines changed

2 files changed

+53
-2
lines changed

numpy/array_api/_manipulation_functions.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,13 +53,27 @@ def permute_dims(x: Array, /, axes: Tuple[int, ...]) -> Array:
5353

5454

5555
# Note: the optional argument is called 'shape', not 'newshape'
56-
def reshape(x: Array, /, shape: Tuple[int, ...]) -> Array:
56+
def reshape(x: Array,
57+
/,
58+
shape: Tuple[int, ...],
59+
*,
60+
copy: Optional[Bool] = None) -> Array:
5761
"""
5862
Array API compatible wrapper for :py:func:`np.reshape <numpy.reshape>`.
5963
6064
See its docstring for more information.
6165
"""
62-
return Array._new(np.reshape(x._array, shape))
66+
67+
data = x._array
68+
if copy:
69+
data = np.copy(data)
70+
71+
reshaped = np.reshape(data, shape)
72+
73+
if copy is False and not np.shares_memory(data, reshaped):
74+
raise AttributeError("Incompatible shape for in-place modification.")
75+
76+
return Array._new(reshaped)
6377

6478

6579
def roll(
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
from numpy.testing import assert_raises
2+
import numpy as np
3+
4+
from .. import all
5+
from .._creation_functions import asarray
6+
from .._dtypes import float64, int8
7+
from .._manipulation_functions import (
8+
concat,
9+
reshape,
10+
stack
11+
)
12+
13+
14+
def test_concat_errors():
15+
assert_raises(TypeError, lambda: concat((1, 1), axis=None))
16+
assert_raises(TypeError, lambda: concat([asarray([1], dtype=int8),
17+
asarray([1], dtype=float64)]))
18+
19+
20+
def test_stack_errors():
21+
assert_raises(TypeError, lambda: stack([asarray([1, 1], dtype=int8),
22+
asarray([2, 2], dtype=float64)]))
23+
24+
25+
def test_reshape_copy():
26+
a = asarray(np.ones((2, 3)))
27+
b = reshape(a, (3, 2), copy=True)
28+
assert not np.shares_memory(a._array, b._array)
29+
30+
a = asarray(np.ones((2, 3)))
31+
b = reshape(a, (3, 2), copy=False)
32+
assert np.shares_memory(a._array, b._array)
33+
34+
a = asarray(np.ones((2, 3)).T)
35+
b = reshape(a, (3, 2), copy=True)
36+
assert_raises(AttributeError, lambda: reshape(a, (2, 3), copy=False))
37+

0 commit comments

Comments
 (0)
0