8000 ENH: add np.stack · numpy/numpy@a9c7d8b · GitHub
[go: up one dir, main page]

Skip to content

Commit a9c7d8b

Browse files
committed
ENH: add np.stack
The motivation here is to present a uniform and N-dimensional interface for joining arrays along a new axis, similarly to how `concatenate` provides a uniform and N-dimensional interface for joining arrays along an existing axis. Background ~~~~~~~~~~ Currently, users can choose between `hstack`, `vstack`, `column_stack` and `dstack`, but none of these functions handle N-dimensional input. In my opinion, it's also difficult to keep track of the differences between these methods and to predict how they will handle input with different dimensions. In the past, my preferred approach has been to either construct the result array explicitly and use indexing for assignment, to or use `np.array` to stack along the first dimension and then use `transpose` (or a similar method) to reorder dimensions if necessary. This is pretty awkward. I brought this proposal up a few weeks on the numpy-discussion list: http://mail.scipy.org/pipermail/numpy-discussion/2015-February/072199.html I also received positive feedback on Twitter: https://twitter.com/shoyer/status/565937244599377920 Implementation notes ~~~~~~~~~~~~~~~~~~~~ The one line summaries for `concatenate` and `stack` have been (re)written to mirror each other, and to make clear that the distinction between these functions is whether they join over an existing or new axis. In general, I've tweaked the documentation and docstrings with an eye toward pointing users to `concatenate`/`stack`/`split` as a fundamental set of basic array manipulation routines, and away from `array_split`/`{h,v,d}split`/`{h,v,d,column_}stack` I put this implementation in `numpy.core.shape_base` alongside `hstack`/`vstack`, but it appears that there is also a `numpy.lib.shape_base` module that contains another larger set of functions, including `dstack`. I'm not really sure where this belongs (or if it even matters). Finally, it might be a good idea to write a masked array version of `stack`. But I don't use masked arrays, so I'm not well motivated to do that.
1 parent 2e016ac commit a9c7d8b

File tree

8 files changed

+143
-12
lines changed

8 files changed

+143
-12
lines changed

doc/release/1.10.0-notes.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ Highlights
1111
* Addition of *np.linalg.multi_dot*: compute the dot product of two or more
1212
arrays in a single function call, while automatically selecting the fastest
1313
evaluation order.
14+
* The new function `np.stack` provides a general interface for joining a
15+
sequence of arrays along a new axis, complementing `np.concatenate` for
16+
joining along an existing axis.
1417
* Addition of `nanprod` to the set of nanfunctions.
1518

1619

doc/source/reference/routines.array-manipulation.rst

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,9 @@ Joining arrays
6464
.. autosummary::
6565
:toctree: generated/
6666

67-
column_stack
6867
concatenate
68+
stack
69+
column_stack
6970
dstack
7071
hstack
7172
vstack
@@ -75,10 +76,10 @@ Splitting arrays
7576
.. autosummary::
7677
:toctree: generated/
7778

79+
split
7880
array_split
7981
dsplit
8082
hsplit
81-
split
8283
vsplit
8384

8485
Tiling arrays

