8000 ENH: add copy parameter for api.reshape function by bwalshe · Pull Request #23789 · numpy/numpy · GitHub
[go: up one dir, main page]

Skip to content

ENH: add copy parameter for api.reshape function #23789

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jun 14, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
BUG: add copy parameter for api.reshape function
This adds a parameter to api.reshape to specify if data
should be copied. This parameter is required so that
api.reshape conforms to the standard. See #23410
  • Loading branch information
bwalshe committed May 22, 2023
commit c19e84e012da65828f62f54f721337d873fd04fe
15 changes: 13 additions & 2 deletions numpy/array_api/_manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,24 @@ def permute_dims(x: Array, /, axes: Tuple[int, ...]) -> Array:


# Note: the optional argument is called 'shape', not 'newshape'
def reshape(x: Array, /, shape: Tuple[int, ...]) -> Array:
def reshape(x: Array,
/,
shape: Tuple[int, ...],
*,
copy: Optional[Bool] = None) -> Array:
"""
Array API compatible wrapper for :py:func:`np.reshape <numpy.reshape>`.

See its docstring for more information.
"""
return Array._new(np.reshape(x._array, shape))
if copy is False:
raise NotImplementedError("copy=False is not yet implemented")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this can be implemented by calling reshape and raising an exception if the result is not a view. Or maybe there's even a more direct way to tell if reshape will copy without actually trying it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we need to implement it directly in np.reshape. Do others have issues with the above simple workaround here for now though?

Copy link
Contributor Author
@bwalshe bwalshe May 25, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we need to implement it directly in np.reshape.

I think that in order to do this correctly, I might have to change the signature of the PyArray_Newshape(PyArrayObject *self, PyArray_Dims *newdims, NPY_ORDER order) function in numpy/core/src/multiarray/shape.c, but this has been marked as NUMPY_API in the comments.

I can make the change you suggested above, but if I want to go ahead with the change to PyArray_Newshape, what do I have to do to get approval for that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@asmeurer Do these changes look OK?


data = x._array
if copy:
data = np.copy(data)

return Array._new(np.reshape(data, shape))


def roll(
Expand Down
30 changes: 30 additions & 0 deletions numpy/array_api/tests/test_manipulation_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from numpy.testing import assert_raises
import numpy as np

from .. import all
from .._creation_functions import asarray
from .._dtypes import float64, int8
from .._manipulation_functions import (
concat,
reshape,
stack
)


def test_concat_errors():
assert_raises(TypeError, lambda: concat((1, 1), axis=None))
assert_raises(TypeError, lambda: concat([asarray([1], dtype=int8), asarray([1], dtype=float64)]))


def test_stack_errors():
assert_raises(TypeError, lambda: stack([asarray([1, 1], dtype=int8), asarray([2, 2], dtype=float64)]))


def test_reshape_copy():
a = asarray([1])
b = reshape(a, (1, 1), copy=True)
a[0] = 0
assert all(b[0, 0] == 1)
assert all(a[0] == 0)
assert_raises(NotImplementedError, lambda: reshape(a, (1, 1), copy=False))

0