8000 ENH: add np.stack by shoyer · Pull Request #5605 · numpy/numpy · GitHub
[go: up one dir, main page]

Skip to content

ENH: add np.stack #5605

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 12, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading

Uh oh!

There was an error while loading. Please reload this page.

Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/release/1.10.0-notes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ Highlights
* Addition of *np.linalg.multi_dot*: compute the dot product of two or more
arrays in a single function call, while automatically selecting the fastest
evaluation order.
* The new function `np.stack` provides a general interface for joining a
sequence of arrays along a new axis, complementing `np.concatenate` for
joining along an existing axis.
* Addition of `nanprod` to the set of nanfunctions.


Expand Down
5 changes: 3 additions & 2 deletions doc/source/reference/routines.array-manipulation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,9 @@ Joining arrays
.. autosummary::
:toctree: generated/

column_stack
concatenate
stack
column_stack
dstack
hstack
vstack
Expand All @@ -75,10 +76,10 @@ Splitting arrays
.. autosummary::
:toctree: generated/

split
array_split
dsplit
hsplit
split
vsplit

Tiling arrays
Expand Down
3 changes: 2 additions & 1 deletion numpy/add_newdocs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1142,7 +1142,7 @@ def luf(lamdaexpr, *args, **kwargs):
"""
concatenate((a1, a2, ...), axis=0)

Join a sequence of arrays together.
Join a sequence of arrays along an existing axis.

Parameters
----------
Expand All @@ -1166,6 +1166,7 @@ def luf(lamdaexpr, *args, **kwargs):
hsplit : Split array into multiple sub-arrays horizontally (column wise)
vsplit : Split array into multiple sub-arrays vertically (row wise)
dsplit : Split array into multiple sub-arrays along the 3rd axis (depth).
stack : Stack a sequence of arrays along a new axis.
hstack : Stack arrays in sequence horizontally (column wise)
vstack : Stack arrays in sequence vertically (row wise)
dstack : Stack arrays in sequence depth wise (along third dimension)
Expand Down
79 changes: 76 additions & 3 deletions numpy/core/shape_base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import division, absolute_import, print_function

__all__ = ['atleast_1d', 'atleast_2d', 'atleast_3d', 'vstack', 'hstack']
__all__ = ['atleast_1d', 'atleast_2d', 'atleast_3d', 'vstack', 'hstack',
'stack']

from . import numeric as _nx
from .numeric import array, asanyarray, newaxis
Expand Down Expand Up @@ -196,9 +197,10 @@ def vstack(tup):

See Also
--------
stack : Join a sequence of arrays along a new axis.
hstack : Stack arrays in sequence horizontally (column wise).
dstack : Stack arrays in sequence depth wise (along third dimension).
concatenate : Join a sequence of arrays together.
concatenate : Join a sequence of arrays along an existing axis.
vsplit : Split array into a list of multiple sub-arrays vertically.

Notes
Expand Down Expand Up @@ -246,9 +248,10 @@ def hstack(tup):

See Also
--------
stack : Join a sequence of arrays along a new axis.
vstack : Stack arrays in sequence vertically (row wise).
dstack : Stack arrays in sequence depth wise (along third axis).
concatenate : Join a sequence of arrays together.
concatenate : Join a sequence of arrays along an existing axis.
hsplit : Split array along second axis.

Notes
Expand All @@ -275,3 +278,73 @@ def hstack(tup):
return _nx.concatenate(arrs, 0)
else:
return _nx.concatenate(arrs, 1)

def stack(arrays, axis=0):
"""
Join a sequence of arrays along a new axis.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need more explanation here. For instance, in the examples, it isn't clear why np.stack((a, b)) is not the same as np.stack((a, b), axis=-1).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the missing explanation is that the axis argument refers to the axis position in the result array, not in the input arrays.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still needs a better explanation of where the new axis is located. It might also be better to follow the list insertion protocol.

In [7]: a = [1]*4

In [8]: a.insert(4, 2)

In [9]: a
Out[9]: [1, 1, 1, 1, 2]

In [10]: a = [1]*4

In [11]: a.insert(-1, 2)

In [12]: a
Out[12]: [1, 1, 1, 2, 1]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although it is nice to indicate appending the axis with a simple -1. In the current case, an explanation of how negative axis values are handled would help. For the list version you could do something like

