8000 Merge pull request #12163 from shoyer/einsum-dispatch · numpy/numpy@5946502 · GitHub
[go: up one dir, main page]

Skip to content

Commit 5946502

Browse files
authored
Merge pull request #12163 from shoyer/einsum-dispatch
ENH: __array_function__ for np.einsum and np.block
2 parents 872372b + b9a13b5 commit 5946502

File tree

3 files changed

+49
-3
lines changed

3 files changed

+49
-3
lines changed

numpy/core/einsumfunc.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from numpy.compat import basestring
1010
from numpy.core.multiarray import c_einsum
1111
from numpy.core.numeric import asanyarray, tensordot
12+
from numpy.core.overrides import array_function_dispatch
1213

1314
__all__ = ['einsum', 'einsum_path']
1415

@@ -689,6 +690,17 @@ def _parse_einsum_input(operands):
689690
return (input_subscripts, output_subscript, operands)
690691

691692

693+
def _einsum_path_dispatcher(*operands, **kwargs):
694+
# NOTE: technically, we should only dispatch on array-like arguments, not
695+
# subscripts (given as strings). But separating operands into
696+
# arrays/subscripts is a little tricky/slow (given einsum's two supported
697+
# signatures), so as a practical shortcut we dispatch on everything.
698+
# Strings will be ignored for dispatching since they don't define
699+
# __array_function__.
700+
return operands
701+
702+
703+
@array_function_dispatch(_einsum_path_dispatcher)
692704
def einsum_path(*operands, **kwargs):
693705
"""
694706
einsum_path(subscripts, *operands, optimize='greedy')
@@ -980,7 +992,16 @@ def einsum_path(*operands, **kwargs):
980992
return (path, path_print)
981993

982994

995+
def _einsum_dispatcher(*operands, **kwargs):
996+
# Arguably we dispatch on more arguments that we really should; see note in
997+
# _einsum_path_dispatcher for why.
998+
for op in operands:
999+
yield op
1000+
yield kwargs.get('out')
1001+
1002+
9831003
# Rewrite einsum to handle different cases
1004+
@array_function_dispatch(_einsum_dispatcher)
9841005
def einsum(*operands, **kwargs):
9851006
"""
9861007
einsum(subscripts, *operands, out=None, dtype=None, order='K',

numpy/core/shape_base.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -631,7 +631,19 @@ def _block(arrays, max_depth, result_ndim, depth=0):
631631
return _atleast_nd(arrays, result_ndim)
632632

633633

634-
# TODO: support array_function_dispatch
634+
def _block_dispatcher(arrays):
635+
# Use type(...) is list to match the behavior of np.block(), which special
636+
# cases list specifically rather than allowing for generic iterables or
637+
# tuple. Also, we know that list.__array_function__ will never exist.
638+
if type(arrays) is list:
639+
for subarrays in arrays:
640+
for subarray in _block_dispatcher(subarrays):
641+
yield subarray
642+
else:
643+
yield arrays
644+
645+
646+
@array_function_dispatch(_block_dispatcher)
635647
def block(arrays):
636648
"""
637649
Assemble an nd-array from nested lists of blocks.

numpy/core/tests/test_shape_base.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@
66
array, arange, atleast_1d, atleast_2d, atleast_3d, block, vstack, hstack,
77
newaxis, concatenate, stack
88
)
9-
10-
from numpy.core.shape_base import (_block_setup,
9+
from numpy.core.shape_base import (_block_dispatcher, _block_setup,
1110
_block_concatenate, _block_slicing)
1211
from numpy.testing import (
1312
assert_, assert_raises, assert_array_equal, assert_equal,
@@ -677,3 +676,17 @@ def test_block_memory_order(self, block):
677676

678677
assert block(b_c).flags['C_CONTIGUOUS']
679678
assert block(b_f).flags['F_CONTIGUOUS']
679+
680+
681+
def test_block_dispatcher():
682+
class ArrayLike(object):
683+
pass
684+
a = ArrayLike()
685+
b = ArrayLike()
686+
c = ArrayLike()
687+
assert_equal(list(_block_dispatcher(a)), [a])
688+
assert_equal(list(_block_dispatcher([a])), [a])
689+
assert_equal(list(_block_dispatcher([a, b])), [a, b])
690+
assert_equal(list(_block_dispatcher([[a], [b, [c]]])), [a, b, c])
691+
# don't recurse into non-lists
692+
assert_equal(list(_block_dispatcher((a, b))), [(a, b)])

0 commit comments

Comments
 (0)
0