numpy/add_newdocs.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1142,7 +1142,7 @@ def luf(lamdaexpr, *args, **kwargs):
11421142
"""
11431143
concatenate((a1, a2, ...), axis=0)
11441144
1145-
Join a sequence of arrays together.
1145+
Join a sequence of arrays along an existing axis.
11461146
11471147
Parameters
11481148
----------
@@ -1166,6 +1166,7 @@ def luf(lamdaexpr, *args, **kwargs):
11661166
hsplit : Split array into multiple sub-arrays horizontally (column wise)
11671167
vsplit : Split array into multiple sub-arrays vertically (row wise)
11681168
dsplit : Split array into multiple sub-arrays along the 3rd axis (depth).
1169+
stack : Stack a sequence of arrays along a new axis.
11691170
hstack : Stack arrays in sequence horizontally (column wise)
11701171
vstack : Stack arrays in sequence vertically (row wise)
11711172
dstack : Stack arrays in sequence depth wise (along third dimension)

numpy/core/shape_base.py

Lines changed: 76 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import division, absolute_import, print_function
22

3-
__all__ = ['atleast_1d', 'atleast_2d', 'atleast_3d', 'vstack', 'hstack']
3+
__all__ = ['atleast_1d', 'atleast_2d', 'atleast_3d', 'vstack', 'hstack',
4+
'stack']
45

56
from . import numeric as _nx
67
from .numeric import array, asanyarray, newaxis
@@ -196,9 +197,10 @@ def vstack(tup):
196197
197198
See Also
198199
--------
200+
stack : Join a sequence of arrays along a new axis.
199201
hstack : Stack arrays in sequence horizontally (column wise).
200202
dstack : Stack arrays in sequence depth wise (along third dimension).
201-
concatenate : Join a sequence of arrays together.
203+
concatenate : Join a sequence of arrays along an existing axis.
202204
vsplit : Split array into a list of multiple sub-arrays vertically.
203205
204206
Notes
@@ -246,9 +248,10 @@ def hstack(tup):
246248
247249
See Also
248250
--------
251+
stack : Join a sequence of arrays along a new axis.
249252
vstack : Stack arrays in sequence vertically (row wise).
250253
dstack : Stack arrays in sequence depth wise (along third axis).
251-
concatenate : Join a sequence of arrays together.
254+
concatenate : Join a sequence of arrays along an existing axis.
252255
hsplit : Split array along second axis.
253256
254257
Notes
@@ -275,3 +278,73 @@ def hstack(tup):
275278
return _nx.concatenate(arrs, 0)
276279
else:
277280
return _nx.concatenate(arrs, 1)
281+
282+
def stack(arrays, axis=0):
283+
"""
284+
Join a sequence of arrays along a new axis.
285+
286+
.. versionadded:: 1.10.0
287+
288+
The `axis` parameter specifies the index of the new axis in the dimensions
289+
of the result. For example, if ``axis=0`` it will be the first dimension
290+
and if ``axis=-1`` it will be the last dimension.
291+
292+
Parameters
293+
----------
294+
arrays : sequence of array_like
295+
Each array must have the same shape.
296+
axis : int, optional
297+
The axis in the result array along which the input arrays are stacked.
298+
299+
Returns
300+
-------
301+
stacked : ndarray
302+
The stacked array has one more dimension than the input arrays.
303+
304+
See Also
305+
--------
306+
concatenate : Join a sequence of arrays along an existing axis.
307+
split : Split array into a list of multiple sub-arrays of equal size.
308+
309+
Examples
310+
--------
311+
>>> arrays = [np.random.randn(3, 4) for _ in range(10)]
312+
>>> np.stack(arrays, axis=0).shape
313+
(10, 3, 4)
314+
315+
>>> np.stack(arrays, axis=1).shape
316+
(3, 10, 4)
317+
318+
>>> np.stack(arrays, axis=2).shape
319+
(3, 4, 10)
320+
321+
>>> a = np.array([1, 2, 3])
322+
>>> b = np.array([2, 3, 4])
323+
>>> np.stack((a, b))
324+
array([[1, 2, 3],
325+
[2, 3, 4]])
326+
327+
>>> np.stack((a, b), axis=-1)
328+
array([[1, 2],
329+
[2, 3],
330+
[3, 4]])
331+
332+
"""
333+
arrays = [asanyarray(arr) for arr in arrays]
334+
if not arrays:
335+
raise ValueError('need at least one array to stack')
336+
337+
shapes = set(arr.shape for arr in arrays)
338+
if len(shapes) != 1:
339+
raise ValueError('all input arrays must have the same shape')
340+
341+
result_ndim = arrays[0].ndim + 1
342+
if not -result_ndim <= axis < result_ndim:
343+
msg = 'axis {0} out of bounds [-{1}, {1})'.format(axis, result_ndim)
344+
raise IndexError(msg)
345+
if axis < 0:
346+
axis += result_ndim
347+
348+
sl = (slice(None),) * axis + (_nx.newaxis,)
349+
expanded_arrays = [arr[sl] for arr in arrays]
350+
return _nx.concatenate(expanded_arrays, axis=axis)

numpy/core/tests/test_shape_base.py

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
import warnings
44
import numpy as np
55
from numpy.testing import (TestCase, assert_, assert_raises, assert_array_equal,
6-
assert_equal, run_module_suite)
6+
assert_equal, run_module_suite, assert_raises_regex)
77
from numpy.core import (array, arange, atleast_1d, atleast_2d, atleast_3d,
8-
vstack, hstack, newaxis, concatenate)
8+
vstack, hstack, newaxis, concatenate, stack)
99
from numpy.compat import long
1010

1111
class TestAtleast1d(TestCase):
@@ -246,5 +246,56 @@ def test_concatenate_sloppy0():
246246
assert_raises(DeprecationWarning, concatenate, (r4, r3), 10)
247247

248248

249+
def test_stack():
250+
# 0d input
251+
for input_ in [(1, 2, 3),
252+
[np.int32(1), np.int32(2), np.int32(3)],
253+
[np.array(1), np.array(2), np.array(3)]]:
254+
assert_array_equal(stack(input_), [1, 2, 3])
255+
# 1d input examples
256+
a = np.array([1, 2, 3])
257+
b = np.array([4, 5, 6])
258+
r1 = array([[1, 2, 3], [4, 5, 6]])
259+
assert_array_equal(np.stack((a, b)), r1)
260+
assert_array_equal(np.stack((a, b), axis=1), r1.T)
261+
# all input types
262+
assert_array_equal(np.stack(list([a, b])), r1)
263+
assert_array_equal(np.stack(array([a, b])), r1)
264+
# all shapes for 1d input
265+
arrays = [np.random.randn(3) for _ in range(10)]
266+
axes = [0, 1, -1, -2]
267+
expected_shapes = [(10, 3), (3, 10), (3, 10), (10, 3)]
268+
for axis, expected_shape in zip(axes, expected_shapes):
269+
assert_equal(np.stack(arrays, axis).shape, expected_shape)
270+
assert_raises_regex(IndexError, 'out of bounds', stack, arrays, axis=2)
271+
assert_raises_regex(IndexError, 'out of bounds', stack, arrays, axis=-3)
272+
# all shapes for 2d input
273+
arrays = [np.random.randn(3, 4) for _ in range(10)]
274+
axes = [0, 1, 2, -1, -2, -3]
275+
expected_shapes = [(10, 3, 4), (3, 10, 4), (3, 4, 10),
276+
(3, 4, 10), (3, 10, 4), (10, 3, 4)]
277+
for axis, expected_shape in zip(axes, expected_shapes):
278+
assert_equal(np.stack(arrays, axis).shape, expected_shape)
279+
# empty arrays
280+
assert stack([[], [], []]).shape == (3, 0)
281+
assert stack([[], [], []], axis=1).shape == (0, 3)
282+
# edge cases
283+
assert_raises_regex(ValueError, 'need at least one array', stack, [])
284+
assert_raises_regex(ValueError, 'must have the same shape',
285+
stack, [1, np.arange(3)])
286+
assert_raises_regex(ValueError, 'must have the same shape',
287+
stack, [np.arange(3), 1])
288+
assert_raises_regex(ValueError, 'must have the same shape',
289+
stack, [np.arange(3), 1], axis=1)
290+
assert_raises_regex(ValueError, 'must have the same shape',
291+
stack, [np.zeros((3, 3)), np.zeros(3)], axis=1)
292+
assert_raises_regex(ValueError, 'must have the same shape',
293+
stack, [np.arange(2), np.arange(3)])
294+
# np.matrix
295+
m = np.matrix([[1, 2], [3, 4]])
296+
assert_raises_regex(ValueError, 'shape too large to be a matrix',
297+
stack, [m, m])
298+
299+
249300
if __name__ == "__main__":
250301
run_module_suite()

numpy/lib/function_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3706,7 +3706,7 @@ def insert(arr, obj, values, axis=None):
37063706
See Also
37073707
--------
37083708
append : Append elements at the end of an array.
3709-
concatenate : Join a sequence of arrays together.
3709+
concatenate : Join a sequence of arrays along an existing axis.
37103710
delete : Delete elements from an array.
37113711
37123712
Notes

numpy/lib/index_tricks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,7 @@ class RClass(AxisConcatenator):
404404
405405
See Also
406406
--------
407-
concatenate : Join a sequence of arrays together.
407+
concatenate : Join a sequence of arrays along an existing axis.
408408
c_ : Translates slice objects to concatenation along the second axis.
409409
410410
Examples

numpy/lib/shape_base.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -338,9 +338,10 @@ def dstack(tup):
338338
339339
See Also
340340
--------
341+
stack : Join a sequence of arrays along a new axis.
341342
vstack : Stack along first axis.
342343
hstack : Stack along second axis.
343-
concatenate : Join arrays.
344+
concatenate : Join a sequence of arrays along an existing axis.
344345
dsplit : Split array along third axis.
345346
346347
Notes
@@ -477,7 +478,8 @@ def split(ary,indices_or_sections,axis=0):
477478
hsplit : Split array into multiple sub-arrays horizontally (column-wise).
478479
vsplit : Split array into multiple sub-arrays vertically (row wise).
479480
dsplit : Split array into multiple sub-arrays along the 3rd axis (depth).
480-
concatenate : Join arrays together.
481+
concatenate : Join a sequence of arrays along an existing axis.
482+
stack : Join a sequence of arrays along a new axis.
481483
hstack : Stack arrays in sequence horizontally (column wise).
482484
vstack : Stack arrays in sequence vertically (row wise).
483485
dstack : Stack arrays in sequence depth wise (along third dimension).

0 commit comments

Comments
 (0)
0