In [26]: a = ones((2,2,3))

In [27]: newshape = list(a.shape)

In [28]: newshape.insert(-1, 1)

In [29]: a.reshape(newshape).shape
Out[29]: (2, 2, 1, 3)

8000

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about

def stack(arrays, axis=0):
    """
    Join a sequence of arrays along a new axis.

    The `axis` parameter specifies the index of the new axis in the
    dimensions of the result. For instance, if ``axis=0`` it will be the
    first dimension and if ``axis=-1`` it will be the last dimension.

    etc...


The `axis` parameter specifies the index of the new axis in the dimensions
of the result. For example, if ``axis=0`` it will be the first dimension
and if ``axis=-1`` it will be the last dimension.

.. versionadded:: 1.10.0

Parameters
----------
arrays : sequence of array_like
Each array must have the same shape.
axis : int, optional
The axis in the result array along which the input arrays are stacked.

Returns
-------
stacked : ndarray
The stacked array has one more dimension than the input arrays.

See Also
--------
concatenate : Join a sequence of arrays along an existing axis.
split : Split array into a list of multiple sub-arrays of equal size.

Examples
--------
>>> arrays = [np.random.randn(3, 4) for _ in range(10)]
>>> np.stack(arrays, axis=0).shape
(10, 3, 4)

>>> np.stack(arrays, axis=1).shape
(3, 10, 4)

>>> np.stack(arrays, axis=2).shape
(3, 4, 10)

>>> a = np.array([1, 2, 3])
>>> b = np.array([2, 3, 4])
>>> np.stack((a, b))
array([[1, 2, 3],
[2, 3, 4]])

>>> np.stack((a, b), axis=-1)
array([[1, 2],
[2, 3],
[3, 4]])

"""
arrays = [asanyarray(arr) for arr in arrays]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens with mixed subtypes?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could maybe check that all types are the same.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could do that by checking that set(type(a) for a in arrays) has one member.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens for subtypes is mostly dictated by the behavior of np.concatenate. I don't see much advantage in explicitly checking for consistent types here when none of the logic in this function relies on that.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the type checking should be left for concatenate (which does not currently do this all that well, but could be rewritten, e.g., using insert methods if present on the first member or so).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to ot D7AF hers. Learn more.

OK.

if not arrays:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

len(arrays) == 0 might be clearer.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO this is idiomatic way to write this in Python (but I don't care too much either way)

raise ValueError('need at least one array to stack')

shapes = set(arr.shape for arr in arrays)
if len(shapes) != 1:
raise ValueError('all input arrays must have the same shape')

result_ndim = arrays[0].ndim + 1
if not -result_ndim <= axis < result_ndim:
msg = 'axis {0} out of bounds [-{1}, {1})'.format(axis, result_ndim)
raise IndexError(msg)
if axis < 0:
axis += result_ndim

sl = (slice(None),) * axis + (_nx.newaxis,)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An alternative method, once you have the shape of the arrays, is

newshape = shape[:axis] + (1,) + shape[axis:]
expanded_arrays = [a.reshape(newshape) for a in arrays]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or, getting rid of expanded_arrays

_nx.concatenate([a.reshape(newshape) for a in arrays], axis=axis)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like using slicing for this operation rather than reshape because I know that slicing will always using a view rather than a copy. Though I suppose reshape is probably also safe when used in this way.

F438
expanded_arrays = [arr[sl] for arr in arrays]
return _nx.concatenate(expanded_arrays, axis=axis)
55 changes: 53 additions & 2 deletions numpy/core/tests/test_shape_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
import warnings
import numpy as np
from numpy.testing import (TestCase, assert_, assert_raises, assert_array_equal,
assert_equal, run_module_suite)
assert_equal, run_module_suite, assert_raises_regex)
from numpy.core import (array, arange, atleast_1d, atleast_2d, atleast_3d,
vstack, hstack, newaxis, concatenate)
vstack, hstack, newaxis, concatenate, stack)
from numpy.compat import long

class TestAtleast1d(TestCase):
Expand Down Expand Up @@ -246,5 +246,56 @@ def test_concatenate_sloppy0():
assert_raises(DeprecationWarning, concatenate, (r4, r3), 10)


