8000 MAINT: Simplify block implementation by j-towns · Pull Request #9667 · numpy/numpy · GitHub
[go: up one dir, main page]

Skip to content

MAINT: Simplify block implementation #9667

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
Nov 12, 2017
Merged
Changes from 8 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
217 changes: 67 additions & 150 deletions numpy/core/shape_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,78 +365,71 @@ def stack(arrays, axis=0, out=None):
return _nx.concatenate(expanded_arrays, axis=axis, out=out)


class _Recurser(object):
"""
Utility class for recursing over nested iterables
"""
def __init__(self, recurse_if):
self.recurse_if = recurse_if

def map_reduce(self, x, f_map=lambda x, **kwargs: x,
f_reduce=lambda x, **kwargs: x,
f_kwargs=lambda **kwargs: kwargs,
**kwargs):
"""
Iterate over the nested list, applying:
* ``f_map`` (T -> U) to items
* ``f_reduce`` (Iterable[U] -> U) to mapped items

For instance, ``map_reduce([[1, 2], 3, 4])`` is::

f_reduce([
f_reduce([
f_map(1),
f_map(2)
]),
f_map(3),
f_map(4)
]])


State can be passed down through the calls with `f_kwargs`,
to iterables of mapped items. When kwargs are passed, as in
``map_reduce([[1, 2], 3, 4], **kw)``, this becomes::

kw1 = f_kwargs(**kw)
kw2 = f_kwargs(**kw1)
f_reduce([
f_reduce([
f_map(1), **kw2)
f_map(2, **kw2)
], **kw1),
f_map(3, **kw1),
f_map(4, **kw1)
]], **kw)
"""
def f(x, **kwargs):
if not self.recurse_if(x):
return f_map(x, **kwargs)
else:
next_kwargs = f_kwargs(**kwargs)
return f_reduce((
f(xi, **next_kwargs)
for xi in x
), **kwargs)
return f(x, **kwargs)

def walk(self, x, index=()):
"""
Iterate over x, yielding (index, value, entering), where

* ``index``: a tuple of indices up to this point
* ``value``: equal to ``x[index[0]][...][index[-1]]``. On the first iteration, is
``x`` itself
* ``entering``: bool. The result of ``recurse_if(value)``
"""
do_recurse = self.recurse_if(x)
yield index, x, do_recurse

if not do_recurse:
return
for i, xi in enumerate(x):
# yield from ...
for v in self.walk(xi, index + (i,)):
yield v
def _block_check_depths_match(arrays, index=[]):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could do with a comment explaining what it returns, especially since its recursive.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be clear, what I'm looking for is a docstring explaining what the index argument and return values are.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps this should be called parent_index?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks again for helpful review comments. Do you reckon the docstring I've written is now sufficient?

# Recursive function checking that the depths of nested lists in `arrays`
# all match. Mismatch raises a ValueError as described in the block
# docstring below.
# The entire index (rather than just the depth) is calculated for each
# innermost list, in case an error needs to be raised, so that the index
# of the offending list can be printed as part of the error.
def format_index(index):
idx_str = ''.join('[{}]'.format(i) for i in index if i is not None)
return 'arrays' + idx_str
if type(arrays) is tuple:
# not strictly necessary, but saves us from:
# - more than one way to do things - no point treating tuples like
# lists
# - horribly confusing behaviour that results when tuples are
# treated like ndarray
raise TypeError(
'{} is a tuple. '
'Only lists can be used to arrange blocks, and np.block does '
'not allow implicit conversion from tuple to ndarray.'.format(
format_index(index)
)
)
elif type(arrays) is list and len(arrays) > 0:
indexes = [_block_check_depths_match(arr, index + [i])
for i, arr in enumerate(arrays)]

first_index = indexes[0]
for i, index in enumerate(indexes):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having this overwrite the index parameter is confusing

if len(index) != len(first_index):
raise ValueError(
"List depths are mismatched. First element was at depth "
"{}, but there is an element at depth {} ({})".format(
len(first_index),
len(index),
format_index(index)
)
)
return first_index
elif type(arrays) is list and len(arrays) == 0:
# We've 'bottomed out' on an empty list
return index + [None]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain what's going on here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If an empty list is encountered the recursion needs to back up and an error message may need to be generated, if the depths don't match.

