-
-
Notifications
You must be signed in to change notification settings - Fork 11.1k
MAINT: Simplify block implementation #9667
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 8 commits
a5c6f0d
2dcc9aa
e787a9f
95adb77
07a3f43
19fc68c
997ac2c
7eb1044
ff7f726
6ecd2b4
1211b70
ffa6cf6
3ed6936
5f9f1fa
8a83a5f
5a0557a
c2b5be5
bd6729d
ad278f3
2c1734b
eaddf39
a5cbc93
a691f2d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -365,78 +365,71 @@ def stack(arrays, axis=0, out=None): | |
return _nx.concatenate(expanded_arrays, axis=axis, out=out) | ||
|
||
|
||
class _Recurser(object): | ||
""" | ||
Utility class for recursing over nested iterables | ||
""" | ||
def __init__(self, recurse_if): | ||
self.recurse_if = recurse_if | ||
|
||
def map_reduce(self, x, f_map=lambda x, **kwargs: x, | ||
f_reduce=lambda x, **kwargs: x, | ||
f_kwargs=lambda **kwargs: kwargs, | ||
**kwargs): | ||
""" | ||
Iterate over the nested list, applying: | ||
* ``f_map`` (T -> U) to items | ||
* ``f_reduce`` (Iterable[U] -> U) to mapped items | ||
|
||
For instance, ``map_reduce([[1, 2], 3, 4])`` is:: | ||
|
||
f_reduce([ | ||
f_reduce([ | ||
f_map(1), | ||
f_map(2) | ||
]), | ||
f_map(3), | ||
f_map(4) | ||
]]) | ||
|
||
|
||
State can be passed down through the calls with `f_kwargs`, | ||
to iterables of mapped items. When kwargs are passed, as in | ||
``map_reduce([[1, 2], 3, 4], **kw)``, this becomes:: | ||
|
||
kw1 = f_kwargs(**kw) | ||
kw2 = f_kwargs(**kw1) | ||
f_reduce([ | ||
f_reduce([ | ||
f_map(1), **kw2) | ||
f_map(2, **kw2) | ||
], **kw1), | ||
f_map(3, **kw1), | ||
f_map(4, **kw1) | ||
]], **kw) | ||
""" | ||
def f(x, **kwargs): | ||
if not self.recurse_if(x): | ||
return f_map(x, **kwargs) | ||
else: | ||
next_kwargs = f_kwargs(**kwargs) | ||
return f_reduce(( | ||
f(xi, **next_kwargs) | ||
for xi in x | ||
), **kwargs) | ||
return f(x, **kwargs) | ||
|
||
def walk(self, x, index=()): | ||
""" | ||
Iterate over x, yielding (index, value, entering), where | ||
|
||
* ``index``: a tuple of indices up to this point | ||
* ``value``: equal to ``x[index[0]][...][index[-1]]``. On the first iteration, is | ||
``x`` itself | ||
* ``entering``: bool. The result of ``recurse_if(value)`` | ||
""" | ||
do_recurse = self.recurse_if(x) | ||
yield index, x, do_recurse | ||
|
||
if not do_recurse: | ||
return | ||
for i, xi in enumerate(x): | ||
# yield from ... | ||
for v in self.walk(xi, index + (i,)): | ||
yield v | ||
def _block_check_depths_match(arrays, index=[]): | ||
# Recursive function checking that the depths of nested lists in `arrays` | ||
# all match. Mismatch raises a ValueError as described in the block | ||
# docstring below. | ||
# The entire index (rather than just the depth) is calculated for each | ||
# innermost list, in case an error needs to be raised, so that the index | ||
# of the offending list can be printed as part of the error. | ||
def format_index(index): | ||
idx_str = ''.join('[{}]'.format(i) for i in index if i is not None) | ||
return 'arrays' + idx_str | ||
if type(arrays) is tuple: | ||
# not strictly necessary, but saves us from: | ||
# - more than one way to do things - no point treating tuples like | ||
# lists | ||
# - horribly confusing behaviour that results when tuples are | ||
# treated like ndarray | ||
raise TypeError( | ||
'{} is a tuple. ' | ||
'Only lists can be used to arrange blocks, and np.block does ' | ||
'not allow implicit conversion from tuple to ndarray.'.format( | ||
format_index(index) | ||
) | ||
) | ||
elif type(arrays) is list and len(arrays) > 0: | ||
indexes = [_block_check_depths_match(arr, index + [i]) | ||
for i, arr in enumerate(arrays)] | ||
|
||
first_index = indexes[0] | ||
for i, index in enumerate(indexes): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Having this overwrite the |
||
if len(index) != len(first_index): | ||
raise ValueError( | ||
"List depths are mismatched. First element was at depth " | ||
"{}, but there is an element at depth {} ({})".format( | ||
len(first_index), | ||
len(index), | ||
format_index(index) | ||
) | ||
) | ||
return first_index | ||
elif type(arrays) is list and len(arrays) == 0: | ||
# We've 'bottomed out' on an empty list | ||
return index + [None] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you explain what's going on here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If an empty list is encountered the recursion needs to back up and an error message may need to be generated, if the depths don't match. I used the length of the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As an example (from the tests), if you do np.block([1, []]) the error that you get should be
At the end of the message only one index is printed even though the depth is two — that's the kind of situation I'm preparing for in the above line. |
||
else: | ||
# We've 'bottomed out' | ||
return index | ||
|
||
|
||
def _block(arrays, depth=0): | ||
def atleast_nd(a, ndim): | ||
# Ensures `a` has at least `ndim` dimensions by prepending | ||
# ones to `a.shape` as necessary | ||
return array(a, ndmin=ndim, copy=False, subok=True) | ||
|
||
if type(arrays) is list: | ||
if len(arrays) == 0: | ||
raise ValueError('Lists cannot be empty') | ||
arrs, list_ndims = zip(*(_block(arr, depth+1) for arr in arrays)) | ||
list_ndim = list_ndims[0] | ||
arr_ndim = max(arr.ndim for arr in arrs) | ||
ndim = max(list_ndim, arr_ndim) | ||
arrs = [atleast_nd(a, ndim) for a in arrs] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It makes me a little uneasy that this happens inside the recursion - I think there aren't many tests for block with different dimensioned arrays. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've simplified the logic a little and added two tests for cases where arr_ndim > list_ndim and the inner arrays have different ndims. |
||
return _nx.concatenate(arrs, axis=depth+ndim-list_ndim), list_ndim | ||
else: | ||
# We've 'bottomed out' | ||
return atleast_nd(arrays, depth), depth | ||
|
||
|
||
def block(arrays): | ||
|
@@ -587,81 +580,5 @@ def block(arrays): | |
|
||
|
||
""" | ||
def atleast_nd(x, ndim): | ||
x = asanyarray(x) | ||
diff = max(ndim - x.ndim, 0) | ||
return x[(None,)*diff + (Ellipsis,)] | ||
|
||
def format_index(index): | ||
return 'arrays' + ''.join('[{}]'.format(i) for i in index) | ||
|
||
rec = _Recurser(recurse_if=lambda x: type(x) is list) | ||
|
||
# ensure that the lists are all matched in depth | ||
list_ndim = None | ||
any_empty = False | ||
for index, value, entering in rec.walk(arrays): | ||
if type(value) is tuple: | ||
# not strictly necessary, but saves us from: | ||
# - more than one way to do things - no point treating tuples like | ||
# lists | ||
# - horribly confusing behaviour that results when tuples are | ||
# treated like ndarray | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any reason for removing this comment? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, I actually hadn't meant to delete that, will re-include it. |
||
raise TypeError( | ||
'{} is a tuple. ' | ||
'Only lists can be used to arrange blocks, and np.block does ' | ||
'not allow implicit conversion from tuple to ndarray.'.format( | ||
format_index(index) | ||
) | ||
) | ||
if not entering: | ||
curr_depth = len(index) | ||
elif len(value) == 0: | ||
curr_depth = len(index) + 1 | ||
any_empty = True | ||
else: | ||
continue | ||
|
||
if list_ndim is not None and list_ndim != curr_depth: | ||
raise ValueError( | ||
"List depths are mismatched. First element was at depth {}, " | ||
"but there is an element at depth {} ({})".format( | ||
list_ndim, | ||
curr_depth, | ||
format_index(index) | ||
) | ||
) | ||
list_ndim = curr_depth | ||
|
||
# do this here so we catch depth mismatches first | ||
if any_empty: | ||
raise ValueError('Lists cannot be empty') | ||
|
||
# convert all the arrays to ndarrays | ||
arrays = rec.map_reduce(arrays, | ||
f_map=asanyarray, | ||
f_reduce=list | ||
) | ||
|
||
# determine the maximum dimension of the elements | ||
elem_ndim = rec.map_reduce(arrays, | ||
f_map=lambda xi: xi.ndim, | ||
f_reduce=max | ||
) | ||
ndim = max(list_ndim, elem_ndim) | ||
|
||
# first axis to concatenate along | ||
first_axis = ndim - list_ndim | ||
|
||
# Make all the elements the same dimension | ||
arrays = rec.map_reduce(arrays, | ||
f_map=lambda xi: atleast_nd(xi, ndim), | ||
f_reduce=list | ||
) | ||
|
||
# concatenate innermost lists on the right, outermost on the left | ||
return rec.map_reduce(arrays, | ||
f_reduce=lambda xs, axis: _nx.concatenate(list(xs), axis=axis), | ||
f_kwargs=lambda axis: dict(axis=axis+1), | ||
axis=first_axis | ||
) | ||
_block_check_depths_match(arrays) | ||
return _block(arrays)[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This could do with a comment explaining what it returns, especially since its recursive.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To be clear, what I'm looking for is a docstring explaining what the
index
argument and return values are.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps this should be called
parent_index
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks again for helpful review comments. Do you reckon the docstring I've written is now sufficient?