def test_stack():
# 0d input
for input_ in [(1, 2, 3),
[np.int32(1), np.int32(2), np.int32(3)],
[np.array(1), np.array(2), np.array(3)]]:
assert_array_equal(stack(input_), [1, 2, 3])
# 1d input examples
a = np.array([1, 2, 3])
b = np.array([4, 5, 6])
r1 = array([[1, 2, 3], [4, 5, 6]])
assert_array_equal(np.stack((a, b)), r1)
assert_array_equal(np.stack((a, b), axis=1), r1.T)
# all input types
assert_array_equal(np.stack(list([a, b])), r1)
assert_array_equal(np.stack(array([a, b])), r1)
# all shapes for 1d input
arrays = [np.random.randn(3) for _ in range(10)]
axes = [0, 1, -1, -2]
expected_shapes = [(10, 3), (3, 10), (3, 10), (10, 3)]
for axis, expected_shape in zip(axes, expected_shapes):
assert_equal(np.stack(arrays, axis).shape, expected_shape)
assert_raises_regex(IndexError, 'out of bounds', stack, arrays, axis=2)
assert_raises_regex(IndexError, 'out of bounds', stack, arrays, axis=-3)
# all shapes for 2d input
arrays = [np.random.randn(3, 4) for _ in range(10)]
axes = [0, 1, 2, -1, -2, -3]
expected_shapes = [(10, 3, 4), (3, 10, 4), (3, 4, 10),
(3, 4, 10), (3, 10, 4), (10, 3, 4)]
for axis, expected_shape in zip(axes, expected_shapes):
assert_equal(np.stack(arrays, axis).shape, expected_shape)
# empty arrays
assert stack([[], [], []]).shape == (3, 0)
assert stack([[], [], []], axis=1).shape == (0, 3)
# edge cases
assert_raises_regex(ValueError, 'need at least one array', stack, [])
assert_raises_regex(ValueError, 'must have the same shape',
stack, [1, np.arange(3)])
assert_raises_regex(ValueError, 'must have the same shape',
stack, [np.arange(3), 1])
assert_raises_regex(ValueError, 'must have the same shape',
stack, [np.arange(3), 1], axis=1)
assert_raises_regex(ValueError, 'must have the same shape',
stack, [np.zeros((3, 3)), np.zeros(3)], axis=1)
assert_raises_regex(ValueError, 'must have the same shape',
stack, [np.arange(2), np.arange(3)])
# np.matrix
m = np.matrix([[1, 2], [3, 4]])
assert_raises_regex(ValueError, 'shape too large to be a matrix',
stack, [m, m])


if __name__ == "__main__":
run_module_suite()
2 changes: 1 addition & 1 deletion numpy/lib/function_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3706,7 +3706,7 @@ def insert(arr, obj, values, axis=None):
See Also
--------
append : Append elements at the end of an array.
concatenate : Join a sequence of arrays together.
concatenate : Join a sequence of arrays along an existing axis.
delete : Delete elements from an array.

Notes
Expand Down
2 changes: 1 addition & 1 deletion numpy/lib/index_tricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ class RClass(AxisConcatenator):
See Also
--------
concatenate : Join a sequence of arrays together.
concatenate : Join a sequence of arrays along an existing axis.
c_ : Translates slice objects to concatenation along the second axis.
Examples
Expand Down
6 changes: 4 additions & 2 deletions numpy/lib/shape_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,9 +338,10 @@ def dstack(tup):

See Also
--------
stack : Join a sequence of arrays along a new axis.
vstack : Stack along first axis.
hstack : Stack along second axis.
concatenate : Join arrays.
concatenate : Join a sequence of arrays along an existing axis.
dsplit : Split array along third axis.

Notes
Expand Down Expand Up @@ -477,7 +478,8 @@ def split(ary,indices_or_sections,axis=0):
hsplit : Split array into multiple sub-arrays horizontally (column-wise).
vsplit : Split array into multiple sub-arrays vertically (row wise).
dsplit : Split array into multiple sub-arrays along the 3rd axis (depth).
concatenate : Join arrays together.
concatenate : Join a sequence of arrays along an existing axis.
stack : Join a sequence of arrays along a new axis.
hstack : Stack arrays in sequence horizontally (column wise).
vstack : Stack arrays in sequence vertically (row wise).
dstack : Stack arrays in sequence depth wise (along third dimension).
Expand Down
0