8000 REV: Undo bad rebase in gh-8981 (7fdfdd6a52fc0761c0d45931247c5ed24802… · numpy/numpy@ae338e4 · GitHub
[go: up one dir, main page]

Skip to content

Commit ae338e4

Browse files
committed
REV: Undo bad rebase in gh-8981 (7fdfdd6)
This restores the changes in gh-9667 that were overwritten.
1 parent 6789c25 commit ae338e4

File tree

2 files changed

+110
-148
lines changed

2 files changed

+110
-148
lines changed

numpy/core/shape_base.py

Lines changed: 88 additions & 148 deletions
Original file line numberDiff line numberDiff line change
@@ -365,78 +365,93 @@ def stack(arrays, axis=0, out=None):
365365
return _nx.concatenate(expanded_arrays, axis=axis, out=out)
366366

367367

368-
class _Recurser(object):
368+
def _block_check_depths_match(arrays, parent_index=[]):
369369
"""
370-
Utility class for recursing over nested iterables
370+
Recursive function checking that the depths of nested lists in `arrays`
371+
all match. Mismatch raises a ValueError as described in the block
372+
docstring below.
373+
374+
The entire index (rather than just the depth) needs to be calculated
375+
for each innermost list, in case an error needs to be raised, so that
376+
the index of the offending list can be printed as part of the error.
377+
378+
The parameter `parent_index` is the full index of `arrays` within the
379+
nested lists passed to _block_check_depths_match at the top of the
380+
recursion.
381+
The return value is a pair. The first item returned is the full index
382+
of an element (specifically the first element) from the bottom of the
383+
nesting in `arrays`. An empty list at the bottom of the nesting is
384+
represented by a `None` index.
385+
The second item is the maximum of the ndims of the arrays nested in
386+
`arrays`.
371387
"""
372-
def __init__(self, recurse_if):
373-
self.recurse_if = recurse_if
374-
375-
def map_reduce(self, x, f_map=lambda x, **kwargs: x,
376-
f_reduce=lambda x, **kwargs: x,
377-
f_kwargs=lambda **kwargs: kwargs,
378-
**kwargs):
379-
"""
380-
Iterate over the nested list, applying:
381-
* ``f_map`` (T -> U) to items
382-
* ``f_reduce`` (Iterable[U] -> U) to mapped items
383-
384-
For instance, ``map_reduce([[1, 2], 3, 4])`` is::
385-
386-
f_reduce([
387-
f_reduce([
388-
f_map(1),
389-
f_map(2)
390-
]),
391-
f_map(3),
392-
f_map(4)
393-
]])
394-
395-
396-
State can be passed down through the calls with `f_kwargs`,
397-
to iterables of mapped items. When kwargs are passed, as in
398-
``map_reduce([[1, 2], 3, 4], **kw)``, this becomes::
399-
400-
kw1 = f_kwargs(**kw)
401-
kw2 = f_kwargs(**kw1)
402-
f_reduce([
403-
f_reduce([
404-
f_map(1), **kw2)
405-
f_map(2, **kw2)
406-
], **kw1),
407-
f_map(3, **kw1),
408-
f_map(4, **kw1)
409-
]], **kw)
410-
"""
411-
def f(x, **kwargs):
412-
if not self.recurse_if(x):
413-
return f_map(x, **kwargs)
414-
else:
415-
next_kwargs = f_kwargs(**kwargs)
416-
return f_reduce((
417-
f(xi, **next_kwargs)
418-
for xi in x
419-
), **kwargs)
420-
return f(x, **kwargs)
421-
422-
def walk(self, x, index=()):
423-
"""
424-
Iterate over x, yielding (index, value, entering), where
425-
426-
* ``index``: a tuple of indices up to this point
427-
* ``value``: equal to ``x[index[0]][...][index[-1]]``. On the first iteration, is
428-
``x`` itself
429-
* ``entering``: bool. The result of ``recurse_if(value)``
430-
"""
431-
do_recurse = self.recurse_if(x)
432-
yield index, x, do_recurse
433-
434-
if not do_recurse:
435-
return
436-
for i, xi in enumerate(x):
437-
# yield from ...
438-
for v in self.walk(xi, index + (i,)):
439-
yield v
388+
def format_index(index):
389+
idx_str = ''.join('[{}]'.format(i) for i in index if i is not None)
390+
return 'arrays' + idx_str
391+
if type(arrays) is tuple:
392+
# not strictly necessary, but saves us from:
393+
# - more than one way to do things - no point treating tuples like
394+
# lists
395+
# - horribly confusing behaviour that results when tuples are
396+
# treated like ndarray
397+
raise TypeError(
398+
'{} is a tuple. '
399+
'Only lists can be used to arrange blocks, and np.block does '
400+
'not allow implicit conversion from tuple to ndarray.'.format(
401+
format_index(parent_index)
402+
)
403+
)
404+
elif type(arrays) is list and len(arrays) > 0:
405+
idxs_ndims = (_block_check_depths_match(arr, parent_index + [i])
406+
for i, arr in enumerate(arrays))
407+
408+
first_index, max_arr_ndim = next(idxs_ndims)
409+< F438 /span>
for index, ndim in idxs_ndims:
410+
if ndim > max_arr_ndim:
411+
max_arr_ndim = ndim
412+
if len(index) != len(first_index):
413+
raise ValueError(
414+
"List depths are mismatched. First element was at depth "
415+
"{}, but there is an element at depth {} ({})".format(
416+
len(first_index),
417+
len(index),
418+
format_index(index)
419+
)
420+
)
421+
return first_index, max_arr_ndim
422+
elif type(arrays) is list and len(arrays) == 0:
423+
# We've 'bottomed out' on an empty list
424+
return parent_index + [None], 0
425+
else:
426+
# We've 'bottomed out' - arrays is either a scalar or an array
427+
return parent_index, _nx.ndim(arrays)
428+
429+
430+
def _block(arrays, max_depth, result_ndim):
431+
"""
432+
Internal implementation of block. `arrays` is the argument passed to
433+
block. `max_depth` is the depth of nested lists within `arrays` and
434+
`result_ndim` is the greatest of the dimensions of the arrays in
435+
`arrays` and the depth of the lists in `arrays` (see block docstring
436+
for details).
437+
"""
438+
def atleast_nd(a, ndim):
439+
# Ensures `a` has at least `ndim` dimensions by prepending
440+
# ones to `a.shape` as necessary
441+
return array(a, ndmin=ndim, copy=False, subok=True)
442+
443+
def block_recursion(arrays, depth=0):
444+
if depth < max_depth:
445+
if len(arrays) == 0:
446+
raise ValueError('Lists cannot be empty')
447+
arrs = [block_recursion(arr, depth+1) for arr in arrays]
448+
return _nx.concatenate(arrs, axis=-(max_depth-depth))
449+
else:
450+
# We've 'bottomed out' - arrays is either a scalar or an array
451+
# type(arrays) is not list
452+
return atleast_nd(arrays, result_ndim)
453+
454+
return block_recursion(arrays)
440455

441456

442457
def block(arrays):
@@ -587,81 +602,6 @@ def block(arrays):
587602
588603
589604
"""
590-
def atleast_nd(x, ndim):
591-
x = asanyarray(x)
592-
diff = max(ndim - x.ndim, 0)
593-
return x[(None,)*diff + (Ellipsis,)]
594-
595-
def format_index(index):
596-
return 'arrays' + ''.join('[{}]'.format(i) for i in index)
597-
598-
rec = _Recurser(recurse_if=lambda x: type(x) is list)
599-
600-
# ensure that the lists are all matched in depth
601-
list_ndim = None
602-
any_empty = False
603-
for index, value, entering in rec.walk(arrays):
604-
if type(value) is tuple:
605-
# not strictly necessary, but saves us from:
606-
# - more than one way to do things - no point treating tuples like
607-
# lists
608-
# - horribly confusing behaviour that results when tuples are
609-
# treated like ndarray
610-
raise TypeError(
611-
'{} is a tuple. '
612-
'Only lists can be used to arrange blocks, and np.block does '
613-
'not allow implicit conversion from tuple to ndarray.'.format(
614-
format_index(index)
615-
)
616-
)
617-
if not entering:
618-
curr_depth = len(index)
619-
elif len(value) == 0:
620-
curr_depth = len(index) + 1
621-
any_empty = True
622-
else:
623-
continue
624-
625-
if list_ndim is not None and list_ndim != curr_depth:
626-
raise ValueError(
627-
"List depths are mismatched. First element was at depth {}, "
628-
"but there is an element at depth {} ({})".format(
629-
list_ndim,
630-
curr_depth,
631-
format_index(index)
632-
)
633-
)
634-
list_ndim = curr_depth
635-
636-
# do this here so we catch depth mismatches first
637-
if any_empty:
638-
raise ValueError('Lists cannot be empty')
639-
640-
# convert all the arrays to ndarrays
641-
arrays = rec.map_reduce(arrays,
642-
f_map=asanyarray,
643-
f_reduce=list
644-
)
645-
646-
# determine the maximum dimension of the elements
647-
elem_ndim = rec.map_reduce(arrays,
648-
f_map=lambda xi: xi.ndim,
649-
f_reduce=max
650-
)
651-
ndim = max(list_ndim, elem_ndim)
652-
653-
# first axis to concatenate along
654-
first_axis = ndim - list_ndim
655-
656-
# Make all the elements the same dimension
657-
arrays = rec.map_reduce(arrays,
658-
f_map=lambda xi: atleast_nd(xi, ndim),
659-
f_reduce=list
660-
)
661-
662-
# concatenate innermost lists on the right, outermost on the left
663-
return rec.map_reduce(arrays,
664-
f_reduce=lambda xs, axis: _nx.concatenate(list(xs), axis=axis),
665-
f_kwargs=lambda axis: dict(axis=axis+1),
666-
axis=first_axis
667-
)
605+
bottom_index, arr_ndim = _block_check_depths_match(arrays)
606+
list_ndim = len(bottom_index)
607+
return _block(arrays, list_ndim, max(arr_ndim, list_ndim))

numpy/core/tests/test_shape_base.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -560,6 +560,28 @@ def test_tuple(self):
560560
assert_raises_regex(TypeError, 'tuple', np.block, ([1, 2], [3, 4]))
561561
assert_raises_regex(TypeError, 'tuple', np.block, [(1, 2), (3, 4)])
562562

563+
def test_different_ndims(self):
564+
a = 1.
565+
b = 2 * np.ones((1, 2))
566+
c = 3 * np.ones((1, 1, 3))
567+
568+
result = np.block([a, b, c])
569+
expected = np.array([[[1., 2., 2., 3., 3., 3.]]])
570+
571+
assert_equal(result, expected)
572+
573+
def test_different_ndims_depths(self):
574+
a = 1.
575+
b = 2 * np.ones((1, 2))
576+
c = 3 * np.ones((1, 2, 3))
577+
578+
result = np.block([[a, b], [c]])
579+
expected = np.array([[[1., 2., 2.],
580+
[3., 3., 3.],
581+
[3., 3., 3.]]])
582+
583+
assert_equal(result, expected)
584+
563585

564586
if __name__ == "__main__":
565587
run_module_suite()

0 commit comments

Comments
 (0)
0