I used the length of the index list for the depth info in the error message. On this line I'm making sure that the length of index indeed reflects the depth of nested lists, but using None to flag that this actual index shouldn't be included in the error message. This is parsed at the end of this line. I could have used some other value (such as -1) to flag an empty list.

Copy link
Contributor Author
@j-towns j-towns Sep 10, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As an example (from the tests), if you do

np.block([1, []])

the error that you get should be

ValueError: List depths are mismatched. First element was at depth 1, but there is an element at depth 2 (arrays[1])

At the end of the message only one index is printed even though the depth is two — that's the kind of situation I'm preparing for in the above line.

else:
# We've 'bottomed out'
return index


def _block(arrays, depth=0):
def atleast_nd(a, ndim):
# Ensures `a` has at least `ndim` dimensions by prepending
# ones to `a.shape` as necessary
return array(a, ndmin=ndim, copy=False, subok=True)

if type(arrays) is list:
if len(arrays) == 0:
raise ValueError('Lists cannot be empty')
arrs, list_ndims = zip(*(_block(arr, depth+1) for arr in arrays))
list_ndim = list_ndims[0]
arr_ndim = max(arr.ndim for arr in arrs)
ndim = max(list_ndim, arr_ndim)
arrs = [atleast_nd(a, ndim) for a in arrs]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It makes me a little uneasy that this happens inside the recursion - I think there aren't many tests for block with different dimensioned arrays.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've simplified the logic a little and added two tests for cases where arr_ndim > list_ndim and the inner arrays have different ndims.

return _nx.concatenate(arrs, axis=depth+ndim-list_ndim), list_ndim
else:
# We've 'bottomed out'
return atleast_nd(arrays, depth), depth


def block(arrays):
Expand Down Expand Up @@ -587,81 +580,5 @@ def block(arrays):


"""
def atleast_nd(x, ndim):
x = asanyarray(x)
diff = max(ndim - x.ndim, 0)
return x[(None,)*diff + (Ellipsis,)]

def format_index(index):
return 'arrays' + ''.join('[{}]'.format(i) for i in index)

rec = _Recurser(recurse_if=lambda x: type(x) is list)

# ensure that the lists are all matched in depth
list_ndim = None
any_empty = False
for index, value, entering in rec.walk(arrays):
if type(value) is tuple:
# not strictly necessary, but saves us from:
# - more than one way to do things - no point treating tuples like
# lists
# - horribly confusing behaviour that results when tuples are
# treated like ndarray
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason for removing this comment?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I actually hadn't meant to delete that, will re-include it.

raise TypeError(
'{} is a tuple. '
'Only lists can be used to arrange blocks, and np.block does '
'not allow implicit conversion from tuple to ndarray.'.format(
format_index(index)
)
)
if not entering:
curr_depth = len(index)
elif len(value) == 0:
curr_depth = len(index) + 1
any_empty = True
else:
continue

if list_ndim is not None and list_ndim != curr_depth:
raise ValueError(
"List depths are mismatched. First element was at depth {}, "
"but there is an element at depth {} ({})".format(
list_ndim,
curr_depth,
format_index(index)
)
)
list_ndim = curr_depth

# do this here so we catch depth mismatches first
if any_empty:
raise ValueError('Lists cannot be empty')

# convert all the arrays to ndarrays
arrays = rec.map_reduce(arrays,
f_map=asanyarray,
f_reduce=list
)

# determine the maximum dimension of the elements
elem_ndim = rec.map_reduce(arrays,
f_map=lambda xi: xi.ndim,
f_reduce=max
)
ndim = max(list_ndim, elem_ndim)

# first axis to concatenate along
first_axis = ndim - list_ndim

# Make all the elements the same dimension
arrays = rec.map_reduce(arrays,
f_map=lambda xi: atleast_nd(xi, ndim),
f_reduce=list
)

# concatenate innermost lists on the right, outermost on the left
return rec.map_reduce(arrays,
f_reduce=lambda xs, axis: _nx.concatenate(list(xs), axis=axis),
f_kwargs=lambda axis: dict(axis=axis+1),
axis=first_axis
)
_block_check_depths_match(arrays)
return _block(arrays)[0]
0