-
-
Notifications
You must be signed in to change notification settings - Fork 11.1k
MAINT: Simplify block implementation #9667
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 22 commits
a5c6f0d
2dcc9aa
e787a9f
95adb77
07a3f43
19fc68c
997ac2c
7eb1044
ff7f726
6ecd2b4
1211b70
ffa6cf6
3ed6936
5f9f1fa
8a83a5f
5a0557a
c2b5be5
bd6729d
ad278f3
2c1734b
eaddf39
a5cbc93
a691f2d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -365,78 +365,93 @@ def stack(arrays, axis=0, out=None): | |
return _nx.concatenate(expanded_arrays, axis=axis, out=out) | ||
|
||
|
||
class _Recurser(object): | ||
def _block_check_depths_match(arrays, parent_index=[]): | ||
""" | ||
Utility class for recursing over nested iterables | ||
Recursive function checking that the depths of nested lists in `arrays` | ||
all match. Mismatch raises a ValueError as described in the block | ||
docstring below. | ||
|
||
The entire index (rather than just the depth) needs to be calculated | ||
for each innermost list, in case an error needs to be raised, so that | ||
the index of the offending list can be printed as part of the error. | ||
|
||
The parameter `parent_index` is the full index of `arrays` within the | ||
nested lists passed to _block_check_depths_match at the top of the | ||
recursion. | ||
The return value is a pair. The first item returned is the full index | ||
of an element (specifically the first element) from the bottom of the | ||
nesting in `arrays`. An empty list at the bottom of the nesting is | ||
represented by a `None` index. | ||
The second item is the maximum of the ndims of the arrays nested in | ||
`arrays`. | ||
""" | ||
def __init__(self, recurse_if): | ||
self.recurse_if = recurse_if | ||
|
||
def map_reduce(self, x, f_map=lambda x, **kwargs: x, | ||
f_reduce=lambda x, **kwargs: x, | ||
f_kwargs=lambda **kwargs: kwargs, | ||
**kwargs): | ||
""" | ||
Iterate over the nested list, applying: | ||
* ``f_map`` (T -> U) to items | ||
* ``f_reduce`` (Iterable[U] -> U) to mapped items | ||
|
||
For instance, ``map_reduce([[1, 2], 3, 4])`` is:: | ||
|
||
f_reduce([ | ||
f_reduce([ | ||
f_map(1), | ||
f_map(2) | ||
]), | ||
f_map(3), | ||
f_map(4) | ||
]]) | ||
|
||
|
||
State can be passed down through the calls with `f_kwargs`, | ||
to iterables of mapped items. When kwargs are passed, as in | ||
``map_reduce([[1, 2], 3, 4], **kw)``, this becomes:: | ||
|
||
kw1 = f_kwargs(**kw) | ||
kw2 = f_kwargs(**kw1) | ||
f_reduce([ | ||
f_reduce([ | ||
f_map(1), **kw2) | ||
f_map(2, **kw2) | ||
], **kw1), | ||
f_map(3, **kw1), | ||
f_map(4, **kw1) | ||
]], **kw) | ||
""" | ||
def f(x, **kwargs): | ||
if not self.recurse_if(x): | ||
return f_map(x, **kwargs) | ||
else: | ||
next_kwargs = f_kwargs(**kwargs) | ||
return f_reduce(( | ||
f(xi, **next_kwargs) | ||
for xi in x | ||
), **kwargs) | ||
return f(x, **kwargs) | ||
|
||
def walk(self, x, index=()): | ||
""" | ||
Iterate over x, yielding (index, value, entering), where | ||
|
||
* ``index``: a tuple of indices up to this point | ||
* ``value``: equal to ``x[index[0]][...][index[-1]]``. On the first iteration, is | ||
``x`` itself | ||
* ``entering``: bool. The result of ``recurse_if(value)`` | ||
""" | ||
do_recurse = self.recurse_if(x) | ||
yield index, x, do_recurse | ||
|
||
if not do_recurse: | ||
return | ||
for i, xi in enumerate(x): | ||
# yield from ... | ||
for v in self.walk(xi, index + (i,)): | ||
yield v | ||
def format_index(index): | ||
idx_str = ''.join('[{}]'.format(i) for i in index if i is not None) | ||
return 'arrays' + idx_str | ||
if type(arrays) is tuple: | ||
# not strictly necessary, but saves us from: | ||
# - more than one way to do things - no point treating tuples like | ||
# lists | ||
# - horribly confusing behaviour that results when tuples are | ||
# treated like ndarray | ||
raise TypeError( | ||
'{} is a tuple. ' | ||
'Only lists can be used to arrange blocks, and np.block does ' | ||
'not allow implicit conversion from tuple to ndarray.'.format( | ||
format_index(parent_index) | ||
) | ||
) | ||
elif type(arrays) is list and len(arrays) > 0: | ||
idxs_ndims = (_block_check_depths_match(arr, parent_index + [i]) | ||
for i, arr in enumerate(arrays) 8000 ) | ||
|
||
first_index, max_arr_ndim = next(idxs_ndims) | ||
for i, (index, ndim) in enumerate(idxs_ndims, 1): | ||
if ndim > max_arr_ndim: | ||
max_arr_ndim = ndim | ||
if len(index) != len(first_index): | ||
raise ValueError( | ||
"List depths are mismatched. First element was at depth " | ||
"{}, but there is an element at depth {} ({})".format( | ||
len(first_index), | ||
len(index), | ||
format_index(index) | ||
) | ||
) | ||
return first_index, max_arr_ndim | ||
elif type(arrays) is list and len(arrays) == 0: | ||
# We've 'bottomed out' on an empty list | ||
return parent_index + [None], 0 | ||
else: | ||
# We've 'bottomed out' - arrays is either a scalar or an array | ||
return parent_index, _nx.ndim(arrays) | ||
|
||
|
||
def _block(arrays, max_depth, result_ndim): | ||
""" | ||
Internal implementation of block. `arrays` is the argument passed to | ||
block. `max_depth` is the depth of nested lists within `arrays` and | ||
`result_ndim` is the greatest of the dimensions of the arrays in | ||
`arrays` and the depth of the lists in `arrays` (see block docstring | ||
for details). | ||
""" | ||
def atleast_nd(a, ndim): | ||
# Ensures `a` has at least `ndim` dimensions by prepending | 8000 tr>||
# ones to `a.shape` as necessary | ||
return array(a, ndmin=ndim, copy=False, subok=True) | ||
|
||
def block_recursion(arrays, depth=0): | ||
if depth < max_depth: | ||
if len(arrays) == 0: | ||
raise ValueError('Lists cannot be empty') | ||
arrs = [block_recursion(arr, depth+1) for arr in arrays] | ||
return _nx.concatenate(arrs, axis=-(max_depth-depth)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This now reflects the wording of the docstring quite closely:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wouldn't There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also slight personal preference to having spaces around binary operators for readability. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
It would, yeah. The reason I wrote it that way round was because I wanted to match the docstring as closely as possible (my original motivation for this pr was to make it a bit clearer to people reading the code how block actually worked). The docstring says:
To me the correspondence between the docstring and the code is ever-so-slightly clearer with the expression the way round that it currently is, and I think the effect on performance is probably negligable. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I often instinctively shorten things when they're being fed to a keyword argument, because you don't normally put spaces around the axis=-(max_depth - depth) and axis=depth - max_depth look a bit weird? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The style point is minor. Though there are many tools (e.g. |
||
else: | ||
# We've 'bottomed out' - arrays is either a scalar or an array | ||
# type(arrays) is not list | ||
return atleast_nd(arrays, result_ndim) | ||
|
||
return block_recursion(arrays) | ||
|
||
|
||
def block(arrays): | ||
|
@@ -587,81 +602,6 @@ def block(arrays): | |
|
||
|
||
""" | ||
def atleast_nd(x, ndim): | ||
x = asanyarray(x) | ||
diff = max(ndim - x.ndim, 0) | ||
return x[(None,)*diff + (Ellipsis,)] | ||
|
||
def format_index(index): | ||
return 'arrays' + ''.join('[{}]'.format(i) for i in index) | ||
|
||
rec = _Recurser(recurse_if=lambda x: type(x) is list) | ||
|
||
# ensure that the lists are all matched in depth | ||
list_ndim = None | ||
any_empty = False | ||
for index, value, entering in rec.walk(arrays): | ||
if type(value) is tuple: | ||
# not strictly necessary, but saves us from: | ||
# - more than one way to do things - no point treating tuples like | ||
# lists | ||
# - horribly confusing behaviour that results when tuples are | ||
# treated like ndarray | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any reason for removing this comment? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, I actually hadn't meant to delete that, will re-include it. |
||
raise TypeError( | ||
'{} is a tuple. ' | ||
'Only lists can be used to arrange blocks, and np.block does ' | ||
'not allow implicit conversion from tuple to ndarray.'.format( | ||
format_index(index) | ||
) | ||
) | ||
if not entering: | ||
curr_depth = len(index) | ||
elif len(value) == 0: | ||
curr_depth = len(index) + 1 | ||
any_empty = True | ||
else: | ||
continue | ||
|
||
if list_ndim is not None and list_ndim != curr_depth: | ||
raise ValueError( | ||
"List depths are mismatched. First element was at depth {}, " | ||
"but there is an element at depth {} ({})".format( | ||
list_ndim, | ||
curr_depth, | ||
format_index(index) | ||
) | ||
) | ||
list_ndim = curr_depth | ||
|
||
# do this here so we catch depth mismatches first | ||
if any_empty: | ||
raise ValueError('Lists cannot be empty') | ||
|
||
# convert all the arrays to ndarrays | ||
arrays = rec.map_reduce(arrays, | ||
f_map=asanyarray, | ||
f_reduce=list | ||
) | ||
|
||
# determine the maximum dimension of the elements | ||
elem_ndim = rec.map_reduce(arrays, | ||
f_map=lambda xi: xi.ndim, | ||
f_reduce=max | ||
) | ||
ndim = max(list_ndim, elem_ndim) | ||
|
||
# first axis to concatenate along | ||
first_axis = ndim - list_ndim | ||
|
||
# Make all the elements the same dimension | ||
arrays = rec.map_reduce(arrays, | ||
f_map=lambda xi: atleast_nd(xi, ndim), | ||
f_reduce=list | ||
) | ||
|
||
# concatenate innermost lists on the right, outermost on the left | ||
return rec.map_reduce(arrays, | ||
f_reduce=lambda xs, axis: _nx.concatenate(list(xs), axis=axis), | ||
f_kwargs=lambda axis: dict(axis=axis+1), | ||
axis=first_axis | ||
) | ||
bottom_index, arr_ndim = _block_check_depths_match(arrays) | ||
list_ndim = len(bottom_index) | ||
return _block(arrays, list_ndim, max(arr_ndim, list_ndim)) |
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This
i
andenumerate
is not needed here