From abcdd9a62a1f83fa5d233477442cf0a34bde2143 Mon Sep 17 00:00:00 2001 From: Mark Wiebe Date: Mon, 31 Jan 2011 09:13:57 -0800 Subject: [PATCH 1/7] ENH: einsum: Disable broadcasting by default, allow spaces in subscripts string --- numpy/add_newdocs.py | 21 ++-- numpy/core/src/multiarray/einsum.c.src | 58 +++++++--- numpy/core/tests/test_numeric.py | 140 ++++++++++++++----------- 3 files changed, 138 insertions(+), 81 deletions(-) diff --git a/numpy/add_newdocs.py b/numpy/add_newdocs.py index 51826c5ff149..f518602407b6 100644 --- a/numpy/add_newdocs.py +++ b/numpy/add_newdocs.py @@ -1540,13 +1540,16 @@ ``np.einsum('ji', a)`` takes its transpose. The output can be controlled by specifying output subscript labels - as well. This specifies the label order, and allows summing to be - disallowed or forced when desired. The call ``np.einsum('i->', a)`` - is equivalent to ``np.sum(a, axis=-1)``, and - ``np.einsum('ii->i', a)`` is equivalent to ``np.diag(a)``. - - It is also possible to control how broadcasting occurs using - an ellipsis. To take the trace along the first and last axes, + as well. This specifies the label order, and allows summing to + be disallowed or forced when desired. The call ``np.einsum('i->', a)`` + is like ``np.sum(a, axis=-1)``, and ``np.einsum('ii->i', a)`` + is like ``np.diag(a)``. The difference is that ``einsum`` does not + allow broadcasting by default. + + To enable and control broadcasting, use an ellipsis. Default + NumPy-style broadcasting is done by adding an ellipsis + to the left of each term, like ``np.einsum('...ii->...i', a)``. + To take the trace along the first and last axes, you can do ``np.einsum('i...i', a)``, or to do a matrix-matrix product with the left-most indices instead of rightmost, you can do ``np.einsum('ij...,jk...->ik...', a, b)``. @@ -1624,7 +1627,7 @@ [1, 4], [2, 5]]) - >>> np.einsum(',', 3, c) + >>> np.einsum('..., ...', 3, c) array([[ 0, 3, 6], [ 9, 12, 15]]) >>> np.multiply(3, c) @@ -1643,7 +1646,7 @@ array([[0, 1, 2, 3, 4], [0, 2, 4, 6, 8]]) - >>> np.einsum('i...->', a) + >>> np.einsum('i...->...', a) array([50, 55, 60, 65, 70]) >>> np.sum(a, axis=0) array([50, 55, 60, 65, 70]) diff --git a/numpy/core/src/multiarray/einsum.c.src b/numpy/core/src/multiarray/einsum.c.src index f5c91b170312..1c3b07bdecbf 100644 --- a/numpy/core/src/multiarray/einsum.c.src +++ b/numpy/core/src/multiarray/einsum.c.src @@ -46,6 +46,7 @@ #define EINSUM_IS_SSE_ALIGNED(x) ((((npy_intp)x)&0xf) == 0) typedef enum { + BROADCAST_NONE, BROADCAST_LEFT, BROADCAST_RIGHT, BROADCAST_MIDDLE @@ -1382,7 +1383,7 @@ parse_operand_subscripts(char *subscripts, int length, EINSUM_BROADCAST *out_broadcast) { int i, idim, ndim_left, label; - int left_labels = 0, right_labels = 0; + int left_labels = 0, right_labels = 0, ellipsis = 0; /* Process the labels from the end until the ellipsis */ idim = ndim-1; @@ -1417,6 +1418,7 @@ parse_operand_subscripts(char *subscripts, int length, else if (label == '.') { /* A valid ellipsis */ if (i >= 2 && subscripts[i-1] == '.' && subscripts[i-2] == '.') { + ellipsis = 1; length = i-2; break; } @@ -1428,7 +1430,7 @@ parse_operand_subscripts(char *subscripts, int length, } } - else { + else if (label != ' ') { PyErr_Format(PyExc_ValueError, "invalid subscript '%c' in einstein sum " "subscripts string, subscripts must " @@ -1436,6 +1438,15 @@ parse_operand_subscripts(char *subscripts, int length, return 0; } } + + if (!ellipsis && idim != -1) { + PyErr_Format(PyExc_ValueError, + "operand has more dimensions than subscripts " + "given in einstein sum, but no '...' ellipsis " + "provided to broadcast the extra dimensions."); + return 0; + } + /* Reduce ndim to just the dimensions left to fill at the beginning */ ndim_left = idim+1; idim = 0; @@ -1472,7 +1483,7 @@ parse_operand_subscripts(char *subscripts, int length, return 0; } } - else { + else if (label != ' ') { PyErr_Format(PyExc_ValueError, "invalid subscript '%c' in einstein sum " "subscripts string, subscripts must " @@ -1509,7 +1520,10 @@ parse_operand_subscripts(char *subscripts, int length, } } - if (left_labels && right_labels) { + if (!ellipsis) { + *out_broadcast = BROADCAST_NONE; + } + else if (left_labels && right_labels) { *out_broadcast = BROADCAST_MIDDLE; } else if (!left_labels) { @@ -1535,7 +1549,7 @@ parse_output_subscripts(char *subscripts, int length, EINSUM_BROADCAST *out_broadcast) { int i, nlabels, label, idim, ndim, ndim_left; - int left_labels = 0, right_labels = 0; + int left_labels = 0, right_labels = 0, ellipsis = 0; /* Count the labels, making sure they're all unique and valid */ nlabels = 0; @@ -1563,7 +1577,7 @@ parse_output_subscripts(char *subscripts, int length, return -1; } } - else if (label != '.') { + else if (label != '.' && label != ' ') { PyErr_Format(PyExc_ValueError, "invalid subscript '%c' in einstein sum " "subscripts string, subscripts must " @@ -1580,7 +1594,7 @@ parse_output_subscripts(char *subscripts, int length, for (i = length-1; i >= 0; --i) { label = subscripts[i]; /* A label for an axis */ - if (label != '.') { + if (label != '.' && label != ' ') { if (idim >= 0) { out_labels[idim--] = label; } @@ -1593,9 +1607,10 @@ parse_output_subscripts(char *subscripts, int length, right_labels = 1; } /* The end of the ellipsis */ - else { + else if (label == '.') { /* A valid ellipsis */ if (i >= 2 && subscripts[i-1] == '.' && subscripts[i-2] == '.') { + ellipsis = 1; length = i-2; break; } @@ -1608,6 +1623,15 @@ parse_output_subscripts(char *subscripts, int length, } } } + + if (!ellipsis && idim != -1) { + PyErr_SetString(PyExc_ValueError, + "output has more dimensions than subscripts " + "given in einstein sum, but no '...' ellipsis " + "provided to broadcast the extra dimensions."); + return 0; + } + /* Reduce ndim to just the dimensions left to fill at the beginning */ ndim_left = idim+1; idim = 0; @@ -1620,7 +1644,7 @@ parse_output_subscripts(char *subscripts, int length, for (i = 0; i < length; ++i) { label = subscripts[i]; /* A label for an axis */ - if (label != '.') { + if (label != '.' && label != ' ') { if (idim < ndim_left) { out_labels[idim++] = label; } @@ -1646,7 +1670,10 @@ parse_output_subscripts(char *subscripts, int length, out_labels[idim++] = 0; } - if (left_labels && right_labels) { + if (!ellipsis) { + *out_broadcast = BROADCAST_NONE; + } + else if (left_labels && right_labels) { *out_broadcast = BROADCAST_MIDDLE; } else if (!left_labels) { @@ -1941,7 +1968,7 @@ prepare_op_axes(int ndim, int iop, char *labels, npy_intp *axes, } } } - /* Middle broadcasting */ + /* Middle or None broadcasting */ else { /* broadcast dimensions get placed in leftmost position */ ibroadcast = 0; @@ -2133,8 +2160,13 @@ PyArray_EinsteinSum(char *subscripts, npy_intp nop, * that appeared once, in alphabetical order */ if (subscripts[0] == '\0') { - char outsubscripts[NPY_MAXDIMS]; - int length = 0; + char outsubscripts[NPY_MAXDIMS + 3]; + int length; + /* If no output was specified, always broadcast left (like normal) */ + outsubscripts[0] = '.'; + outsubscripts[1] = '.'; + outsubscripts[2] = '.'; + length = 3; for (label = min_label; label <= max_label; ++label) { if (label_counts[label] == 1) { if (length < NPY_MAXDIMS-1) { diff --git a/numpy/core/tests/test_numeric.py b/numpy/core/tests/test_numeric.py index f1cf7c2c40d0..68b40ff9b887 100644 --- a/numpy/core/tests/test_numeric.py +++ b/numpy/core/tests/test_numeric.py @@ -211,11 +211,16 @@ def test_einsum_errors(self): assert_raises(ValueError, np.einsum, "ii", np.arange(6).reshape(2,3)) assert_raises(ValueError, np.einsum, "ii->i", np.arange(6).reshape(2,3)) + # broadcasting to new dimensions must be enabled explicitly + assert_raises(ValueError, np.einsum, "i", np.arange(6).reshape(2,3)) + assert_raises(ValueError, np.einsum, "i->i", [[0,1],[0,1]], + out=np.arange(4).reshape(2,2)) + def test_einsum_views(self): # pass-through a = np.arange(6).reshape(2,3) - b = np.einsum("", a) + b = np.einsum("...", a) assert_(b.base is a) b = np.einsum("ij", a) @@ -239,16 +244,16 @@ def test_einsum_views(self): # diagonal with various ways of broadcasting an additional dimension a = np.arange(27).reshape(3,3,3) - b = np.einsum("ii->i", a) + b = np.einsum("...ii->...i", a) assert_(b.base is a) assert_equal(b, [[x[i,i] for i in range(3)] for x in a]) - b = np.einsum("ii...->i", a) + b = np.einsum("ii...->...i", a) assert_(b.base is a) assert_equal(b, [[x[i,i] for i in range(3)] for x in a.transpose(2,0,1)]) - b = np.einsum("ii->i...", a) + b = np.einsum("...ii->i...", a) assert_(b.base is a) assert_equal(b, [a[:,i,i] for i in range(3)]) @@ -264,7 +269,7 @@ def test_einsum_views(self): assert_(b.base is a) assert_equal(b, [a.transpose(1,0,2)[:,i,i] for i in range(3)]) - b = np.einsum("i...i->i", a) + b = np.einsum("i...i->...i", a) assert_(b.base is a) assert_equal(b, [[x[i,i] for i in range(3)] for x in a.transpose(1,0,2)]) @@ -288,33 +293,42 @@ def check_einsum_sums(self, dtype): a = np.arange(10, dtype=dtype) assert_equal(np.einsum("i->", a), np.sum(a, axis=-1)) - a = np.arange(24, dtype=dtype).reshape(2,3,4) - assert_equal(np.einsum("i->", a), np.sum(a, axis=-1)) + for n in range(1,17): + a = np.arange(2*3*n, dtype=dtype).reshape(2,3,n) + assert_equal(np.einsum("...i->...", a), + np.sum(a, axis=-1).astype(dtype)) # sum(a, axis=0) - a = np.arange(10, dtype=dtype) - assert_equal(np.einsum("i...->", a), np.sum(a, axis=0)) + for n in range(1,17): + a = np.arange(2*n, dtype=dtype).reshape(2,n) + assert_equal(np.einsum("i...->...", a), + np.sum(a, axis=0).astype(dtype)) - a = np.arange(24, dtype=dtype).reshape(2,3,4) - assert_equal(np.einsum("i...->", a), np.sum(a, axis=0)) + for n in range(1,17): + a = np.arange(2*3*n, dtype=dtype).reshape(2,3,n) + assert_equal(np.einsum("i...->...", a), + np.sum(a, axis=0).astype(dtype)) # trace(a) a = np.arange(25, dtype=dtype).reshape(5,5) assert_equal(np.einsum("ii", a), np.trace(a)) # multiply(a, b) - a = np.arange(12, dtype=dtype).reshape(3,4) - b = np.arange(24, dtype=dtype).reshape(2,3,4) - assert_equal(np.einsum(",", a, b), np.multiply(a, b)) + for n in range(1,17): + a = np.arange(3*n, dtype=dtype).reshape(3,n) + b = np.arange(2*3*n, dtype=dtype).reshape(2,3,n) + assert_equal(np.einsum("..., ...", a, b), np.multiply(a, b)) # inner(a,b) - a = np.arange(24, dtype=dtype).reshape(2,3,4) - b = np.arange(4, dtype=dtype) - assert_equal(np.einsum("i,i", a, b), np.inner(a, b)) + for n in range(1,17): + a = np.arange(2*3*n, dtype=dtype).reshape(2,3,n) + b = np.arange(n, dtype=dtype) + assert_equal(np.einsum("...i, ...i", a, b), np.inner(a, b)) - a = np.arange(24, dtype=dtype).reshape(2,3,4) - b = np.arange(2, dtype=dtype) - assert_equal(np.einsum("i...,i...", a, b), np.inner(a.T, b.T).T) + for n in range(1,11): + a = np.arange(n*3*2, dtype=dtype).reshape(n,3,2) + b = np.arange(n, dtype=dtype) + assert_equal(np.einsum("i..., i...", a, b), np.inner(a.T, b.T).T) # outer(a,b) a = np.arange(3, dtype=dtype)+1 @@ -328,28 +342,33 @@ def check_einsum_sums(self, dtype): warnings.simplefilter('ignore', np.ComplexWarning) # matvec(a,b) / a.dot(b) where a is matrix, b is vector - a = np.arange(20, dtype=dtype).reshape(4,5) - b = np.arange(5, dtype=dtype) - assert_equal(np.einsum("ij,j", a, b), np.dot(a, b)) - - a = np.arange(20, dtype=dtype).reshape(4,5) - b = np.arange(5, dtype=dtype) - c = np.arange(4, dtype=dtype) - np.einsum("ij,j", a, b, out=c, - dtype='f8', casting='unsafe') - assert_equal(c, - np.dot(a.astype('f8'), b.astype('f8')).astype(dtype)) - - a = np.arange(20, dtype=dtype).reshape(4,5) - b = np.arange(5, dtype=dtype) - assert_equal(np.einsum("ji,j", a.T, b.T), np.dot(b.T, a.T)) - - a = np.arange(20, dtype=dtype).reshape(4,5) - b = np.arange(5, dtype=dtype) - c = np.arange(4, dtype=dtype) - np.einsum("ji,j", a.T, b.T, out=c, dtype='f8', casting='unsafe') - assert_equal(c, - np.dot(b.T.astype('f8'), a.T.astype('f8')).astype(dtype)) + for n in range(1,17): + a = np.arange(4*n, dtype=dtype).reshape(4,n) + b = np.arange(n, dtype=dtype) + assert_equal(np.einsum("ij, j", a, b), np.dot(a, b)) + + for n in range(1,17): + a = np.arange(4*n, dtype=dtype).reshape(4,n) + b = np.arange(n, dtype=dtype) + c = np.arange(4, dtype=dtype) + np.einsum("ij,j", a, b, out=c, + dtype='f8', casting='unsafe') + assert_equal(c, + np.dot(a.astype('f8'), + b.astype('f8')).astype(dtype)) + + for n in range(1,17): + a = np.arange(4*n, dtype=dtype).reshape(4,n) + b = np.arange(n, dtype=dtype) + assert_equal(np.einsum("ji,j", a.T, b.T), np.dot(b.T, a.T)) + + a = np.arange(4*n, dtype=dtype).reshape(4,n) + b = np.arange(n, dtype=dtype) + c = np.arange(4, dtype=dtype) + np.einsum("ji,j", a.T, b.T, out=c, dtype='f8', casting='unsafe') + assert_equal(c, + np.dot(b.T.astype('f8'), + a.T.astype('f8')).astype(dtype)) # matmat(a,b) / a.dot(b) where a is matrix, b is matrix a = np.arange(20, dtype=dtype).reshape(4,5) @@ -363,7 +382,7 @@ def check_einsum_sums(self, dtype): assert_equal(c, np.dot(a.astype('f8'), b.astype('f8')).astype(dtype)) - # matrix triple product (note this is not an efficient + # matrix triple product (note this is not currently an efficient # way to multiply 3 matrices) a = np.arange(12, dtype=dtype).reshape(3,4) b = np.arange(20, dtype=dtype).reshape(4,5) @@ -385,7 +404,7 @@ def check_einsum_sums(self, dtype): if np.dtype(dtype) != np.dtype('f2'): a = np.arange(60, dtype=dtype).reshape(3,4,5) b = np.arange(24, dtype=dtype).reshape(4,3,2) - assert_equal(np.einsum("ijk,jil->kl", a, b), + assert_equal(np.einsum("ijk, jil -> kl", a, b), np.tensordot(a,b, axes=([1,0],[0,1]))) a = np.arange(60, dtype=dtype).reshape(3,4,5) @@ -411,21 +430,24 @@ def check_einsum_sums(self, dtype): assert_equal(np.einsum("i,->", a, 3), 3*np.sum(a)) # Various stride0, contiguous, and SSE aligned variants - a = np.arange(64, dtype=dtype) - if np.dtype(dtype).itemsize > 1: - assert_equal(np.einsum(",",a,a), np.multiply(a,a)) - assert_equal(np.einsum("i,i", a, a), np.dot(a,a)) - assert_equal(np.einsum("i,->i", a, 2), 2*a) - assert_equal(np.einsum(",i->i", 2, a), 2*a) - assert_equal(np.einsum("i,->", a, 2), 2*np.sum(a)) - assert_equal(np.einsum(",i->", 2, a), 2*np.sum(a)) - - assert_equal(np.einsum(",",a[1:],a[:-1]), np.multiply(a[1:],a[:-1])) - assert_equal(np.einsum("i,i", a[1:], a[:-1]), np.dot(a[1:],a[:-1])) - assert_equal(np.einsum("i,->i", a[1:], 2), 2*a[1:]) - assert_equal(np.einsum(",i->i", 2, a[1:]), 2*a[1:]) - assert_equal(np.einsum("i,->", a[1:], 2), 2*np.sum(a[1:])) - assert_equal(np.einsum(",i->", 2, a[1:]), 2*np.sum(a[1:])) + for n in range(1,25): + a = np.arange(n, dtype=dtype) + if np.dtype(dtype).itemsize > 1: + assert_equal(np.einsum("...,...",a,a), np.multiply(a,a)) + assert_equal(np.einsum("i,i", a, a), np.dot(a,a)) + assert_equal(np.einsum("i,->i", a, 2), 2*a) + assert_equal(np.einsum(",i->i", 2, a), 2*a) + assert_equal(np.einsum("i,->", a, 2), 2*np.sum(a)) + assert_equal(np.einsum(",i->", 2, a), 2*np.sum(a)) + + assert_equal(np.einsum("...,...",a[1:],a[:-1]), + np.multiply(a[1:],a[:-1])) + assert_equal(np.einsum("i,i", a[1:], a[:-1]), + np.dot(a[1:],a[:-1])) + assert_equal(np.einsum("i,->i", a[1:], 2), 2*a[1:]) + assert_equal(np.einsum(",i->i", 2, a[1:]), 2*a[1:]) + assert_equal(np.einsum("i,->", a[1:], 2), 2*np.sum(a[1:])) + assert_equal(np.einsum(",i->", 2, a[1:]), 2*np.sum(a[1:])) # An object array, summed as the data type a = np.arange(9, dtype=object) From cdb0a56c8551182e566f0308fd9f4515d5e95d89 Mon Sep 17 00:00:00 2001 From: Mark Wiebe Date: Mon, 31 Jan 2011 12:22:39 -0800 Subject: [PATCH 2/7] ENH: einsum: Add alternative einsum parameter method This makes the following equivalent: einsum('ii', a) einsum(a, [0,0]) einsum('ii->i', a) einsum(a, [0,0], [0]) einsum('...i,...i->...', a, b) einsum(a, [Ellipsis,0], b, [Ellipsis,0], [Ellipsis]) --- numpy/add_newdocs.py | 30 ++ numpy/core/src/multiarray/multiarraymodule.c | 276 ++++++++++++++++--- numpy/core/tests/test_numeric.py | 163 +++++++++-- 3 files changed, 411 insertions(+), 58 deletions(-) diff --git a/numpy/add_newdocs.py b/numpy/add_newdocs.py index f518602407b6..e749784d58fd 100644 --- a/numpy/add_newdocs.py +++ b/numpy/add_newdocs.py @@ -1517,6 +1517,10 @@ Evaluates the Einstein summation convention on the operands. + An alternative way to provide the subscripts and operands is as + einsum(op0, sublist0, op1, sublist1, ..., [sublistout]). The examples + below have corresponding einsum calls with the two parameter methods. + Using the Einstein summation convention, many common multi-dimensional array operations can be represented in a simple fashion. This function provides a way compute such summations. @@ -1605,20 +1609,30 @@ >>> np.einsum('ii', a) 60 + >>> np.einsum(a, [0,0]) + 60 >>> np.trace(a) 60 >>> np.einsum('ii->i', a) array([ 0, 6, 12, 18, 24]) + >>> np.einsum(a, [0,0], [0]) + array([ 0, 6, 12, 18, 24]) >>> np.diag(a) array([ 0, 6, 12, 18, 24]) >>> np.einsum('ij,j', a, b) array([ 30, 80, 130, 180, 230]) + >>> np.einsum(a, [0,1], b, [1]) + array([ 30, 80, 130, 180, 230]) >>> np.dot(a, b) array([ 30, 80, 130, 180, 230]) >>> np.einsum('ji', c) + array([[0, 3], + [1, 4], + [2, 5]]) + >>> np.einsum(c, [1,0]) array([[0, 3], [1, 4], [2, 5]]) @@ -1628,6 +1642,9 @@ [2, 5]]) >>> np.einsum('..., ...', 3, c) + array([[ 0, 3, 6], + [ 9, 12, 15]]) + >>> np.einsum(3, [Ellipsis], c, [Ellipsis]) array([[ 0, 3, 6], [ 9, 12, 15]]) >>> np.multiply(3, c) @@ -1636,10 +1653,15 @@ >>> np.einsum('i,i', b, b) 30 + >>> np.einsum(b, [0], b, [0]) + 30 >>> np.inner(b,b) 30 >>> np.einsum('i,j', np.arange(2)+1, b) + array([[0, 1, 2, 3, 4], + [0, 2, 4, 6, 8]]) + >>> np.einsum(np.arange(2)+1, [0], b, [1]) array([[0, 1, 2, 3, 4], [0, 2, 4, 6, 8]]) >>> np.outer(np.arange(2)+1, b) @@ -1648,12 +1670,20 @@ >>> np.einsum('i...->...', a) array([50, 55, 60, 65, 70]) + >>> np.einsum(a, [0,Ellipsis], [Ellipsis]) + array([50, 55, 60, 65, 70]) >>> np.sum(a, axis=0) array([50, 55, 60, 65, 70]) >>> a = np.arange(60.).reshape(3,4,5) >>> b = np.arange(24.).reshape(4,3,2) >>> np.einsum('ijk,jil->kl', a, b) + array([[ 4400., 4730.], + [ 4532., 4874.], + [ 4664., 5018.], + [ 4796., 5162.], + [ 4928., 5306.]]) + >>> np.einsum(a, [0,1,2], b, [1,0,3], [2,3]) array([[ 4400., 4730.], [ 4532., 4874.], [ 4664., 5018.], diff --git a/numpy/core/src/multiarray/multiarraymodule.c b/numpy/core/src/multiarray/multiarraymodule.c index bbc3e8f23419..9d9510a1f577 100644 --- a/numpy/core/src/multiarray/multiarraymodule.c +++ b/numpy/core/src/multiarray/multiarraymodule.c @@ -1953,47 +1953,40 @@ array_matrixproduct(PyObject *NPY_UNUSED(dummy), PyObject *args) return _ARET(PyArray_MatrixProduct(a, v)); } -static PyObject * -array_einsum(PyObject *NPY_UNUSED(dummy), PyObject *args, PyObject *kwds) +static int +einsum_sub_op_from_str(PyObject *args, PyObject **str_obj, char **subscripts, + PyArrayObject **op) { - char *subscripts; int i, nop; - PyArrayObject *op[NPY_MAXARGS]; - NPY_ORDER order = NPY_KEEPORDER; - NPY_CASTING casting = NPY_SAFE_CASTING; - PyArrayObject *out = NULL; - PyArray_Descr *dtype = NULL; - PyObject *ret = NULL; PyObject *subscripts_str; - PyObject *str_obj = NULL; - PyObject *str_key_obj = NULL; nop = PyTuple_GET_SIZE(args) - 1; if (nop <= 0) { PyErr_SetString(PyExc_ValueError, "must specify the einstein sum subscripts string " "and at least one operand"); - return NULL; + return -1; } - else if (nop > NPY_MAXARGS) { + else if (nop >= NPY_MAXARGS) { PyErr_SetString(PyExc_ValueError, "too many operands"); - return NULL; + return -1; } /* Get the subscripts string */ subscripts_str = PyTuple_GET_ITEM(args, 0); if (PyUnicode_Check(subscripts_str)) { - str_obj = PyUnicode_AsASCIIString(subscripts_str); - if (str_obj == NULL) { - return NULL; + *str_obj = PyUnicode_AsASCIIString(subscripts_str); + if (*str_obj == NULL) { + return -1; } - subscripts_str = str_obj; + subscripts_str = *str_obj; } - subscripts = PyBytes_AsString(subscripts_str); + *subscripts = PyBytes_AsString(subscripts_str); if (subscripts == NULL) { - Py_XDECREF(str_obj); - return NULL; + Py_XDECREF(*str_obj); + *str_obj = NULL; + return -1; } /* Set the operands to NULL */ @@ -2004,17 +1997,235 @@ array_einsum(PyObject *NPY_UNUSED(dummy), PyObject *args, PyObject *kwds) /* Get the operands */ for (i = 0; i < nop; ++i) { PyObject *obj = PyTuple_GET_ITEM(args, i+1); - if (PyArray_Check(obj)) { - Py_INCREF(obj); - op[i] = (PyArrayObject *)obj; + + op[i] = (PyArrayObject *)PyArray_FromAny(obj, + NULL, 0, 0, NPY_ENSUREARRAY, NULL); + if (op[i] == NULL) { + goto fail; + } + } + + return nop; + +fail: + for (i = 0; i < nop; ++i) { + Py_XDECREF(op[i]); + op[i] = NULL; + } + + return -1; +} + +/* + * Converts a list of subscripts to a string. + * + * Returns -1 on error, the number of characters placed in subscripts + * otherwise. + */ +static int +einsum_list_to_subscripts(PyObject *obj, char *subscripts, int subsize) +{ + int ellipsis = 0, subindex = 0; + npy_intp i, size; + PyObject *item; + + obj = PySequence_Fast(obj, "the subscripts for each operand must " + "be a list or a tuple"); + if (obj == NULL) { + return -1; + } + size = PySequence_Size(obj); + + + for (i = 0; i < size; ++i) { + item = PySequence_Fast_GET_ITEM(obj, i); + /* Ellipsis */ + if (item == Py_Ellipsis) { + if (ellipsis) { + PyErr_SetString(PyExc_ValueError, + "each subscripts list may have only one ellipsis"); + Py_DECREF(obj); + return -1; + } + if (subindex + 3 >= subsize) { + PyErr_SetString(PyExc_ValueError, + "subscripts list is too long"); + Py_DECREF(obj); + return -1; + } + subscripts[subindex++] = '.'; + subscripts[subindex++] = '.'; + subscripts[subindex++] = '.'; + ellipsis = 1; + } + /* Subscript */ + else if (PyInt_Check(item) || PyLong_Check(item)) { + long s = PyInt_AsLong(item); + if ( s < 0 || s > 2*26) { + PyErr_SetString(PyExc_ValueError, + "subscript is not within the valid range [0, 52]"); + Py_DECREF(obj); + return -1; + } + if (s < 26) { + subscripts[subindex++] = 'A' + s; + } + else { + subscripts[subindex++] = 'a' + s; + } + if (subindex >= subsize) { + PyErr_SetString(PyExc_ValueError, + "subscripts list is too long"); + Py_DECREF(obj); + return -1; + } } + /* Invalid */ else { - op[i] = (PyArrayObject *)PyArray_FromAny(obj, - NULL, 0, 0, NPY_ENSUREARRAY, NULL); - if (op[i] == NULL) { - goto finish; + PyErr_SetString(PyExc_ValueError, + "each subscript must be either an integer " + "or an ellipsis"); + Py_DECREF(obj); + return -1; + } + } + + Py_DECREF(obj); + + return subindex; +} + +/* + * Fills in the subscripts, with maximum size subsize, and op, + * with the values in the tuple 'args'. + * + * Returns -1 on error, number of operands placed in op otherwise. + */ +static int +einsum_sub_op_from_lists(PyObject *args, + char *subscripts, int subsize, PyArrayObject **op) +{ + int subindex = 0; + npy_intp i, nop; + + nop = PyTuple_Size(args)/2; + + if (nop == 0) { + PyErr_SetString(PyExc_ValueError, "must provide at least an " + "operand and a subscripts list to einsum"); + return -1; + } + else if(nop >= NPY_MAXARGS) { + PyErr_SetString(PyExc_ValueError, "too many operands"); + return -1; + } + + /* Set the operands to NULL */ + for (i = 0; i < nop; ++i) { + op[nop] = NULL; + } + + /* Get the operands and build the subscript string */ + for (i = 0; i < nop; ++i) { + PyObject *obj = PyTuple_GET_ITEM(args, 2*i); + int n; + + /* Comma between the subscripts for each operand */ + if (i != 0) { + subscripts[subindex++] = ','; + if (subindex >= subsize) { + PyErr_SetString(PyExc_ValueError, + "subscripts list is too long"); + goto fail; } } + + op[i] = (PyArrayObject *)PyArray_FromAny(obj, + NULL, 0, 0, NPY_ENSUREARRAY, NULL); + if (op[i] == NULL) { + goto fail; + } + + obj = PyTuple_GET_ITEM(args, 2*i+1); + n = einsum_list_to_subscripts(obj, subscripts+subindex, + subsize-subindex); + if (n < 0) { + goto fail; + } + subindex += n; + } + + /* Add the '->' to the string if provided */ + if (PyTuple_Size(args) == 2*nop+1) { + PyObject *obj; + int n; + + if (subindex + 2 >= subsize) { + PyErr_SetString(PyExc_ValueError, + "subscripts list is too long"); + goto fail; + } + subscripts[subindex++] = '-'; + subscripts[subindex++] = '>'; + + obj = PyTuple_GET_ITEM(args, 2*nop); + n = einsum_list_to_subscripts(obj, subscripts+subindex, + subsize-subindex); + if (n < 0) { + goto fail; + } + subindex += n; + } + + /* NULL-terminate the subscripts string */ + subscripts[subindex] = '\0'; + + return nop; + +fail: + for (i = 0; i < nop; ++i) { + Py_XDECREF(op[i]); + op[i] = NULL; + } + + return -1; +} + +static PyObject * +array_einsum(PyObject *NPY_UNUSED(dummy), PyObject *args, PyObject *kwds) +{ + char *subscripts = NULL, subscripts_buffer[256]; + PyObject *str_obj = NULL, *str_key_obj = NULL; + PyObject *arg0; + int i, nop; + PyArrayObject *op[NPY_MAXARGS]; + NPY_ORDER order = NPY_KEEPORDER; + NPY_CASTING casting = NPY_SAFE_CASTING; + PyArrayObject *out = NULL; + PyArray_Descr *dtype = NULL; + PyObject *ret = NULL; + + if (PyTuple_GET_SIZE(args) < 1) { + PyErr_SetString(PyExc_ValueError, + "must specify the einstein sum subscripts string " + "and at least one operand, or at least one operand " + "and its corresponding subscripts list"); + return NULL; + } + arg0 = PyTuple_GET_ITEM(args, 0); + + /* einsum('i,j', a, b), einsum('i,j->ij', a, b) */ + if (PyString_Check(arg0) || PyUnicode_Check(arg0)) { + nop = einsum_sub_op_from_str(args, &str_obj, &subscripts, op); + } + /* einsum(a, [0], b, [1]), einsum(a, [0], b, [1], [0,1]) */ + else { + nop = einsum_sub_op_from_lists(args, subscripts_buffer, + sizeof(subscripts_buffer), op); + subscripts = subscripts_buffer; + } + if (nop <= 0) { + goto finish; } /* Get the keyword arguments */ @@ -2090,6 +2301,7 @@ array_einsum(PyObject *NPY_UNUSED(dummy), PyObject *args, PyObject *kwds) Py_XDECREF(dtype); Py_XDECREF(str_obj); Py_XDECREF(str_key_obj); + /* out is a borrowed reference */ return ret; } @@ -2722,20 +2934,20 @@ compare_chararrays(PyObject *NPY_UNUSED(dummy), PyObject *args, PyObject *kwds) int cmp_op; Bool rstrip; char *cmp_str; - Py_ssize_t strlen; + Py_ssize_t strlength; PyObject *res = NULL; static char msg[] = "comparision must be '==', '!=', '<', '>', '<=', '>='"; static char *kwlist[] = {"a1", "a2", "cmp", "rstrip", NULL}; if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOs#O&", kwlist, - &array, &other, &cmp_str, &strlen, + &array, &other, &cmp_str, &strlength, PyArray_BoolConverter, &rstrip)) { return NULL; } - if (strlen < 1 || strlen > 2) { + if (strlength < 1 || strlength > 2) { goto err; } - if (strlen > 1) { + if (strlength > 1) { if (cmp_str[1] != '=') { goto err; } diff --git a/numpy/core/tests/test_numeric.py b/numpy/core/tests/test_numeric.py index 68b40ff9b887..34d295a8b995 100644 --- a/numpy/core/tests/test_numeric.py +++ b/numpy/core/tests/test_numeric.py @@ -223,10 +223,17 @@ def test_einsum_views(self): b = np.einsum("...", a) assert_(b.base is a) + b = np.einsum(a, [Ellipsis]) + assert_(b.base is a) + b = np.einsum("ij", a) assert_(b.base is a) assert_equal(b, a) + b = np.einsum(a, [0,1]) + assert_(b.base is a) + assert_equal(b, a) + # transpose a = np.arange(6).reshape(2,3) @@ -234,6 +241,10 @@ def test_einsum_views(self): assert_(b.base is a) assert_equal(b, a.T) + b = np.einsum(a, [1,0]) + assert_(b.base is a) + assert_equal(b, a.T) + # diagonal a = np.arange(9).reshape(3,3) @@ -241,6 +252,10 @@ def test_einsum_views(self): assert_(b.base is a) assert_equal(b, [a[i,i] for i in range(3)]) + b = np.einsum(a, [0,0], [0]) + assert_(b.base is a) + assert_equal(b, [a[i,i] for i in range(3)]) + # diagonal with various ways of broadcasting an additional dimension a = np.arange(27).reshape(3,3,3) @@ -248,32 +263,62 @@ def test_einsum_views(self): assert_(b.base is a) assert_equal(b, [[x[i,i] for i in range(3)] for x in a]) + b = np.einsum(a, [Ellipsis,0,0], [Ellipsis,0]) + assert_(b.base is a) + assert_equal(b, [[x[i,i] for i in range(3)] for x in a]) + b = np.einsum("ii...->...i", a) assert_(b.base is a) assert_equal(b, [[x[i,i] for i in range(3)] for x in a.transpose(2,0,1)]) + b = np.einsum(a, [0,0,Ellipsis], [Ellipsis,0]) + assert_(b.base is a) + assert_equal(b, [[x[i,i] for i in range(3)] + for x in a.transpose(2,0,1)]) + b = np.einsum("...ii->i...", a) assert_(b.base is a) assert_equal(b, [a[:,i,i] for i in range(3)]) + b = np.einsum(a, [Ellipsis,0,0], [0,Ellipsis]) + assert_(b.base is a) + assert_equal(b, [a[:,i,i] for i in range(3)]) + b = np.einsum("jii->ij", a) assert_(b.base is a) assert_equal(b, [a[:,i,i] for i in range(3)]) + b = np.einsum(a, [1,0,0], [0,1]) + assert_(b.base is a) + assert_equal(b, [a[:,i,i] for i in range(3)]) + b = np.einsum("ii...->i...", a) assert_(b.base is a) assert_equal(b, [a.transpose(2,0,1)[:,i,i] for i in range(3)]) + b = np.einsum(a, [0,0,Ellipsis], [0,Ellipsis]) + assert_(b.base is a) + assert_equal(b, [a.transpose(2,0,1)[:,i,i] for i in range(3)]) + b = np.einsum("i...i->i...", a) assert_(b.base is a) assert_equal(b, [a.transpose(1,0,2)[:,i,i] for i in range(3)]) + b = np.einsum(a, [0,Ellipsis,0], [0,Ellipsis]) + assert_(b.base is a) + assert_equal(b, [a.transpose(1,0,2)[:,i,i] for i in range(3)]) + b = np.einsum("i...i->...i", a) assert_(b.base is a) assert_equal(b, [[x[i,i] for i in range(3)] for x in a.transpose(1,0,2)]) + b = np.einsum(a, [0,Ellipsis,0], [Ellipsis,0]) + assert_(b.base is a) + assert_equal(b, [[x[i,i] for i in range(3)] + for x in a.transpose(1,0,2)]) + # triple diagonal a = np.arange(27).reshape(3,3,3) @@ -281,6 +326,10 @@ def test_einsum_views(self): assert_(b.base is a) assert_equal(b, [a[i,i,i] for i in range(3)]) + b = np.einsum(a, [0,0,0], [0]) + assert_(b.base is a) + assert_equal(b, [a[i,i,i] for i in range(3)]) + # swap axes a = np.arange(24).reshape(2,3,4) @@ -288,52 +337,75 @@ def test_einsum_views(self): assert_(b.base is a) assert_equal(b, a.swapaxes(0,1)) + b = np.einsum(a, [0,1,2], [1,0,2]) + assert_(b.base is a) + assert_equal(b, a.swapaxes(0,1)) + def check_einsum_sums(self, dtype): # sum(a, axis=-1) - a = np.arange(10, dtype=dtype) - assert_equal(np.einsum("i->", a), np.sum(a, axis=-1)) + for n in range(1,17): + a = np.arange(n, dtype=dtype) + assert_equal(np.einsum("i->", a), np.sum(a, axis=-1).astype(dtype)) + assert_equal(np.einsum(a, [0], []), + np.sum(a, axis=-1).astype(dtype)) for n in range(1,17): a = np.arange(2*3*n, dtype=dtype).reshape(2,3,n) assert_equal(np.einsum("...i->...", a), np.sum(a, axis=-1).astype(dtype)) + assert_equal(np.einsum(a, [Ellipsis,0], [Ellipsis]), + np.sum(a, axis=-1).astype(dtype)) # sum(a, axis=0) for n in range(1,17): a = np.arange(2*n, dtype=dtype).reshape(2,n) assert_equal(np.einsum("i...->...", a), np.sum(a, axis=0).astype(dtype)) + assert_equal(np.einsum(a, [0,Ellipsis], [Ellipsis]), + np.sum(a, axis=0).astype(dtype)) for n in range(1,17): a = np.arange(2*3*n, dtype=dtype).reshape(2,3,n) assert_equal(np.einsum("i...->...", a), np.sum(a, axis=0).astype(dtype)) + assert_equal(np.einsum(a, [0,Ellipsis], [Ellipsis]), + np.sum(a, axis=0).astype(dtype)) # trace(a) - a = np.arange(25, dtype=dtype).reshape(5,5) - assert_equal(np.einsum("ii", a), np.trace(a)) + for n in range(1,17): + a = np.arange(n*n, dtype=dtype).reshape(n,n) + assert_equal(np.einsum("ii", a), np.trace(a).astype(dtype)) + assert_equal(np.einsum(a, [0,0]), np.trace(a).astype(dtype)) # multiply(a, b) for n in range(1,17): a = np.arange(3*n, dtype=dtype).reshape(3,n) b = np.arange(2*3*n, dtype=dtype).reshape(2,3,n) assert_equal(np.einsum("..., ...", a, b), np.multiply(a, b)) + assert_equal(np.einsum(a, [Ellipsis], b, [Ellipsis]), + np.multiply(a, b)) # inner(a,b) for n in range(1,17): a = np.arange(2*3*n, dtype=dtype).reshape(2,3,n) b = np.arange(n, dtype=dtype) assert_equal(np.einsum("...i, ...i", a, b), np.inner(a, b)) + assert_equal(np.einsum(a, [Ellipsis,0], b, [Ellipsis,0]), + np.inner(a, b)) for n in range(1,11): a = np.arange(n*3*2, dtype=dtype).reshape(n,3,2) b = np.arange(n, dtype=dtype) assert_equal(np.einsum("i..., i...", a, b), np.inner(a.T, b.T).T) + assert_equal(np.einsum(a, [0,Ellipsis], b, [0,Ellipsis]), + np.inner(a.T, b.T).T) # outer(a,b) - a = np.arange(3, dtype=dtype)+1 - b = np.arange(4, dtype=dtype)+1 - assert_equal(np.einsum("i,j", a, b), np.outer(a, b)) + for n in range(1,17): + a = np.arange(3, dtype=dtype)+1 + b = np.arange(n, dtype=dtype)+1 + assert_equal(np.einsum("i,j", a, b), np.outer(a, b)) + assert_equal(np.einsum(a, [0], b, [1]), np.outer(a, b)) # Suppress the complex warnings for the 'as f8' tests ctx = WarningManager() @@ -346,41 +418,61 @@ def check_einsum_sums(self, dtype): a = np.arange(4*n, dtype=dtype).reshape(4,n) b = np.arange(n, dtype=dtype) assert_equal(np.einsum("ij, j", a, b), np.dot(a, b)) + assert_equal(np.einsum(a, [0,1], b, [1]), np.dot(a, b)) - for n in range(1,17): - a = np.arange(4*n, dtype=dtype).reshape(4,n) - b = np.arange(n, dtype=dtype) c = np.arange(4, dtype=dtype) np.einsum("ij,j", a, b, out=c, dtype='f8', casting='unsafe') assert_equal(c, np.dot(a.astype('f8'), b.astype('f8')).astype(dtype)) + c[...] = 0 + np.einsum(a, [0,1], b, [1], out=c, + dtype='f8', casting='unsafe') + assert_equal(c, + np.dot(a.astype('f8'), + b.astype('f8')).astype(dtype)) for n in range(1,17): a = np.arange(4*n, dtype=dtype).reshape(4,n) b = np.arange(n, dtype=dtype) assert_equal(np.einsum("ji,j", a.T, b.T), np.dot(b.T, a.T)) + assert_equal(np.einsum(a.T, [1,0], b.T, [1]), np.dot(b.T, a.T)) - a = np.arange(4*n, dtype=dtype).reshape(4,n) - b = np.arange(n, dtype=dtype) c = np.arange(4, dtype=dtype) np.einsum("ji,j", a.T, b.T, out=c, dtype='f8', casting='unsafe') assert_equal(c, np.dot(b.T.astype('f8'), a.T.astype('f8')).astype(dtype)) + c[...] = 0 + np.einsum(a.T, [1,0], b.T, [1], out=c, + dtype='f8', casting='unsafe') + assert_equal(c, + np.dot(b.T.astype('f8'), + a.T.astype('f8')).astype(dtype)) # matmat(a,b) / a.dot(b) where a is matrix, b is matrix - a = np.arange(20, dtype=dtype).reshape(4,5) - b = np.arange(30, dtype=dtype).reshape(5,6) - assert_equal(np.einsum("ij,jk", a, b), np.dot(a, b)) + for n in range(1,17): + if n < 8 or dtype != 'f2': + a = np.arange(4*n, dtype=dtype).reshape(4,n) + b = np.arange(n*6, dtype=dtype).reshape(n,6) + assert_equal(np.einsum("ij,jk", a, b), np.dot(a, b)) + assert_equal(np.einsum(a, [0,1], b, [1,2]), np.dot(a, b)) - a = np.arange(20, dtype=dtype).reshape(4,5) - b = np.arange(30, dtype=dtype).reshape(5,6) - c = np.arange(24, dtype=dtype).reshape(4,6) - np.einsum("ij,jk", a, b, out=c, dtype='f8', casting='unsafe') - assert_equal(c, - np.dot(a.astype('f8'), b.astype('f8')).astype(dtype)) + for n in range(1,17): + a = np.arange(4*n, dtype=dtype).reshape(4,n) + b = np.arange(n*6, dtype=dtype).reshape(n,6) + c = np.arange(24, dtype=dtype).reshape(4,6) + np.einsum("ij,jk", a, b, out=c, dtype='f8', casting='unsafe') + assert_equal(c, + np.dot(a.astype('f8'), + b.astype('f8')).astype(dtype)) + c[...] = 0 + np.einsum(a, [0,1], b, [1,2], out=c, + dtype='f8', casting='unsafe') + assert_equal(c, + np.dot(a.astype('f8'), + b.astype('f8')).astype(dtype)) # matrix triple product (note this is not currently an efficient # way to multiply 3 matrices) @@ -390,15 +482,19 @@ def check_einsum_sums(self, dtype): if dtype != 'f2': assert_equal(np.einsum("ij,jk,kl", a, b, c), a.dot(b).dot(c)) + assert_equal(np.einsum(a, [0,1], b, [1,2], c, [2,3]), + a.dot(b).dot(c)) - a = np.arange(12, dtype=dtype).reshape(3,4) - b = np.arange(20, dtype=dtype).reshape(4,5) - c = np.arange(30, dtype=dtype).reshape(5,6) d = np.arange(18, dtype=dtype).reshape(3,6) np.einsum("ij,jk,kl", a, b, c, out=d, dtype='f8', casting='unsafe') assert_equal(d, a.astype('f8').dot(b.astype('f8') ).dot(c.astype('f8')).astype(dtype)) + d[...] = 0 + np.einsum(a, [0,1], b, [1,2], c, [2,3], out=d, + dtype='f8', casting='unsafe') + assert_equal(d, a.astype('f8').dot(b.astype('f8') + ).dot(c.astype('f8')).astype(dtype)) # tensordot(a, b) if np.dtype(dtype) != np.dtype('f2'): @@ -406,14 +502,19 @@ def check_einsum_sums(self, dtype): b = np.arange(24, dtype=dtype).reshape(4,3,2) assert_equal(np.einsum("ijk, jil -> kl", a, b), np.tensordot(a,b, axes=([1,0],[0,1]))) + assert_equal(np.einsum(a, [0,1,2], b, [1,0,3], [2,3]), + np.tensordot(a,b, axes=([1,0],[0,1]))) - a = np.arange(60, dtype=dtype).reshape(3,4,5) - b = np.arange(24, dtype=dtype).reshape(4,3,2) c = np.arange(10, dtype=dtype).reshape(5,2) np.einsum("ijk,jil->kl", a, b, out=c, dtype='f8', casting='unsafe') assert_equal(c, np.tensordot(a.astype('f8'), b.astype('f8'), axes=([1,0],[0,1])).astype(dtype)) + c[...] = 0 + np.einsum(a, [0,1,2], b, [1,0,3], [2,3], out=c, + dtype='f8', casting='unsafe') + assert_equal(c, np.tensordot(a.astype('f8'), b.astype('f8'), + axes=([1,0],[0,1])).astype(dtype)) finally: ctx.__exit__() @@ -424,10 +525,15 @@ def check_einsum_sums(self, dtype): assert_equal(np.einsum("i,i,i->i", a, b, c, dtype='?', casting='unsafe'), logical_and(logical_and(a!=0, b!=0), c!=0)) + assert_equal(np.einsum(a, [0], b, [0], c, [0], [0], + dtype='?', casting='unsafe'), + logical_and(logical_and(a!=0, b!=0), c!=0)) a = np.arange(9, dtype=dtype) assert_equal(np.einsum(",i->", 3, a), 3*np.sum(a)) + assert_equal(np.einsum(3, [], a, [0], []), 3*np.sum(a)) assert_equal(np.einsum("i,->", a, 3), 3*np.sum(a)) + assert_equal(np.einsum(a, [0], 3, [], []), 3*np.sum(a)) # Various stride0, contiguous, and SSE aligned variants for n in range(1,25): @@ -451,10 +557,15 @@ def check_einsum_sums(self, dtype): # An object array, summed as the data type a = np.arange(9, dtype=object) + b = np.einsum("i->", a, dtype=dtype, casting='unsafe') assert_equal(b, np.sum(a)) assert_equal(b.dtype, np.dtype(dtype)) + b = np.einsum(a, [0], [], dtype=dtype, casting='unsafe') + assert_equal(b, np.sum(a)) + assert_equal(b.dtype, np.dtype(dtype)) + def test_einsum_sums_int8(self): self.check_einsum_sums('i1'); From 3b6b801551def0076d4d3f81b11c313c91e277b0 Mon Sep 17 00:00:00 2001 From: Mark Wiebe Date: Fri, 28 Jan 2011 23:40:40 -0800 Subject: [PATCH 3/7] ENH: einsum: Change loop unrolling to be better for small loops --- numpy/core/src/multiarray/einsum.c.src | 557 ++++++++++++++----------- numpy/core/tests/test_numeric.py | 2 + 2 files changed, 317 insertions(+), 242 deletions(-) diff --git a/numpy/core/src/multiarray/einsum.c.src b/numpy/core/src/multiarray/einsum.c.src index 1c3b07bdecbf..91cb558b8f16 100644 --- a/numpy/core/src/multiarray/einsum.c.src +++ b/numpy/core/src/multiarray/einsum.c.src @@ -187,12 +187,35 @@ static void npy_@name@ *data0 = (npy_@name@ *)dataptr[0]; npy_@name@ *data_out = (npy_@name@ *)dataptr[1]; - /* Unroll the loop by 16 */ - while (count >= 16) { - count -= 16; +/* This is placed before the main loop to make small counts faster */ +finish_after_unrolled_loop: + switch (count) { +/**begin repeat2 + * #i = 6, 5, 4, 3, 2, 1, 0# + */ + case @i@+1: +#if !@complex@ + data_out[@i@] = @to@(@from@(data0[@i@]) + + @from@(data_out[@i@])); +#else + ((npy_@temp@ *)data_out + 2*@i@)[0] = + ((npy_@temp@ *)data0 + 2*@i@)[0] + + ((npy_@temp@ *)data_out + 2*@i@)[0]; + ((npy_@temp@ *)data_out + 2*@i@)[1] = + ((npy_@temp@ *)data0 + 2*@i@)[1] + + ((npy_@temp@ *)data_out + 2*@i@)[1]; +#endif +/**end repeat2**/ + case 0: + return; + } + + /* Unroll the loop by 8 */ + while (count >= 8) { + count -= 8; /**begin repeat2 - * #i = 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15# + * #i = 0, 1, 2, 3, 4, 5, 6, 7# */ #if !@complex@ data_out[@i@] = @to@(@from@(data0[@i@]) + @@ -205,31 +228,13 @@ static void ((npy_@temp@ *)data0 + 2*@i@)[1] + ((npy_@temp@ *)data_out + 2*@i@)[1]; #endif - data0 += 16; - data_out += 16; /**end repeat2**/ + data0 += 8; + data_out += 8; } /* Finish off the loop */ - -/**begin repeat2 - * #i = 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15# - */ - if (count-- == 0) { - return; - } -#if !@complex@ - data_out[@i@] = @to@(@from@(data0[@i@]) + - @from@(data_out[@i@])); -#else - ((npy_@temp@ *)data_out + 2*@i@)[0] = - ((npy_@temp@ *)data0 + 2*@i@)[0] + - ((npy_@temp@ *)data_out + 2*@i@)[0]; - ((npy_@temp@ *)data_out + 2*@i@)[1] = - ((npy_@temp@ *)data0 + 2*@i@)[1] + - ((npy_@temp@ *)data_out + 2*@i@)[1]; -#endif -/**end repeat2**/ + goto finish_after_unrolled_loop; } #elif @nop@ == 2 && !@complex@ @@ -245,36 +250,54 @@ static void #if EINSUM_USE_SSE1 && @float32@ __m128 a, b; #endif + +/* This is placed before the main loop to make small counts faster */ +finish_after_unrolled_loop: + switch (count) { +/**begin repeat2 + * #i = 6, 5, 4, 3, 2, 1, 0# + */ + case @i@+1: + data_out[@i@] = @to@(@from@(data0[@i@]) * + @from@(data1[@i@]) + + @from@(data_out[@i@])); +/**end repeat2**/ + case 0: + return; + } #if EINSUM_USE_SSE1 && @float32@ /* Use aligned instructions if possible */ if (EINSUM_IS_SSE_ALIGNED(data0) && EINSUM_IS_SSE_ALIGNED(data1) && EINSUM_IS_SSE_ALIGNED(data_out)) { - /* Unroll the loop by 16 */ - while (count >= 16) { - count -= 16; + /* Unroll the loop by 8 */ + while (count >= 8) { + count -= 8; /**begin repeat2 - * #i = 0, 4, 8, 12# + * #i = 0, 4# */ a = _mm_mul_ps(_mm_load_ps(data0+@i@), _mm_load_ps(data1+@i@)); b = _mm_add_ps(a, _mm_load_ps(data_out+@i@)); _mm_store_ps(data_out+@i@, b); /**end repeat2**/ - data0 += 16; - data1 += 16; - data_out += 16; + data0 += 8; + data1 += 8; + data_out += 8; } + + /* Finish off the loop */ + goto finish_after_unrolled_loop; } #endif - /* Unroll the loop by 16 */ - while (count >= 16) { - count -= 16; + /* Unroll the loop by 8 */ + while (count >= 8) { + count -= 8; #if EINSUM_USE_SSE1 && @float32@ /**begin repeat2 - * #i = 0, 4, 8, 12# + * #i = 0, 4# */ a = _mm_mul_ps(_mm_loadu_ps(data0+@i@), _mm_loadu_ps(data1+@i@)); b = _mm_add_ps(a, _mm_loadu_ps(data_out+@i@)); @@ -282,30 +305,20 @@ static void /**end repeat2**/ #else /**begin repeat2 - * #i = 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15# + * #i = 0, 1, 2, 3, 4, 5, 6, 7# */ data_out[@i@] = @to@(@from@(data0[@i@]) * @from@(data1[@i@]) + @from@(data_out[@i@])); /**end repeat2**/ #endif - data0 += 16; - data1 += 16; - data_out += 16; + data0 += 8; + data1 += 8; + data_out += 8; } /* Finish off the loop */ - -/**begin repeat2 - * #i = 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15# - */ - if (count-- == 0) { - return; - } - data_out[@i@] = @to@(@from@(data0[@i@]) * - @from@(data1[@i@]) + - @from@(data_out[@i@])); -/**end repeat2**/ + goto finish_after_unrolled_loop; } /* Some extra specializations for the two operand case */ @@ -319,37 +332,55 @@ static void #if EINSUM_USE_SSE1 && @float32@ __m128 a, b, value0_sse; - - value0_sse = _mm_set_ps1(value0); #endif +/* This is placed before the main loop to make small counts faster */ +finish_after_unrolled_loop: + switch (count) { +/**begin repeat2 + * #i = 6, 5, 4, 3, 2, 1, 0# + */ + case @i@+1: + data_out[@i@] = @to@(value0 * + @from@(data1[@i@]) + + @from@(data_out[@i@])); +/**end repeat2**/ + case 0: + return; + } + #if EINSUM_USE_SSE1 && @float32@ + value0_sse = _mm_set_ps1(value0); + /* Use aligned instructions if possible */ if (EINSUM_IS_SSE_ALIGNED(data1) && EINSUM_IS_SSE_ALIGNED(data_out)) { - /* Unroll the loop by 16 */ - while (count >= 16) { - count -= 16; + /* Unroll the loop by 8 */ + while (count >= 8) { + count -= 8; /**begin repeat2 - * #i = 0, 4, 8, 12# + * #i = 0, 4# */ a = _mm_mul_ps(value0_sse, _mm_load_ps(data1+@i@)); b = _mm_add_ps(a, _mm_load_ps(data_out+@i@)); _mm_store_ps(data_out+@i@, b); /**end repeat2**/ - data1 += 16; - data_out += 16; + data1 += 8; + data_out += 8; } + + /* Finish off the loop */ + goto finish_after_unrolled_loop; } #endif - /* Unroll the loop by 16 */ - while (count >= 16) { - count -= 16; + /* Unroll the loop by 8 */ + while (count >= 8) { + count -= 8; #if EINSUM_USE_SSE1 && @float32@ /**begin repeat2 - * #i = 0, 4, 8, 12# + * #i = 0, 4# */ a = _mm_mul_ps(value0_sse, _mm_loadu_ps(data1+@i@)); b = _mm_add_ps(a, _mm_loadu_ps(data_out+@i@)); @@ -357,29 +388,19 @@ static void /**end repeat2**/ #else /**begin repeat2 - * #i = 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15# + * #i = 0, 1, 2, 3, 4, 5, 6, 7# */ data_out[@i@] = @to@(value0 * @from@(data1[@i@]) + @from@(data_out[@i@])); /**end repeat2**/ #endif - data1 += 16; - data_out += 16; + data1 += 8; + data_out += 8; } /* Finish off the loop */ - -/**begin repeat2 - * #i = 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15# - */ - if (count-- == 0) { - return; - } - data_out[@i@] = @to@(value0 * - @from@(data1[@i@]) + - @from@(data_out[@i@])); -/**end repeat2**/ + goto finish_after_unrolled_loop; } static void @@ -392,37 +413,55 @@ static void #if EINSUM_USE_SSE1 && @float32@ __m128 a, b, value1_sse; - - value1_sse = _mm_set_ps1(value1); #endif +/* This is placed before the main loop to make small counts faster */ +finish_after_unrolled_loop: + switch (count) { +/**begin repeat2 + * #i = 6, 5, 4, 3, 2, 1, 0# + */ + case @i@+1: + data_out[@i@] = @to@(@from@(data0[@i@])* + value1 + + @from@(data_out[@i@])); +/**end repeat2**/ + case 0: + return; + } + #if EINSUM_USE_SSE1 && @float32@ + value1_sse = _mm_set_ps1(value1); + /* Use aligned instructions if possible */ if (EINSUM_IS_SSE_ALIGNED(data0) && EINSUM_IS_SSE_ALIGNED(data_out)) { - /* Unroll the loop by 16 */ - while (count >= 16) { - count -= 16; + /* Unroll the loop by 8 */ + while (count >= 8) { + count -= 8; /**begin repeat2 - * #i = 0, 4, 8, 12# + * #i = 0, 4# */ a = _mm_mul_ps(_mm_load_ps(data0+@i@), value1_sse); b = _mm_add_ps(a, _mm_load_ps(data_out+@i@)); _mm_store_ps(data_out+@i@, b); /**end repeat2**/ - data0 += 16; - data_out += 16; + data0 += 8; + data_out += 8; } + + /* Finish off the loop */ + goto finish_after_unrolled_loop; } #endif - /* Unroll the loop by 16 */ - while (count >= 16) { - count -= 16; + /* Unroll the loop by 8 */ + while (count >= 8) { + count -= 8; #if EINSUM_USE_SSE1 && @float32@ /**begin repeat2 - * #i = 0, 4, 8, 12# + * #i = 0, 4# */ a = _mm_mul_ps(_mm_loadu_ps(data0+@i@), value1_sse); b = _mm_add_ps(a, _mm_loadu_ps(data_out+@i@)); @@ -430,29 +469,19 @@ static void /**end repeat2**/ #else /**begin repeat2 - * #i = 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15# + * #i = 0, 1, 2, 3, 4, 5, 6, 7# */ data_out[@i@] = @to@(@from@(data0[@i@])* value1 + @from@(data_out[@i@])); /**end repeat2**/ #endif - data0 += 16; - data_out += 16; + data0 += 8; + data_out += 8; } /* Finish off the loop */ - -/**begin repeat2 - * #i = 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15# - */ - if (count-- == 0) { - return; - } - data_out[@i@] = @to@(@from@(data0[@i@])* - value1 + - @from@(data_out[@i@])); -/**end repeat2**/ + goto finish_after_unrolled_loop; } static void @@ -467,15 +496,29 @@ static void __m128 a, accum_sse = _mm_setzero_ps(); #endif +/* This is placed before the main loop to make small counts faster */ +finish_after_unrolled_loop: + switch (count) { +/**begin repeat2 + * #i = 6, 5, 4, 3, 2, 1, 0# + */ + case @i@+1: + accum += @from@(data0[@i@]) * @from@(data1[@i@]); +/**end repeat2**/ + case 0: + *(npy_@name@ *)dataptr[2] += @to@(accum); + return; + } + #if EINSUM_USE_SSE1 && @float32@ /* Use aligned instructions if possible */ if (EINSUM_IS_SSE_ALIGNED(data0) && EINSUM_IS_SSE_ALIGNED(data1)) { - /* Unroll the loop by 16 */ - while (count >= 16) { - count -= 16; + /* Unroll the loop by 8 */ + while (count >= 8) { + count -= 8; /**begin repeat2 - * #i = 0, 4, 8, 12# + * #i = 0, 4# */ /* * NOTE: This accumulation changes the order, so will likely @@ -484,19 +527,31 @@ static void a = _mm_mul_ps(_mm_load_ps(data0+@i@), _mm_load_ps(data1+@i@)); accum_sse = _mm_add_ps(accum_sse, a); /**end repeat2**/ - data0 += 16; - data1 += 16; + data0 += 8; + data1 += 8; } + +#if EINSUM_USE_SSE1 && @float32@ + /* Add the four SSE values and put in accum */ + a = _mm_shuffle_ps(accum_sse, accum_sse, _MM_SHUFFLE(2,3,0,1)); + accum_sse = _mm_add_ps(a, accum_sse); + a = _mm_shuffle_ps(accum_sse, accum_sse, _MM_SHUFFLE(1,0,3,2)); + accum_sse = _mm_add_ps(a, accum_sse); + _mm_store_ss(&accum, accum_sse); +#endif + + /* Finish off the loop */ + goto finish_after_unrolled_loop; } #endif - /* Unroll the loop by 16 */ - while (count >= 16) { - count -= 16; + /* Unroll the loop by 8 */ + while (count >= 8) { + count -= 8; #if EINSUM_USE_SSE1 && @float32@ /**begin repeat2 - * #i = 0, 4, 8, 12# + * #i = 0, 4# */ /* * NOTE: This accumulation changes the order, so will likely @@ -507,13 +562,13 @@ static void /**end repeat2**/ #else /**begin repeat2 - * #i = 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15# + * #i = 0, 1, 2, 3, 4, 5, 6, 7# */ accum += @from@(data0[@i@]) * @from@(data1[@i@]); /**end repeat2**/ #endif - data0 += 16; - data1 += 16; + data0 += 8; + data1 += 8; } #if EINSUM_USE_SSE1 && @float32@ @@ -524,19 +579,9 @@ static void accum_sse = _mm_add_ps(a, accum_sse); _mm_store_ss(&accum, accum_sse); #endif - /* Finish off the loop */ -/**begin repeat2 - * #i = 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15# - */ - if (count-- == 0) { - *(npy_@name@ *)dataptr[2] += @to@(accum); - return; - } - accum += @from@(data0[@i@]) * @from@(data1[@i@]); -/**end repeat2**/ - - *(npy_@name@ *)dataptr[2] += @to@(accum); + /* Finish off the loop */ + goto finish_after_unrolled_loop; } static void @@ -551,15 +596,29 @@ static void __m128 a, accum_sse = _mm_setzero_ps(); #endif +/* This is placed before the main loop to make small counts faster */ +finish_after_unrolled_loop: + switch (count) { +/**begin repeat2 + * #i = 6, 5, 4, 3, 2, 1, 0# + */ + case @i@+1: + accum += @from@(data1[@i@]); +/**end repeat2**/ + case 0: + *(npy_@name@ *)dataptr[2] += @to@(value0 * accum); + return; + } + #if EINSUM_USE_SSE1 && @float32@ /* Use aligned instructions if possible */ if (EINSUM_IS_SSE_ALIGNED(data1)) { - /* Unroll the loop by 16 */ - while (count >= 16) { - count -= 16; + /* Unroll the loop by 8 */ + while (count >= 8) { + count -= 8; /**begin repeat2 - * #i = 0, 4, 8, 12# + * #i = 0, 4# */ /* * NOTE: This accumulation changes the order, so will likely @@ -567,18 +626,30 @@ static void */ accum_sse = _mm_add_ps(accum_sse, _mm_load_ps(data1+@i@)); /**end repeat2**/ - data1 += 16; + data1 += 8; } + +#if EINSUM_USE_SSE1 && @float32@ + /* Add the four SSE values and put in accum */ + a = _mm_shuffle_ps(accum_sse, accum_sse, _MM_SHUFFLE(2,3,0,1)); + accum_sse = _mm_add_ps(a, accum_sse); + a = _mm_shuffle_ps(accum_sse, accum_sse, _MM_SHUFFLE(1,0,3,2)); + accum_sse = _mm_add_ps(a, accum_sse); + _mm_store_ss(&accum, accum_sse); +#endif + + /* Finish off the loop */ + goto finish_after_unrolled_loop; } #endif - /* Unroll the loop by 16 */ - while (count >= 16) { - count -= 16; + /* Unroll the loop by 8 */ + while (count >= 8) { + count -= 8; #if EINSUM_USE_SSE1 && @float32@ /**begin repeat2 - * #i = 0, 4, 8, 12# + * #i = 0, 4# */ /* * NOTE: This accumulation changes the order, so will likely @@ -588,12 +659,12 @@ static void /**end repeat2**/ #else /**begin repeat2 - * #i = 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15# + * #i = 0, 1, 2, 3, 4, 5, 6, 7# */ accum += @from@(data1[@i@]); /**end repeat2**/ #endif - data1 += 16; + data1 += 8; } #if EINSUM_USE_SSE1 && @float32@ @@ -604,19 +675,9 @@ static void accum_sse = _mm_add_ps(a, accum_sse); _mm_store_ss(&accum, accum_sse); #endif - /* Finish off the loop */ -/**begin repeat2 - * #i = 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15# - */ - if (count-- == 0) { - *(npy_@name@ *)dataptr[2] += @to@(value0 * accum); - return; - } - accum += @from@(data1[@i@]); -/**end repeat2**/ - - *(npy_@name@ *)dataptr[2] += @to@(value0 * accum); + /* Finish off the loop */ + goto finish_after_unrolled_loop; } static void @@ -631,15 +692,29 @@ static void __m128 a, accum_sse = _mm_setzero_ps(); #endif +/* This is placed before the main loop to make small counts faster */ +finish_after_unrolled_loop: + switch (count) { +/**begin repeat2 + * #i = 6, 5, 4, 3, 2, 1, 0# + */ + case @i@+1: + accum += @from@(data0[@i@]); +/**end repeat2**/ + case 0: + *(npy_@name@ *)dataptr[2] += @to@(accum * value1); + return; + } + #if EINSUM_USE_SSE1 && @float32@ /* Use aligned instructions if possible */ if (EINSUM_IS_SSE_ALIGNED(data0)) { - /* Unroll the loop by 16 */ - while (count >= 16) { - count -= 16; + /* Unroll the loop by 8 */ + while (count >= 8) { + count -= 8; /**begin repeat2 - * #i = 0, 4, 8, 12# + * #i = 0, 4# */ /* * NOTE: This accumulation changes the order, so will likely @@ -647,18 +722,30 @@ static void */ accum_sse = _mm_add_ps(accum_sse, _mm_load_ps(data0+@i@)); /**end repeat2**/ - data0 += 16; + data0 += 8; } + +#if EINSUM_USE_SSE1 && @float32@ + /* Add the four SSE values and put in accum */ + a = _mm_shuffle_ps(accum_sse, accum_sse, _MM_SHUFFLE(2,3,0,1)); + accum_sse = _mm_add_ps(a, accum_sse); + a = _mm_shuffle_ps(accum_sse, accum_sse, _MM_SHUFFLE(1,0,3,2)); + accum_sse = _mm_add_ps(a, accum_sse); + _mm_store_ss(&accum, accum_sse); +#endif + + /* Finish off the loop */ + goto finish_after_unrolled_loop; } #endif - /* Unroll the loop by 16 */ - while (count >= 16) { - count -= 16; + /* Unroll the loop by 8 */ + while (count >= 8) { + count -= 8; #if EINSUM_USE_SSE1 && @float32@ /**begin repeat2 - * #i = 0, 4, 8, 12# + * #i = 0, 4# */ /* * NOTE: This accumulation changes the order, so will likely @@ -668,12 +755,12 @@ static void /**end repeat2**/ #else /**begin repeat2 - * #i = 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15# + * #i = 0, 1, 2, 3, 4, 5, 6, 7# */ accum += @from@(data0[@i@]); /**end repeat2**/ #endif - data0 += 16; + data0 += 8; } #if EINSUM_USE_SSE1 && @float32@ @@ -684,19 +771,9 @@ static void accum_sse = _mm_add_ps(a, accum_sse); _mm_store_ss(&accum, accum_sse); #endif - /* Finish off the loop */ - -/**begin repeat2 - * #i = 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15# - */ - if (count-- == 0) { - *(npy_@name@ *)dataptr[2] += @to@(accum * value1); - return; - } - accum += @from@(data0[@i@]); -/**end repeat2**/ - *(npy_@name@ *)dataptr[2] += @to@(accum * value1); + /* Finish off the loop */ + goto finish_after_unrolled_loop; } #elif @nop@ == 3 && !@complex@ @@ -710,27 +787,27 @@ static void npy_@name@ *data2 = (npy_@name@ *)dataptr[2]; npy_@name@ *data_out = (npy_@name@ *)dataptr[3]; - /* Unroll the loop by 16 */ - while (count >= 16) { - count -= 16; + /* Unroll the loop by 8 */ + while (count >= 8) { + count -= 8; /**begin repeat2 - * #i = 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15# + * #i = 0, 1, 2, 3, 4, 5, 6, 7# */ data_out[@i@] = @to@(@from@(data0[@i@]) * @from@(data1[@i@]) * @from@(data2[@i@]) + @from@(data_out[@i@])); /**end repeat2**/ - data0 += 16; - data1 += 16; - data_out += 16; + data0 += 8; + data1 += 8; + data_out += 8; } /* Finish off the loop */ /**begin repeat2 - * #i = 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15# + * #i = 0, 1, 2, 3, 4, 5, 6, 7# */ if (count-- == 0) { return; @@ -981,38 +1058,76 @@ bool_sum_of_products_contig_@noplabel@(int nop, char **dataptr, char *data_out = dataptr[@nop@]; #endif -/* Unroll the loop by 16 for fixed-size nop */ #if (@nop@ <= 3) - while (count >= 16) { - count -= 16; +/* This is placed before the main loop to make small counts faster */ +finish_after_unrolled_loop: + switch (count) { +/**begin repeat1 + * #i = 6, 5, 4, 3, 2, 1, 0# + */ + case @i@+1: +# if @nop@ == 1 + *((npy_bool *)data_out + @i@) = (*((npy_bool *)data0 + @i@)) || + (*((npy_bool *)data_out + @i@)); + data0 += 8*sizeof(npy_bool); + data_out += 8*sizeof(npy_bool); +# elif @nop@ == 2 + *((npy_bool *)data_out + @i@) = + ((*((npy_bool *)data0 + @i@)) && + (*((npy_bool *)data1 + @i@))) || + (*((npy_bool *)data_out + @i@)); + data0 += 8*sizeof(npy_bool); + data1 += 8*sizeof(npy_bool); + data_out += 8*sizeof(npy_bool); +# elif @nop@ == 3 + *((npy_bool *)data_out + @i@) = + ((*((npy_bool *)data0 + @i@)) && + (*((npy_bool *)data1 + @i@)) && + (*((npy_bool *)data2 + @i@))) || + (*((npy_bool *)data_out + @i@)); + data0 += 8*sizeof(npy_bool); + data1 += 8*sizeof(npy_bool); + data2 += 8*sizeof(npy_bool); + data_out += 8*sizeof(npy_bool); +# endif +/**end repeat1**/ + case 0: + return; + } +#endif + +/* Unroll the loop by 8 for fixed-size nop */ +#if (@nop@ <= 3) + while (count >= 8) { + count -= 8; #else while (count--) { #endif # if @nop@ == 1 /**begin repeat1 - * #i = 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15# + * #i = 0, 1, 2, 3, 4, 5, 6, 7# */ *((npy_bool *)data_out + @i@) = (*((npy_bool *)data0 + @i@)) || (*((npy_bool *)data_out + @i@)); /**end repeat1**/ - data0 += 16*sizeof(npy_bool); - data_out += 16*sizeof(npy_bool); + data0 += 8*sizeof(npy_bool); + data_out += 8*sizeof(npy_bool); # elif @nop@ == 2 /**begin repeat1 - * #i = 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15# + * #i = 0, 1, 2, 3, 4, 5, 6, 7# */ *((npy_bool *)data_out + @i@) = ((*((npy_bool *)data0 + @i@)) && (*((npy_bool *)data1 + @i@))) || (*((npy_bool *)data_out + @i@)); /**end repeat1**/ - data0 += 16*sizeof(npy_bool); - data1 += 16*sizeof(npy_bool); - data_out += 16*sizeof(npy_bool); + data0 += 8*sizeof(npy_bool); + data1 += 8*sizeof(npy_bool); + data_out += 8*sizeof(npy_bool); # elif @nop@ == 3 /**begin repeat1 - * #i = 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15# + * #i = 0, 1, 2, 3, 4, 5, 6, 7# */ *((npy_bool *)data_out + @i@) = ((*((npy_bool *)data0 + @i@)) && @@ -1020,10 +1135,10 @@ bool_sum_of_products_contig_@noplabel@(int nop, char **dataptr, (*((npy_bool *)data2 + @i@))) || (*((npy_bool *)data_out + @i@)); /**end repeat1**/ - data0 += 16*sizeof(npy_bool); - data1 += 16*sizeof(npy_bool); - data2 += 16*sizeof(npy_bool); - data_out += 16*sizeof(npy_bool); + data0 += 8*sizeof(npy_bool); + data1 += 8*sizeof(npy_bool); + data2 += 8*sizeof(npy_bool); + data_out += 8*sizeof(npy_bool); # else npy_bool temp = *(npy_bool *)dataptr[0]; int i; @@ -1039,51 +1154,7 @@ bool_sum_of_products_contig_@noplabel@(int nop, char **dataptr, /* If the loop was unrolled, we need to finish it off */ #if (@nop@ <= 3) -# if @nop@ == 1 -/**begin repeat1 - * #i = 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15# - */ - if (count-- == 0) { - return; - } - *((npy_bool *)data_out + @i@) = (*((npy_bool *)data0 + @i@)) || - (*((npy_bool *)data_out + @i@)); -/**end repeat1**/ - data0 += 16*sizeof(npy_bool); - data_out += 16*sizeof(npy_bool); -# elif @nop@ == 2 -/**begin repeat1 - * #i = 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15# - */ - if (count-- == 0) { - return; - } - *((npy_bool *)data_out + @i@) = - ((*((npy_bool *)data0 + @i@)) && - (*((npy_bool *)data1 + @i@))) || - (*((npy_bool *)data_out + @i@)); -/**end repeat1**/ - data0 += 16*sizeof(npy_bool); - data1 += 16*sizeof(npy_bool); - data_out += 16*sizeof(npy_bool); -# elif @nop@ == 3 -/**begin repeat1 - * #i = 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15# - */ - if (count-- == 0) { - return; - } - *((npy_bool *)data_out + @i@) = - ((*((npy_bool *)data0 + @i@)) && - (*((npy_bool *)data1 + @i@)) && - (*((npy_bool *)data2 + @i@))) || - (*((npy_bool *)data_out + @i@)); -/**end repeat1**/ - data0 += 16*sizeof(npy_bool); - data1 += 16*sizeof(npy_bool); - data2 += 16*sizeof(npy_bool); - data_out += 16*sizeof(npy_bool); -# endif + goto finish_after_unrolled_loop; #endif } @@ -2405,7 +2476,9 @@ PyArray_EinsteinSum(char *subscripts, npy_intp nop, if (!needs_api) { NPY_BEGIN_THREADS; } + /*printf("Einsum loop\n");*/ do { + /*printf("Einsum inner loop count %d\n", (int)*countptr);*/ sop(nop, dataptr, stride, *countptr); } while(iternext(iter)); if (!needs_api) { diff --git a/numpy/core/tests/test_numeric.py b/numpy/core/tests/test_numeric.py index 34d295a8b995..384f9745919e 100644 --- a/numpy/core/tests/test_numeric.py +++ b/numpy/core/tests/test_numeric.py @@ -342,6 +342,8 @@ def test_einsum_views(self): assert_equal(b, a.swapaxes(0,1)) def check_einsum_sums(self, dtype): + # Check various sums. Does many sizes to exercise unrolled loops. + # sum(a, axis=-1) for n in range(1,17): a = np.arange(n, dtype=dtype) From 909e30a9c21e6f1d0c1a83759b42ec5efd3f2054 Mon Sep 17 00:00:00 2001 From: Mark Wiebe Date: Mon, 31 Jan 2011 14:04:44 -0800 Subject: [PATCH 4/7] ENH: einsum: Change function selection function to use tables --- numpy/core/src/multiarray/einsum.c.src | 352 ++++++++++++------------- 1 file changed, 163 insertions(+), 189 deletions(-) diff --git a/numpy/core/src/multiarray/einsum.c.src b/numpy/core/src/multiarray/einsum.c.src index 91cb558b8f16..d681016b7ff3 100644 --- a/numpy/core/src/multiarray/einsum.c.src +++ b/numpy/core/src/multiarray/einsum.c.src @@ -1216,153 +1216,185 @@ bool_sum_of_products_outstride0_@noplabel@(int nop, char **dataptr, typedef void (*sum_of_products_fn)(int, char **, npy_intp *, npy_intp); -static sum_of_products_fn -get_sum_of_products_function(int nop, int type_num, - npy_intp itemsize, npy_intp *fixed_strides) -{ - int iop; +/* These tables need to match up with the type enum */ - /* nop of 2 has more specializations */ - if (nop == 2) { - if (fixed_strides[0] == itemsize) { - if (fixed_strides[1] == itemsize) { - if (fixed_strides[2] == itemsize) { - /* contig, contig, contig */ - switch (type_num) { -/**begin repeat - * #name = byte, short, int, long, longlong, - * ubyte, ushort, uint, ulong, ulonglong, - * half, float, double, longdouble# - * #NAME = BYTE, SHORT, INT, LONG, LONGLONG, - * UBYTE, USHORT, UINT, ULONG, ULONGLONG, - * HALF, FLOAT, DOUBLE, LONGDOUBLE# - */ - case NPY_@NAME@: - return &@name@_sum_of_products_contig_two; -/**end repeat**/ - } - } - else if (fixed_strides[2] == 0) { - /* contig, contig, stride0 */ - switch (type_num) { +static sum_of_products_fn _binary_specialization_table[NPY_NTYPES][5] = { /**begin repeat - * #name = byte, short, int, long, longlong, - * ubyte, ushort, uint, ulong, ulonglong, - * half, float, double, longdouble# - * #NAME = BYTE, SHORT, INT, LONG, LONGLONG, - * UBYTE, USHORT, UINT, ULONG, ULONGLONG, - * HALF, FLOAT, DOUBLE, LONGDOUBLE# - */ - case NPY_@NAME@: - return &@name@_sum_of_products_contig_contig_outstride0_two; -/**end repeat**/ - } - } - } - else if (fixed_strides[1] == 0) { - if (fixed_strides[2] == itemsize) { - /* contig, stride0, contig */ - switch (type_num) { -/**begin repeat - * #name = byte, short, int, long, longlong, - * ubyte, ushort, uint, ulong, ulonglong, - * half, float, double, longdouble# - * #NAME = BYTE, SHORT, INT, LONG, LONGLONG, - * UBYTE, USHORT, UINT, ULONG, ULONGLONG, - * HALF, FLOAT, DOUBLE, LONGDOUBLE# + * #name = bool, + * byte, ubyte, + * short, ushort, + * int, uint, + * long, ulong, + * longlong, ulonglong, + * float, double, longdouble, + * cfloat, cdouble, clongdouble, + * object, string, unicode, void, + * datetime, timedelta, half# + * #use = 0, + * 1, 1, + * 1, 1, + * 1, 1, + * 1, 1, + * 1, 1, + * 1, 1, 1, + * 0, 0, 0, + * 0, 0, 0, 0, + * 0, 0, 1# */ - case NPY_@NAME@: - return &@name@_sum_of_products_contig_stride0_outcontig_two; +#if @use@ +{ + &@name@_sum_of_products_stride0_contig_outstride0_two, + &@name@_sum_of_products_stride0_contig_outcontig_two, + &@name@_sum_of_products_contig_stride0_outstride0_two, + &@name@_sum_of_products_contig_stride0_outcontig_two, + &@name@_sum_of_products_contig_contig_outstride0_two, +}, +#else + {NULL, NULL, NULL, NULL, NULL}, +#endif /**end repeat**/ - } - } - else if (fixed_strides[2] == 0) { - /* contig, stride0, stride0 */ - switch (type_num) { +}; /* End of _binary_specialization_table */ + +static sum_of_products_fn _outstride0_specialized_table[NPY_NTYPES][4] = { /**begin repeat - * #name = byte, short, int, long, longlong, - * ubyte, ushort, uint, ulong, ulonglong, - * half, float, double, longdouble# - * #NAME = BYTE, SHORT, INT, LONG, LONGLONG, - * UBYTE, USHORT, UINT, ULONG, ULONGLONG, - * HALF, FLOAT, DOUBLE, LONGDOUBLE# + * #name = bool, + * byte, ubyte, + * short, ushort, + * int, uint, + * long, ulong, + * longlong, ulonglong, + * float, double, longdouble, + * cfloat, cdouble, clongdouble, + * object, string, unicode, void, + * datetime, timedelta, half# + * #use = 1, + * 1, 1, + * 1, 1, + * 1, 1, + * 1, 1, + * 1, 1, + * 1, 1, 1, + * 1, 1, 1, + * 0, 0, 0, 0, + * 0, 0, 1# */ - case NPY_@NAME@: - return &@name@_sum_of_products_contig_stride0_outstride0_two; +#if @use@ +{ + &@name@_sum_of_products_outstride0_any, + &@name@_sum_of_products_outstride0_one, + &@name@_sum_of_products_outstride0_two, + &@name@_sum_of_products_outstride0_three +}, +#else + {NULL, NULL, NULL, NULL}, +#endif /**end repeat**/ - } - } - } - } - else if (fixed_strides[0] == 0) { - if (fixed_strides[1] == itemsize) { - if (fixed_strides[2] == itemsize) { - /* stride0, contig, contig */ - switch (type_num) { +}; /* End of _outstride0_specialized_table */ + +static sum_of_products_fn _allcontig_specialized_table[NPY_NTYPES][4] = { /**begin repeat - * #name = byte, short, int, long, longlong, - * ubyte, ushort, uint, ulong, ulonglong, - * half, float, double, longdouble# - * #NAME = BYTE, SHORT, INT, LONG, LONGLONG, - * UBYTE, USHORT, UINT, ULONG, ULONGLONG, - * HALF, FLOAT, DOUBLE, LONGDOUBLE# + * #name = bool, + * byte, ubyte, + * short, ushort, + * int, uint, + * long, ulong, + * longlong, ulonglong, + * float, double, longdouble, + * cfloat, cdouble, clongdouble, + * object, string, unicode, void, + * datetime, timedelta, half# + * #use = 1, + * 1, 1, + * 1, 1, + * 1, 1, + * 1, 1, + * 1, 1, + * 1, 1, 1, + * 1, 1, 1, + * 0, 0, 0, 0, + * 0, 0, 1# */ - case NPY_@NAME@: - return &@name@_sum_of_products_stride0_contig_outcontig_two; +#if @use@ +{ + &@name@_sum_of_products_contig_any, + &@name@_sum_of_products_contig_one, + &@name@_sum_of_products_contig_two, + &@name@_sum_of_products_contig_three +}, +#else + {NULL, NULL, NULL, NULL}, +#endif /**end repeat**/ - } - } - else if (fixed_strides[2] == 0) { - /* stride0, contig, stride0 */ - switch (type_num) { +}; /* End of _allcontig_specialized_table */ + +static sum_of_products_fn _unspecialized_table[NPY_NTYPES][4] = { /**begin repeat - * #name = byte, short, int, long, longlong, - * ubyte, ushort, uint, ulong, ulonglong, - * half, float, double, longdouble# - * #NAME = BYTE, SHORT, INT, LONG, LONGLONG, - * UBYTE, USHORT, UINT, ULONG, ULONGLONG, - * HALF, FLOAT, DOUBLE, LONGDOUBLE# + * #name = bool, + * byte, ubyte, + * short, ushort, + * int, uint, + * long, ulong, + * longlong, ulonglong, + * float, double, longdouble, + * cfloat, cdouble, clongdouble, + * object, string, unicode, void, + * datetime, timedelta, half# + * #use = 1, + * 1, 1, + * 1, 1, + * 1, 1, + * 1, 1, + * 1, 1, + * 1, 1, 1, + * 1, 1, 1, + * 0, 0, 0, 0, + * 0, 0, 1# */ - case NPY_@NAME@: - return &@name@_sum_of_products_stride0_contig_outstride0_two; +#if @use@ +{ + &@name@_sum_of_products_any, + &@name@_sum_of_products_one, + &@name@_sum_of_products_two, + &@name@_sum_of_products_three +}, +#else + {NULL, NULL, NULL, NULL}, +#endif /**end repeat**/ - } - } +}; /* End of _unnspecialized_table */ + +static sum_of_products_fn +get_sum_of_products_function(int nop, int type_num, + npy_intp itemsize, npy_intp *fixed_strides) +{ + int iop; + + if (type_num >= NPY_NTYPES) { + return NULL; + } + + /* nop of 2 has more specializations */ + if (nop == 2) { + /* Encode the zero/contiguous strides */ + int code; + code = (fixed_strides[0] == 0) ? 0 : + (fixed_strides[0] == itemsize) ? 2*2*1 : 8; + code += (fixed_strides[1] == 0) ? 0 : + (fixed_strides[1] == itemsize) ? 2*1 : 8; + code += (fixed_strides[2] == 0) ? 0 : + (fixed_strides[2] == itemsize) ? 1 : 8; + if (code >= 2 && code < 7) { + sum_of_products_fn ret = + _binary_specialization_table[type_num][code-2]; + if (ret != NULL) { + return ret; } } } /* Inner loop with an output stride of 0 */ if (fixed_strides[nop] == 0) { - switch (type_num) { -/**begin repeat - * #name = bool, - * byte, short, int, long, longlong, - * ubyte, ushort, uint, ulong, ulonglong, - * half, float, double, longdouble, - * cfloat, cdouble, clongdouble# - * #NAME = BOOL, - * BYTE, SHORT, INT, LONG, LONGLONG, - * UBYTE, USHORT, UINT, ULONG, ULONGLONG, - * HALF, FLOAT, DOUBLE, LONGDOUBLE, - * CFLOAT, CDOUBLE, CLONGDOUBLE# - */ - case NPY_@NAME@: - switch (nop) { -/**begin repeat1 - * #nop = 1, 2, 3, 1000# - * #noplabel = one, two, three, any# - */ -#if @nop@ <= 3 - case @nop@: -#else - default: -#endif - return &@name@_sum_of_products_outstride0_@noplabel@; -/**end repeat1**/ - } -/**end repeat**/ - } + return _outstride0_specialized_table[type_num][nop <= 3 ? nop : 0]; } /* Check for all contiguous */ @@ -1374,69 +1406,11 @@ get_sum_of_products_function(int nop, int type_num, /* Contiguous loop */ if (iop == nop) { - switch (type_num) { -/**begin repeat - * #name = bool, - * byte, short, int, long, longlong, - * ubyte, ushort, uint, ulong, ulonglong, - * half, float, double, longdouble, - * cfloat, cdouble, clongdouble# - * #NAME = BOOL, - * BYTE, SHORT, INT, LONG, LONGLONG, - * UBYTE, USHORT, UINT, ULONG, ULONGLONG, - * HALF, FLOAT, DOUBLE, LONGDOUBLE, - * CFLOAT, CDOUBLE, CLONGDOUBLE# - */ - case NPY_@NAME@: - switch (nop) { -/**begin repeat1 - * #nop = 1, 2, 3, 1000# - * #noplabel = one, two, three, any# - */ -#if @nop@ <= 3 - case @nop@: -#else - default: -#endif - return &@name@_sum_of_products_contig_@noplabel@; -/**end repeat1**/ - } -/**end repeat**/ - } - } - - /* Regular inner loop */ - switch (type_num) { -/**begin repeat - * #name = bool, - * byte, short, int, long, longlong, - * ubyte, ushort, uint, ulong, ulonglong, - * half, float, double, longdouble, - * cfloat, cdouble, clongdouble# - * #NAME = BOOL, - * BYTE, SHORT, INT, LONG, LONGLONG, - * UBYTE, USHORT, UINT, ULONG, ULONGLONG, - * HALF, FLOAT, DOUBLE, LONGDOUBLE, - * CFLOAT, CDOUBLE, CLONGDOUBLE# - */ - case NPY_@NAME@: - switch (nop) { -/**begin repeat1 - * #nop = 1, 2, 3, 1000# - * #noplabel = one, two, three, any# - */ -#if @nop@ <= 3 - case @nop@: -#else - default: -#endif - return &@name@_sum_of_products_@noplabel@; -/**end repeat1**/ - } -/**end repeat**/ + return _allcontig_specialized_table[type_num][nop <= 3 ? nop : 0]; } - return NULL; + /* None of the above specializations caught it, general loops */ + return _unspecialized_table[type_num][nop <= 3 ? nop : 0]; } /* From f8fccd8a4d77789a41413a98e5486d2f2bad0f02 Mon Sep 17 00:00:00 2001 From: Mark Wiebe Date: Tue, 1 Feb 2011 09:24:17 -0800 Subject: [PATCH 5/7] TST: einsum: Move einsum tests to a different file --- numpy/core/tests/test_einsum.py | 470 +++++++++++++++++++++++++++++++ numpy/core/tests/test_numeric.py | 460 ------------------------------ 2 files changed, 470 insertions(+), 460 deletions(-) create mode 100644 numpy/core/tests/test_einsum.py diff --git a/numpy/core/tests/test_einsum.py b/numpy/core/tests/test_einsum.py new file mode 100644 index 000000000000..7f29889c2a70 --- /dev/null +++ b/numpy/core/tests/test_einsum.py @@ -0,0 +1,470 @@ +import sys +from decimal import Decimal + +import numpy as np +from numpy.testing import * +from numpy.testing.utils import WarningManager +import warnings + +class TestEinSum(TestCase): + def test_einsum_errors(self): + # Need enough arguments + assert_raises(ValueError, np.einsum) + assert_raises(ValueError, np.einsum, "") + + # subscripts must be a string + assert_raises(TypeError, np.einsum, 0, 0) + + # out parameter must be an array + assert_raises(TypeError, np.einsum, "", 0, out='test') + + # order parameter must be a valid order + assert_raises(TypeError, np.einsum, "", 0, order='W') + + # casting parameter must be a valid casting + assert_raises(ValueError, np.einsum, "", 0, casting='blah') + + # dtype parameter must be a valid dtype + assert_raises(TypeError, np.einsum, "", 0, dtype='bad_data_type') + + # other keyword arguments are rejected + assert_raises(TypeError, np.einsum, "", 0, bad_arg=0) + + # number of operands must match count in subscripts string + assert_raises(ValueError, np.einsum, "", 0, 0) + assert_raises(ValueError, np.einsum, ",", 0, [0], [0]) + assert_raises(ValueError, np.einsum, ",", [0]) + + # can't have more subscripts than dimensions in the operand + assert_raises(ValueError, np.einsum, "i", 0) + assert_raises(ValueError, np.einsum, "ij", [0,0]) + assert_raises(ValueError, np.einsum, "...i", 0) + assert_raises(ValueError, np.einsum, "i...j", [0,0]) + assert_raises(ValueError, np.einsum, "i...", 0) + assert_raises(ValueError, np.einsum, "ij...", [0,0]) + + # invalid ellipsis + assert_raises(ValueError, np.einsum, "i..", [0,0]) + assert_raises(ValueError, np.einsum, ".i...", [0,0]) + assert_raises(ValueError, np.einsum, "j->..j", [0,0]) + assert_raises(ValueError, np.einsum, "j->.j...", [0,0]) + + # invalid subscript character + assert_raises(ValueError, np.einsum, "i%...", [0,0]) + assert_raises(ValueError, np.einsum, "...j$", [0,0]) + assert_raises(ValueError, np.einsum, "i->&", [0,0]) + + # output subscripts must appear in input + assert_raises(ValueError, np.einsum, "i->ij", [0,0]) + + # output subscripts may only be specified once + assert_raises(ValueError, np.einsum, "ij->jij", [[0,0],[0,0]]) + + # dimensions much match when being collapsed + assert_raises(ValueError, np.einsum, "ii", np.arange(6).reshape(2,3)) + assert_raises(ValueError, np.einsum, "ii->i", np.arange(6).reshape(2,3)) + + # broadcasting to new dimensions must be enabled explicitly + assert_raises(ValueError, np.einsum, "i", np.arange(6).reshape(2,3)) + assert_raises(ValueError, np.einsum, "i->i", [[0,1],[0,1]], + out=np.arange(4).reshape(2,2)) + + def test_einsum_views(self): + # pass-through + a = np.arange(6).reshape(2,3) + + b = np.einsum("...", a) + assert_(b.base is a) + + b = np.einsum(a, [Ellipsis]) + assert_(b.base is a) + + b = np.einsum("ij", a) + assert_(b.base is a) + assert_equal(b, a) + + b = np.einsum(a, [0,1]) + assert_(b.base is a) + assert_equal(b, a) + + # transpose + a = np.arange(6).reshape(2,3) + + b = np.einsum("ji", a) + assert_(b.base is a) + assert_equal(b, a.T) + + b = np.einsum(a, [1,0]) + assert_(b.base is a) + assert_equal(b, a.T) + + # diagonal + a = np.arange(9).reshape(3,3) + + b = np.einsum("ii->i", a) + assert_(b.base is a) + assert_equal(b, [a[i,i] for i in range(3)]) + + b = np.einsum(a, [0,0], [0]) + assert_(b.base is a) + assert_equal(b, [a[i,i] for i in range(3)]) + + # diagonal with various ways of broadcasting an additional dimension + a = np.arange(27).reshape(3,3,3) + + b = np.einsum("...ii->...i", a) + assert_(b.base is a) + assert_equal(b, [[x[i,i] for i in range(3)] for x in a]) + + b = np.einsum(a, [Ellipsis,0,0], [Ellipsis,0]) + assert_(b.base is a) + assert_equal(b, [[x[i,i] for i in range(3)] for x in a]) + + b = np.einsum("ii...->...i", a) + assert_(b.base is a) + assert_equal(b, [[x[i,i] for i in range(3)] + for x in a.transpose(2,0,1)]) + + b = np.einsum(a, [0,0,Ellipsis], [Ellipsis,0]) + assert_(b.base is a) + assert_equal(b, [[x[i,i] for i in range(3)] + for x in a.transpose(2,0,1)]) + + b = np.einsum("...ii->i...", a) + assert_(b.base is a) + assert_equal(b, [a[:,i,i] for i in range(3)]) + + b = np.einsum(a, [Ellipsis,0,0], [0,Ellipsis]) + assert_(b.base is a) + assert_equal(b, [a[:,i,i] for i in range(3)]) + + b = np.einsum("jii->ij", a) + assert_(b.base is a) + assert_equal(b, [a[:,i,i] for i in range(3)]) + + b = np.einsum(a, [1,0,0], [0,1]) + assert_(b.base is a) + assert_equal(b, [a[:,i,i] for i in range(3)]) + + b = np.einsum("ii...->i...", a) + assert_(b.base is a) + assert_equal(b, [a.transpose(2,0,1)[:,i,i] for i in range(3)]) + + b = np.einsum(a, [0,0,Ellipsis], [0,Ellipsis]) + assert_(b.base is a) + assert_equal(b, [a.transpose(2,0,1)[:,i,i] for i in range(3)]) + + b = np.einsum("i...i->i...", a) + assert_(b.base is a) + assert_equal(b, [a.transpose(1,0,2)[:,i,i] for i in range(3)]) + + b = np.einsum(a, [0,Ellipsis,0], [0,Ellipsis]) + assert_(b.base is a) + assert_equal(b, [a.transpose(1,0,2)[:,i,i] for i in range(3)]) + + b = np.einsum("i...i->...i", a) + assert_(b.base is a) + assert_equal(b, [[x[i,i] for i in range(3)] + for x in a.transpose(1,0,2)]) + + b = np.einsum(a, [0,Ellipsis,0], [Ellipsis,0]) + assert_(b.base is a) + assert_equal(b, [[x[i,i] for i in range(3)] + for x in a.transpose(1,0,2)]) + + # triple diagonal + a = np.arange(27).reshape(3,3,3) + + b = np.einsum("iii->i", a) + assert_(b.base is a) + assert_equal(b, [a[i,i,i] for i in range(3)]) + + b = np.einsum(a, [0,0,0], [0]) + assert_(b.base is a) + assert_equal(b, [a[i,i,i] for i in range(3)]) + + # swap axes + a = np.arange(24).reshape(2,3,4) + + b = np.einsum("ijk->jik", a) + assert_(b.base is a) + assert_equal(b, a.swapaxes(0,1)) + + b = np.einsum(a, [0,1,2], [1,0,2]) + assert_(b.base is a) + assert_equal(b, a.swapaxes(0,1)) + + def check_einsum_sums(self, dtype): + # Check various sums. Does many sizes to exercise unrolled loops. + + # sum(a, axis=-1) + for n in range(1,17): + a = np.arange(n, dtype=dtype) + assert_equal(np.einsum("i->", a), np.sum(a, axis=-1).astype(dtype)) + assert_equal(np.einsum(a, [0], []), + np.sum(a, axis=-1).astype(dtype)) + + for n in range(1,17): + a = np.arange(2*3*n, dtype=dtype).reshape(2,3,n) + assert_equal(np.einsum("...i->...", a), + np.sum(a, axis=-1).astype(dtype)) + assert_equal(np.einsum(a, [Ellipsis,0], [Ellipsis]), + np.sum(a, axis=-1).astype(dtype)) + + # sum(a, axis=0) + for n in range(1,17): + a = np.arange(2*n, dtype=dtype).reshape(2,n) + assert_equal(np.einsum("i...->...", a), + np.sum(a, axis=0).astype(dtype)) + assert_equal(np.einsum(a, [0,Ellipsis], [Ellipsis]), + np.sum(a, axis=0).astype(dtype)) + + for n in range(1,17): + a = np.arange(2*3*n, dtype=dtype).reshape(2,3,n) + assert_equal(np.einsum("i...->...", a), + np.sum(a, axis=0).astype(dtype)) + assert_equal(np.einsum(a, [0,Ellipsis], [Ellipsis]), + np.sum(a, axis=0).astype(dtype)) + + # trace(a) + for n in range(1,17): + a = np.arange(n*n, dtype=dtype).reshape(n,n) + assert_equal(np.einsum("ii", a), np.trace(a).astype(dtype)) + assert_equal(np.einsum(a, [0,0]), np.trace(a).astype(dtype)) + + # multiply(a, b) + for n in range(1,17): + a = np.arange(3*n, dtype=dtype).reshape(3,n) + b = np.arange(2*3*n, dtype=dtype).reshape(2,3,n) + assert_equal(np.einsum("..., ...", a, b), np.multiply(a, b)) + assert_equal(np.einsum(a, [Ellipsis], b, [Ellipsis]), + np.multiply(a, b)) + + # inner(a,b) + for n in range(1,17): + a = np.arange(2*3*n, dtype=dtype).reshape(2,3,n) + b = np.arange(n, dtype=dtype) + assert_equal(np.einsum("...i, ...i", a, b), np.inner(a, b)) + assert_equal(np.einsum(a, [Ellipsis,0], b, [Ellipsis,0]), + np.inner(a, b)) + + for n in range(1,11): + a = np.arange(n*3*2, dtype=dtype).reshape(n,3,2) + b = np.arange(n, dtype=dtype) + assert_equal(np.einsum("i..., i...", a, b), np.inner(a.T, b.T).T) + assert_equal(np.einsum(a, [0,Ellipsis], b, [0,Ellipsis]), + np.inner(a.T, b.T).T) + + # outer(a,b) + for n in range(1,17): + a = np.arange(3, dtype=dtype)+1 + b = np.arange(n, dtype=dtype)+1 + assert_equal(np.einsum("i,j", a, b), np.outer(a, b)) + assert_equal(np.einsum(a, [0], b, [1]), np.outer(a, b)) + + # Suppress the complex warnings for the 'as f8' tests + ctx = WarningManager() + ctx.__enter__() + try: + warnings.simplefilter('ignore', np.ComplexWarning) + + # matvec(a,b) / a.dot(b) where a is matrix, b is vector + for n in range(1,17): + a = np.arange(4*n, dtype=dtype).reshape(4,n) + b = np.arange(n, dtype=dtype) + assert_equal(np.einsum("ij, j", a, b), np.dot(a, b)) + assert_equal(np.einsum(a, [0,1], b, [1]), np.dot(a, b)) + + c = np.arange(4, dtype=dtype) + np.einsum("ij,j", a, b, out=c, + dtype='f8', casting='unsafe') + assert_equal(c, + np.dot(a.astype('f8'), + b.astype('f8')).astype(dtype)) + c[...] = 0 + np.einsum(a, [0,1], b, [1], out=c, + dtype='f8', casting='unsafe') + assert_equal(c, + np.dot(a.astype('f8'), + b.astype('f8')).astype(dtype)) + + for n in range(1,17): + a = np.arange(4*n, dtype=dtype).reshape(4,n) + b = np.arange(n, dtype=dtype) + assert_equal(np.einsum("ji,j", a.T, b.T), np.dot(b.T, a.T)) + assert_equal(np.einsum(a.T, [1,0], b.T, [1]), np.dot(b.T, a.T)) + + c = np.arange(4, dtype=dtype) + np.einsum("ji,j", a.T, b.T, out=c, dtype='f8', casting='unsafe') + assert_equal(c, + np.dot(b.T.astype('f8'), + a.T.astype('f8')).astype(dtype)) + c[...] = 0 + np.einsum(a.T, [1,0], b.T, [1], out=c, + dtype='f8', casting='unsafe') + assert_equal(c, + np.dot(b.T.astype('f8'), + a.T.astype('f8')).astype(dtype)) + + # matmat(a,b) / a.dot(b) where a is matrix, b is matrix + for n in range(1,17): + if n < 8 or dtype != 'f2': + a = np.arange(4*n, dtype=dtype).reshape(4,n) + b = np.arange(n*6, dtype=dtype).reshape(n,6) + assert_equal(np.einsum("ij,jk", a, b), np.dot(a, b)) + assert_equal(np.einsum(a, [0,1], b, [1,2]), np.dot(a, b)) + + for n in range(1,17): + a = np.arange(4*n, dtype=dtype).reshape(4,n) + b = np.arange(n*6, dtype=dtype).reshape(n,6) + c = np.arange(24, dtype=dtype).reshape(4,6) + np.einsum("ij,jk", a, b, out=c, dtype='f8', casting='unsafe') + assert_equal(c, + np.dot(a.astype('f8'), + b.astype('f8')).astype(dtype)) + c[...] = 0 + np.einsum(a, [0,1], b, [1,2], out=c, + dtype='f8', casting='unsafe') + assert_equal(c, + np.dot(a.astype('f8'), + b.astype('f8')).astype(dtype)) + + # matrix triple product (note this is not currently an efficient + # way to multiply 3 matrices) + a = np.arange(12, dtype=dtype).reshape(3,4) + b = np.arange(20, dtype=dtype).reshape(4,5) + c = np.arange(30, dtype=dtype).reshape(5,6) + if dtype != 'f2': + assert_equal(np.einsum("ij,jk,kl", a, b, c), + a.dot(b).dot(c)) + assert_equal(np.einsum(a, [0,1], b, [1,2], c, [2,3]), + a.dot(b).dot(c)) + + d = np.arange(18, dtype=dtype).reshape(3,6) + np.einsum("ij,jk,kl", a, b, c, out=d, + dtype='f8', casting='unsafe') + assert_equal(d, a.astype('f8').dot(b.astype('f8') + ).dot(c.astype('f8')).astype(dtype)) + d[...] = 0 + np.einsum(a, [0,1], b, [1,2], c, [2,3], out=d, + dtype='f8', casting='unsafe') + assert_equal(d, a.astype('f8').dot(b.astype('f8') + ).dot(c.astype('f8')).astype(dtype)) + + # tensordot(a, b) + if np.dtype(dtype) != np.dtype('f2'): + a = np.arange(60, dtype=dtype).reshape(3,4,5) + b = np.arange(24, dtype=dtype).reshape(4,3,2) + assert_equal(np.einsum("ijk, jil -> kl", a, b), + np.tensordot(a,b, axes=([1,0],[0,1]))) + assert_equal(np.einsum(a, [0,1,2], b, [1,0,3], [2,3]), + np.tensordot(a,b, axes=([1,0],[0,1]))) + + c = np.arange(10, dtype=dtype).reshape(5,2) + np.einsum("ijk,jil->kl", a, b, out=c, + dtype='f8', casting='unsafe') + assert_equal(c, np.tensordot(a.astype('f8'), b.astype('f8'), + axes=([1,0],[0,1])).astype(dtype)) + c[...] = 0 + np.einsum(a, [0,1,2], b, [1,0,3], [2,3], out=c, + dtype='f8', casting='unsafe') + assert_equal(c, np.tensordot(a.astype('f8'), b.astype('f8'), + axes=([1,0],[0,1])).astype(dtype)) + finally: + ctx.__exit__() + + # logical_and(logical_and(a!=0, b!=0), c!=0) + a = np.array([1, 3, -2, 0, 12, 13, 0, 1], dtype=dtype) + b = np.array([0, 3.5, 0., -2, 0, 1, 3, 12], dtype=dtype) + c = np.array([True,True,False,True,True,False,True,True]) + assert_equal(np.einsum("i,i,i->i", a, b, c, + dtype='?', casting='unsafe'), + np.logical_and(np.logical_and(a!=0, b!=0), c!=0)) + assert_equal(np.einsum(a, [0], b, [0], c, [0], [0], + dtype='?', casting='unsafe'), + np.logical_and(np.logical_and(a!=0, b!=0), c!=0)) + + a = np.arange(9, dtype=dtype) + assert_equal(np.einsum(",i->", 3, a), 3*np.sum(a)) + assert_equal(np.einsum(3, [], a, [0], []), 3*np.sum(a)) + assert_equal(np.einsum("i,->", a, 3), 3*np.sum(a)) + assert_equal(np.einsum(a, [0], 3, [], []), 3*np.sum(a)) + + # Various stride0, contiguous, and SSE aligned variants + for n in range(1,25): + a = np.arange(n, dtype=dtype) + if np.dtype(dtype).itemsize > 1: + assert_equal(np.einsum("...,...",a,a), np.multiply(a,a)) + assert_equal(np.einsum("i,i", a, a), np.dot(a,a)) + assert_equal(np.einsum("i,->i", a, 2), 2*a) + assert_equal(np.einsum(",i->i", 2, a), 2*a) + assert_equal(np.einsum("i,->", a, 2), 2*np.sum(a)) + assert_equal(np.einsum(",i->", 2, a), 2*np.sum(a)) + + assert_equal(np.einsum("...,...",a[1:],a[:-1]), + np.multiply(a[1:],a[:-1])) + assert_equal(np.einsum("i,i", a[1:], a[:-1]), + np.dot(a[1:],a[:-1])) + assert_equal(np.einsum("i,->i", a[1:], 2), 2*a[1:]) + assert_equal(np.einsum(",i->i", 2, a[1:]), 2*a[1:]) + assert_equal(np.einsum("i,->", a[1:], 2), 2*np.sum(a[1:])) + assert_equal(np.einsum(",i->", 2, a[1:]), 2*np.sum(a[1:])) + + # An object array, summed as the data type + a = np.arange(9, dtype=object) + + b = np.einsum("i->", a, dtype=dtype, casting='unsafe') + assert_equal(b, np.sum(a)) + assert_equal(b.dtype, np.dtype(dtype)) + + b = np.einsum(a, [0], [], dtype=dtype, casting='unsafe') + assert_equal(b, np.sum(a)) + assert_equal(b.dtype, np.dtype(dtype)) + + def test_einsum_sums_int8(self): + self.check_einsum_sums('i1'); + + def test_einsum_sums_uint8(self): + self.check_einsum_sums('u1'); + + def test_einsum_sums_int16(self): + self.check_einsum_sums('i2'); + + def test_einsum_sums_uint16(self): + self.check_einsum_sums('u2'); + + def test_einsum_sums_int32(self): + self.check_einsum_sums('i4'); + + def test_einsum_sums_uint32(self): + self.check_einsum_sums('u4'); + + def test_einsum_sums_int64(self): + self.check_einsum_sums('i8'); + + def test_einsum_sums_uint64(self): + self.check_einsum_sums('u8'); + + def test_einsum_sums_float16(self): + self.check_einsum_sums('f2'); + + def test_einsum_sums_float32(self): + self.check_einsum_sums('f4'); + + def test_einsum_sums_float64(self): + self.check_einsum_sums('f8'); + + def test_einsum_sums_longdouble(self): + self.check_einsum_sums(np.longdouble); + + def test_einsum_sums_cfloat64(self): + self.check_einsum_sums('c8'); + + def test_einsum_sums_cfloat128(self): + self.check_einsum_sums('c16'); + + def test_einsum_sums_clongdouble(self): + self.check_einsum_sums(np.clongdouble); + +if __name__ == "__main__": + run_module_suite() diff --git a/numpy/core/tests/test_numeric.py b/numpy/core/tests/test_numeric.py index 384f9745919e..009065bb449f 100644 --- a/numpy/core/tests/test_numeric.py +++ b/numpy/core/tests/test_numeric.py @@ -153,466 +153,6 @@ def test_zeroresize(self): Ar = resize(A, (0,)) assert_equal(Ar, array([])) -class TestEinSum(TestCase): - def test_einsum_errors(self): - # Need enough arguments - assert_raises(ValueError, np.einsum) - assert_raises(ValueError, np.einsum, "") - - # subscripts must be a string - assert_raises(TypeError, np.einsum, 0, 0) - - # out parameter must be an array - assert_raises(TypeError, np.einsum, "", 0, out='test') - - # order parameter must be a valid order - assert_raises(TypeError, np.einsum, "", 0, order='W') - - # casting parameter must be a valid casting - assert_raises(ValueError, np.einsum, "", 0, casting='blah') - - # dtype parameter must be a valid dtype - assert_raises(TypeError, np.einsum, "", 0, dtype='bad_data_type') - - # other keyword arguments are rejected - assert_raises(TypeError, np.einsum, "", 0, bad_arg=0) - - # number of operands must match count in subscripts string - assert_raises(ValueError, np.einsum, "", 0, 0) - assert_raises(ValueError, np.einsum, ",", 0, [0], [0]) - assert_raises(ValueError, np.einsum, ",", [0]) - - # can't have more subscripts than dimensions in the operand - assert_raises(ValueError, np.einsum, "i", 0) - assert_raises(ValueError, np.einsum, "ij", [0,0]) - assert_raises(ValueError, np.einsum, "...i", 0) - assert_raises(ValueError, np.einsum, "i...j", [0,0]) - assert_raises(ValueError, np.einsum, "i...", 0) - assert_raises(ValueError, np.einsum, "ij...", [0,0]) - - # invalid ellipsis - assert_raises(ValueError, np.einsum, "i..", [0,0]) - assert_raises(ValueError, np.einsum, ".i...", [0,0]) - assert_raises(ValueError, np.einsum, "j->..j", [0,0]) - assert_raises(ValueError, np.einsum, "j->.j...", [0,0]) - - # invalid subscript character - assert_raises(ValueError, np.einsum, "i%...", [0,0]) - assert_raises(ValueError, np.einsum, "...j$", [0,0]) - assert_raises(ValueError, np.einsum, "i->&", [0,0]) - - # output subscripts must appear in input - assert_raises(ValueError, np.einsum, "i->ij", [0,0]) - - # output subscripts may only be specified once - assert_raises(ValueError, np.einsum, "ij->jij", [[0,0],[0,0]]) - - # dimensions much match when being collapsed - assert_raises(ValueError, np.einsum, "ii", np.arange(6).reshape(2,3)) - assert_raises(ValueError, np.einsum, "ii->i", np.arange(6).reshape(2,3)) - - # broadcasting to new dimensions must be enabled explicitly - assert_raises(ValueError, np.einsum, "i", np.arange(6).reshape(2,3)) - assert_raises(ValueError, np.einsum, "i->i", [[0,1],[0,1]], - out=np.arange(4).reshape(2,2)) - - def test_einsum_views(self): - # pass-through - a = np.arange(6).reshape(2,3) - - b = np.einsum("...", a) - assert_(b.base is a) - - b = np.einsum(a, [Ellipsis]) - assert_(b.base is a) - - b = np.einsum("ij", a) - assert_(b.base is a) - assert_equal(b, a) - - b = np.einsum(a, [0,1]) - assert_(b.base is a) - assert_equal(b, a) - - # transpose - a = np.arange(6).reshape(2,3) - - b = np.einsum("ji", a) - assert_(b.base is a) - assert_equal(b, a.T) - - b = np.einsum(a, [1,0]) - assert_(b.base is a) - assert_equal(b, a.T) - - # diagonal - a = np.arange(9).reshape(3,3) - - b = np.einsum("ii->i", a) - assert_(b.base is a) - assert_equal(b, [a[i,i] for i in range(3)]) - - b = np.einsum(a, [0,0], [0]) - assert_(b.base is a) - assert_equal(b, [a[i,i] for i in range(3)]) - - # diagonal with various ways of broadcasting an additional dimension - a = np.arange(27).reshape(3,3,3) - - b = np.einsum("...ii->...i", a) - assert_(b.base is a) - assert_equal(b, [[x[i,i] for i in range(3)] for x in a]) - - b = np.einsum(a, [Ellipsis,0,0], [Ellipsis,0]) - assert_(b.base is a) - assert_equal(b, [[x[i,i] for i in range(3)] for x in a]) - - b = np.einsum("ii...->...i", a) - assert_(b.base is a) - assert_equal(b, [[x[i,i] for i in range(3)] - for x in a.transpose(2,0,1)]) - - b = np.einsum(a, [0,0,Ellipsis], [Ellipsis,0]) - assert_(b.base is a) - assert_equal(b, [[x[i,i] for i in range(3)] - for x in a.transpose(2,0,1)]) - - b = np.einsum("...ii->i...", a) - assert_(b.base is a) - assert_equal(b, [a[:,i,i] for i in range(3)]) - - b = np.einsum(a, [Ellipsis,0,0], [0,Ellipsis]) - assert_(b.base is a) - assert_equal(b, [a[:,i,i] for i in range(3)]) - - b = np.einsum("jii->ij", a) - assert_(b.base is a) - assert_equal(b, [a[:,i,i] for i in range(3)]) - - b = np.einsum(a, [1,0,0], [0,1]) - assert_(b.base is a) - assert_equal(b, [a[:,i,i] for i in range(3)]) - - b = np.einsum("ii...->i...", a) - assert_(b.base is a) - assert_equal(b, [a.transpose(2,0,1)[:,i,i] for i in range(3)]) - - b = np.einsum(a, [0,0,Ellipsis], [0,Ellipsis]) - assert_(b.base is a) - assert_equal(b, [a.transpose(2,0,1)[:,i,i] for i in range(3)]) - - b = np.einsum("i...i->i...", a) - assert_(b.base is a) - assert_equal(b, [a.transpose(1,0,2)[:,i,i] for i in range(3)]) - - b = np.einsum(a, [0,Ellipsis,0], [0,Ellipsis]) - assert_(b.base is a) - assert_equal(b, [a.transpose(1,0,2)[:,i,i] for i in range(3)]) - - b = np.einsum("i...i->...i", a) - assert_(b.base is a) - assert_equal(b, [[x[i,i] for i in range(3)] - for x in a.transpose(1,0,2)]) - - b = np.einsum(a, [0,Ellipsis,0], [Ellipsis,0]) - assert_(b.base is a) - assert_equal(b, [[x[i,i] for i in range(3)] - for x in a.transpose(1,0,2)]) - - # triple diagonal - a = np.arange(27).reshape(3,3,3) - - b = np.einsum("iii->i", a) - assert_(b.base is a) - assert_equal(b, [a[i,i,i] for i in range(3)]) - - b = np.einsum(a, [0,0,0], [0]) - assert_(b.base is a) - assert_equal(b, [a[i,i,i] for i in range(3)]) - - # swap axes - a = np.arange(24).reshape(2,3,4) - - b = np.einsum("ijk->jik", a) - assert_(b.base is a) - assert_equal(b, a.swapaxes(0,1)) - - b = np.einsum(a, [0,1,2], [1,0,2]) - assert_(b.base is a) - assert_equal(b, a.swapaxes(0,1)) - - def check_einsum_sums(self, dtype): - # Check various sums. Does many sizes to exercise unrolled loops. - - # sum(a, axis=-1) - for n in range(1,17): - a = np.arange(n, dtype=dtype) - assert_equal(np.einsum("i->", a), np.sum(a, axis=-1).astype(dtype)) - assert_equal(np.einsum(a, [0], []), - np.sum(a, axis=-1).astype(dtype)) - - for n in range(1,17): - a = np.arange(2*3*n, dtype=dtype).reshape(2,3,n) - assert_equal(np.einsum("...i->...", a), - np.sum(a, axis=-1).astype(dtype)) - assert_equal(np.einsum(a, [Ellipsis,0], [Ellipsis]), - np.sum(a, axis=-1).astype(dtype)) - - # sum(a, axis=0) - for n in range(1,17): - a = np.arange(2*n, dtype=dtype).reshape(2,n) - assert_equal(np.einsum("i...->...", a), - np.sum(a, axis=0).astype(dtype)) - assert_equal(np.einsum(a, [0,Ellipsis], [Ellipsis]), - np.sum(a, axis=0).astype(dtype)) - - for n in range(1,17): - a = np.arange(2*3*n, dtype=dtype).reshape(2,3,n) - assert_equal(np.einsum("i...->...", a), - np.sum(a, axis=0).astype(dtype)) - assert_equal(np.einsum(a, [0,Ellipsis], [Ellipsis]), - np.sum(a, axis=0).astype(dtype)) - - # trace(a) - for n in range(1,17): - a = np.arange(n*n, dtype=dtype).reshape(n,n) - assert_equal(np.einsum("ii", a), np.trace(a).astype(dtype)) - assert_equal(np.einsum(a, [0,0]), np.trace(a).astype(dtype)) - - # multiply(a, b) - for n in range(1,17): - a = np.arange(3*n, dtype=dtype).reshape(3,n) - b = np.arange(2*3*n, dtype=dtype).reshape(2,3,n) - assert_equal(np.einsum("..., ...", a, b), np.multiply(a, b)) - assert_equal(np.einsum(a, [Ellipsis], b, [Ellipsis]), - np.multiply(a, b)) - - # inner(a,b) - for n in range(1,17): - a = np.arange(2*3*n, dtype=dtype).reshape(2,3,n) - b = np.arange(n, dtype=dtype) - assert_equal(np.einsum("...i, ...i", a, b), np.inner(a, b)) - assert_equal(np.einsum(a, [Ellipsis,0], b, [Ellipsis,0]), - np.inner(a, b)) - - for n in range(1,11): - a = np.arange(n*3*2, dtype=dtype).reshape(n,3,2) - b = np.arange(n, dtype=dtype) - assert_equal(np.einsum("i..., i...", a, b), np.inner(a.T, b.T).T) - assert_equal(np.einsum(a, [0,Ellipsis], b, [0,Ellipsis]), - np.inner(a.T, b.T).T) - - # outer(a,b) - for n in range(1,17): - a = np.arange(3, dtype=dtype)+1 - b = np.arange(n, dtype=dtype)+1 - assert_equal(np.einsum("i,j", a, b), np.outer(a, b)) - assert_equal(np.einsum(a, [0], b, [1]), np.outer(a, b)) - - # Suppress the complex warnings for the 'as f8' tests - ctx = WarningManager() - ctx.__enter__() - try: - warnings.simplefilter('ignore', np.ComplexWarning) - - # matvec(a,b) / a.dot(b) where a is matrix, b is vector - for n in range(1,17): - a = np.arange(4*n, dtype=dtype).reshape(4,n) - b = np.arange(n, dtype=dtype) - assert_equal(np.einsum("ij, j", a, b), np.dot(a, b)) - assert_equal(np.einsum(a, [0,1], b, [1]), np.dot(a, b)) - - c = np.arange(4, dtype=dtype) - np.einsum("ij,j", a, b, out=c, - dtype='f8', casting='unsafe') - assert_equal(c, - np.dot(a.astype('f8'), - b.astype('f8')).astype(dtype)) - c[...] = 0 - np.einsum(a, [0,1], b, [1], out=c, - dtype='f8', casting='unsafe') - assert_equal(c, - np.dot(a.astype('f8'), - b.astype('f8')).astype(dtype)) - - for n in range(1,17): - a = np.arange(4*n, dtype=dtype).reshape(4,n) - b = np.arange(n, dtype=dtype) - assert_equal(np.einsum("ji,j", a.T, b.T), np.dot(b.T, a.T)) - assert_equal(np.einsum(a.T, [1,0], b.T, [1]), np.dot(b.T, a.T)) - - c = np.arange(4, dtype=dtype) - np.einsum("ji,j", a.T, b.T, out=c, dtype='f8', casting='unsafe') - assert_equal(c, - np.dot(b.T.astype('f8'), - a.T.astype('f8')).astype(dtype)) - c[...] = 0 - np.einsum(a.T, [1,0], b.T, [1], out=c, - dtype='f8', casting='unsafe') - assert_equal(c, - np.dot(b.T.astype('f8'), - a.T.astype('f8')).astype(dtype)) - - # matmat(a,b) / a.dot(b) where a is matrix, b is matrix - for n in range(1,17): - if n < 8 or dtype != 'f2': - a = np.arange(4*n, dtype=dtype).reshape(4,n) - b = np.arange(n*6, dtype=dtype).reshape(n,6) - assert_equal(np.einsum("ij,jk", a, b), np.dot(a, b)) - assert_equal(np.einsum(a, [0,1], b, [1,2]), np.dot(a, b)) - - for n in range(1,17): - a = np.arange(4*n, dtype=dtype).reshape(4,n) - b = np.arange(n*6, dtype=dtype).reshape(n,6) - c = np.arange(24, dtype=dtype).reshape(4,6) - np.einsum("ij,jk", a, b, out=c, dtype='f8', casting='unsafe') - assert_equal(c, - np.dot(a.astype('f8'), - b.astype('f8')).astype(dtype)) - c[...] = 0 - np.einsum(a, [0,1], b, [1,2], out=c, - dtype='f8', casting='unsafe') - assert_equal(c, - np.dot(a.astype('f8'), - b.astype('f8')).astype(dtype)) - - # matrix triple product (note this is not currently an efficient - # way to multiply 3 matrices) - a = np.arange(12, dtype=dtype).reshape(3,4) - b = np.arange(20, dtype=dtype).reshape(4,5) - c = np.arange(30, dtype=dtype).reshape(5,6) - if dtype != 'f2': - assert_equal(np.einsum("ij,jk,kl", a, b, c), - a.dot(b).dot(c)) - assert_equal(np.einsum(a, [0,1], b, [1,2], c, [2,3]), - a.dot(b).dot(c)) - - d = np.arange(18, dtype=dtype).reshape(3,6) - np.einsum("ij,jk,kl", a, b, c, out=d, - dtype='f8', casting='unsafe') - assert_equal(d, a.astype('f8').dot(b.astype('f8') - ).dot(c.astype('f8')).astype(dtype)) - d[...] = 0 - np.einsum(a, [0,1], b, [1,2], c, [2,3], out=d, - dtype='f8', casting='unsafe') - assert_equal(d, a.astype('f8').dot(b.astype('f8') - ).dot(c.astype('f8')).astype(dtype)) - - # tensordot(a, b) - if np.dtype(dtype) != np.dtype('f2'): - a = np.arange(60, dtype=dtype).reshape(3,4,5) - b = np.arange(24, dtype=dtype).reshape(4,3,2) - assert_equal(np.einsum("ijk, jil -> kl", a, b), - np.tensordot(a,b, axes=([1,0],[0,1]))) - assert_equal(np.einsum(a, [0,1,2], b, [1,0,3], [2,3]), - np.tensordot(a,b, axes=([1,0],[0,1]))) - - c = np.arange(10, dtype=dtype).reshape(5,2) - np.einsum("ijk,jil->kl", a, b, out=c, - dtype='f8', casting='unsafe') - assert_equal(c, np.tensordot(a.astype('f8'), b.astype('f8'), - axes=([1,0],[0,1])).astype(dtype)) - c[...] = 0 - np.einsum(a, [0,1,2], b, [1,0,3], [2,3], out=c, - dtype='f8', casting='unsafe') - assert_equal(c, np.tensordot(a.astype('f8'), b.astype('f8'), - axes=([1,0],[0,1])).astype(dtype)) - finally: - ctx.__exit__() - - # logical_and(logical_and(a!=0, b!=0), c!=0) - a = np.array([1, 3, -2, 0, 12, 13, 0, 1], dtype=dtype) - b = np.array([0, 3.5, 0., -2, 0, 1, 3, 12], dtype=dtype) - c = np.array([True,True,False,True,True,False,True,True]) - assert_equal(np.einsum("i,i,i->i", a, b, c, - dtype='?', casting='unsafe'), - logical_and(logical_and(a!=0, b!=0), c!=0)) - assert_equal(np.einsum(a, [0], b, [0], c, [0], [0], - dtype='?', casting='unsafe'), - logical_and(logical_and(a!=0, b!=0), c!=0)) - - a = np.arange(9, dtype=dtype) - assert_equal(np.einsum(",i->", 3, a), 3*np.sum(a)) - assert_equal(np.einsum(3, [], a, [0], []), 3*np.sum(a)) - assert_equal(np.einsum("i,->", a, 3), 3*np.sum(a)) - assert_equal(np.einsum(a, [0], 3, [], []), 3*np.sum(a)) - - # Various stride0, contiguous, and SSE aligned variants - for n in range(1,25): - a = np.arange(n, dtype=dtype) - if np.dtype(dtype).itemsize > 1: - assert_equal(np.einsum("...,...",a,a), np.multiply(a,a)) - assert_equal(np.einsum("i,i", a, a), np.dot(a,a)) - assert_equal(np.einsum("i,->i", a, 2), 2*a) - assert_equal(np.einsum(",i->i", 2, a), 2*a) - assert_equal(np.einsum("i,->", a, 2), 2*np.sum(a)) - assert_equal(np.einsum(",i->", 2, a), 2*np.sum(a)) - - assert_equal(np.einsum("...,...",a[1:],a[:-1]), - np.multiply(a[1:],a[:-1])) - assert_equal(np.einsum("i,i", a[1:], a[:-1]), - np.dot(a[1:],a[:-1])) - assert_equal(np.einsum("i,->i", a[1:], 2), 2*a[1:]) - assert_equal(np.einsum(",i->i", 2, a[1:]), 2*a[1:]) - assert_equal(np.einsum("i,->", a[1:], 2), 2*np.sum(a[1:])) - assert_equal(np.einsum(",i->", 2, a[1:]), 2*np.sum(a[1:])) - - # An object array, summed as the data type - a = np.arange(9, dtype=object) - - b = np.einsum("i->", a, dtype=dtype, casting='unsafe') - assert_equal(b, np.sum(a)) - assert_equal(b.dtype, np.dtype(dtype)) - - b = np.einsum(a, [0], [], dtype=dtype, casting='unsafe') - assert_equal(b, np.sum(a)) - assert_equal(b.dtype, np.dtype(dtype)) - - def test_einsum_sums_int8(self): - self.check_einsum_sums('i1'); - - def test_einsum_sums_uint8(self): - self.check_einsum_sums('u1'); - - def test_einsum_sums_int16(self): - self.check_einsum_sums('i2'); - - def test_einsum_sums_uint16(self): - self.check_einsum_sums('u2'); - - def test_einsum_sums_int32(self): - self.check_einsum_sums('i4'); - - def test_einsum_sums_uint32(self): - self.check_einsum_sums('u4'); - - def test_einsum_sums_int64(self): - self.check_einsum_sums('i8'); - - def test_einsum_sums_uint64(self): - self.check_einsum_sums('u8'); - - def test_einsum_sums_float16(self): - self.check_einsum_sums('f2'); - - def test_einsum_sums_float32(self): - self.check_einsum_sums('f4'); - - def test_einsum_sums_float64(self): - self.check_einsum_sums('f8'); - - def test_einsum_sums_longdouble(self): - self.check_einsum_sums(np.longdouble); - - def test_einsum_sums_cfloat64(self): - self.check_einsum_sums('c8'); - - def test_einsum_sums_cfloat128(self): - self.check_einsum_sums('c16'); - - def test_einsum_sums_clongdouble(self): - self.check_einsum_sums(np.clongdouble); - class TestNonarrayArgs(TestCase): # check that non-array arguments to functions wrap them in arrays def test_squeeze(self): From 31f6ff9fe0cad4d30299136e72ea2d61fa53c486 Mon Sep 17 00:00:00 2001 From: Mark Wiebe Date: Tue, 1 Feb 2011 12:53:06 -0800 Subject: [PATCH 6/7] ENH: iter: Catch another case with fixed strides. --- numpy/core/src/multiarray/einsum.c.src | 49 ++++++++++++++--- numpy/core/src/multiarray/new_iterator.c.src | 58 ++++++++++++++++---- 2 files changed, 88 insertions(+), 19 deletions(-) diff --git a/numpy/core/src/multiarray/einsum.c.src b/numpy/core/src/multiarray/einsum.c.src index d681016b7ff3..488c9c721011 100644 --- a/numpy/core/src/multiarray/einsum.c.src +++ b/numpy/core/src/multiarray/einsum.c.src @@ -45,6 +45,16 @@ #define EINSUM_IS_SSE_ALIGNED(x) ((((npy_intp)x)&0xf) == 0) +/********** PRINTF DEBUG TRACING **************/ +#define NPY_EINSUM_DBG_TRACING 0 + +#if NPY_EINSUM_DBG_TRACING +#define NPY_EINSUM_DBG_PRINTF(...) printf(__VA_ARGS__) +#else +#define NPY_EINSUM_DBG_PRINTF(...) +#endif +/**********************************************/ + typedef enum { BROADCAST_NONE, BROADCAST_LEFT, @@ -104,6 +114,8 @@ static void npy_intp stride_out = strides[@nop@]; #endif + NPY_EINSUM_DBG_PRINTF("@name@_sum_of_products_@noplabel@ (%d)\n", (int)count); + while (count--) { #if !@complex@ # if @nop@ == 1 @@ -187,6 +199,8 @@ static void npy_@name@ *data0 = (npy_@name@ *)dataptr[0]; npy_@name@ *data_out = (npy_@name@ *)dataptr[1]; + NPY_EINSUM_DBG_PRINTF("@name@_sum_of_products_contig_one (%d)\n", (int)count); + /* This is placed before the main loop to make small counts faster */ finish_after_unrolled_loop: switch (count) { @@ -251,6 +265,8 @@ static void __m128 a, b; #endif + NPY_EINSUM_DBG_PRINTF("@name@_sum_of_products_contig_two (%d)\n", (int)count); + /* This is placed before the main loop to make small counts faster */ finish_after_unrolled_loop: switch (count) { @@ -334,6 +350,9 @@ static void __m128 a, b, value0_sse; #endif + NPY_EINSUM_DBG_PRINTF("@name@_sum_of_products_stride0_contig_outcontig_two (%d)\n", + (int)count); + /* This is placed before the main loop to make small counts faster */ finish_after_unrolled_loop: switch (count) { @@ -415,6 +434,9 @@ static void __m128 a, b, value1_sse; #endif + NPY_EINSUM_DBG_PRINTF("@name@_sum_of_products_contig_stride0_outcontig_two (%d)\n", + (int)count); + /* This is placed before the main loop to make small counts faster */ finish_after_unrolled_loop: switch (count) { @@ -496,6 +518,9 @@ static void __m128 a, accum_sse = _mm_setzero_ps(); #endif + NPY_EINSUM_DBG_PRINTF("@name@_sum_of_products_contig_contig_outstride0_two (%d)\n", + (int)count); + /* This is placed before the main loop to make small counts faster */ finish_after_unrolled_loop: switch (count) { @@ -596,6 +621,9 @@ static void __m128 a, accum_sse = _mm_setzero_ps(); #endif + NPY_EINSUM_DBG_PRINTF("@name@_sum_of_products_stride0_contig_outstride0_two (%d)\n", + (int)count); + /* This is placed before the main loop to make small counts faster */ finish_after_unrolled_loop: switch (count) { @@ -692,6 +720,9 @@ static void __m128 a, accum_sse = _mm_setzero_ps(); #endif + NPY_EINSUM_DBG_PRINTF("@name@_sum_of_products_contig_stride0_outstride0_two (%d)\n", + (int)count); + /* This is placed before the main loop to make small counts faster */ finish_after_unrolled_loop: switch (count) { @@ -825,6 +856,9 @@ static void @name@_sum_of_products_contig_@noplabel@(int nop, char **dataptr, npy_intp *NPY_UNUSED(strides), npy_intp count) { + NPY_EINSUM_DBG_PRINTF("@name@_sum_of_products_contig_@noplabel@ (%d)\n", + (int)count); + while (count--) { #if !@complex@ npy_@temp@ temp = @from@(*(npy_@name@ *)dataptr[0]); @@ -892,6 +926,9 @@ static void npy_intp stride2 = strides[2]; #endif + NPY_EINSUM_DBG_PRINTF("@name@_sum_of_products_outstride0_@noplabel@ (%d)\n", + (int)count); + while (count--) { #if !@complex@ # if @nop@ == 1 @@ -2413,14 +2450,9 @@ PyArray_EinsteinSum(char *subscripts, npy_intp nop, NpyIter_GetDescrArray(iter)[0]->elsize, fixed_strides); - #if 0 +#if NPY_EINSUM_DBG_TRACING NpyIter_DebugPrint(iter); - printf("fixed strides:\n"); - for (iop = 0; iop <= nop; ++iop) { - printf("%ld ", fixed_strides[iop]); - } - printf("\n"); - #endif +#endif /* Finally, the main loop */ if (sop == NULL) { @@ -2450,9 +2482,8 @@ PyArray_EinsteinSum(char *subscripts, npy_intp nop, if (!needs_api) { NPY_BEGIN_THREADS; } - /*printf("Einsum loop\n");*/ + NPY_EINSUM_DBG_PRINTF("Einsum loop\n"); do { - /*printf("Einsum inner loop count %d\n", (int)*countptr);*/ sop(nop, dataptr, stride, *countptr); } while(iternext(iter)); if (!needs_api) { diff --git a/numpy/core/src/multiarray/new_iterator.c.src b/numpy/core/src/multiarray/new_iterator.c.src index ba7c687de064..4e3971313322 100644 --- a/numpy/core/src/multiarray/new_iterator.c.src +++ b/numpy/core/src/multiarray/new_iterator.c.src @@ -2466,28 +2466,57 @@ NPY_NO_EXPORT void NpyIter_GetInnerFixedStrideArray(NpyIter *iter, npy_intp *out_strides) { npy_uint32 itflags = NIT_ITFLAGS(iter); - /*npy_intp ndim = NIT_NDIM(iter);*/ + npy_intp ndim = NIT_NDIM(iter); npy_intp iiter, niter = NIT_NITER(iter); - NpyIter_AxisData *axisdata = NIT_AXISDATA(iter); + NpyIter_AxisData *axisdata0 = NIT_AXISDATA(iter); + npy_intp sizeof_axisdata = NIT_AXISDATA_SIZEOF(itflags, ndim, niter); if (itflags&NPY_ITFLAG_BUFFER) { NpyIter_BufferData *data = NIT_BUFFERDATA(iter); char *op_itflags = NIT_OPITFLAGS(iter); npy_intp stride, *strides = NBF_STRIDES(data), - *ad_strides = NAD_STRIDES(axisdata); + *ad_strides = NAD_STRIDES(axisdata0); PyArray_Descr **dtypes = NIT_DTYPES(iter); for (iiter = 0; iiter < niter; ++iiter) { stride = strides[iiter]; - /* Operands which are always/never buffered have fixed strides */ - if (op_itflags[iiter]& - (NPY_OP_ITFLAG_CAST|NPY_OP_ITFLAG_BUFNEVER)) { + /* + * Operands which are always/never buffered have fixed strides, + * and everything has fixed strides when ndim is 0 or 1 + */ + if (ndim <= 1 || (op_itflags[iiter]& + (NPY_OP_ITFLAG_CAST|NPY_OP_ITFLAG_BUFNEVER))) { out_strides[iiter] = stride; } - /* Reductions in the inner loop have fixed strides */ - else if (stride == 0 && (op_itflags[iiter]&NPY_OP_ITFLAG_REDUCE)) { - out_strides[iiter] = stride; + /* If it's a reduction, 0-stride inner loop may have fixed stride */ + else if (stride == 0 && (itflags&NPY_ITFLAG_REDUCE)) { + /* If it's a reduction operand, definitely fixed stride */ + if (op_itflags[iiter]&NPY_OP_ITFLAG_REDUCE) { + out_strides[iiter] = stride; + } + /* + * Otherwise it's a fixed stride if the stride is 0 + * for all inner dimensions of the reduction double loop + */ + else { + NpyIter_AxisData *axisdata = axisdata0; + npy_intp idim, + reduce_outerdim = NBF_REDUCE_OUTERDIM(data); + for (idim = 0; idim < reduce_outerdim; ++idim) { + if (NAD_STRIDES(axisdata)[iiter] != 0) { + break; + } + NIT_ADVANCE_AXISDATA(axisdata, 1); + } + /* If all the strides were 0, the stride won't change */ + if (idim == reduce_outerdim) { + out_strides[iiter] = stride; + } + else { + out_strides[iiter] = NPY_MAX_INTP; + } + } } /* * Inner loop contiguous array means its stride won't change when @@ -2507,7 +2536,7 @@ NpyIter_GetInnerFixedStrideArray(NpyIter *iter, npy_intp *out_strides) } else { /* If there's no buffering, the strides are always fixed */ - memcpy(out_strides, NAD_STRIDES(axisdata), niter*NPY_SIZEOF_INTP); + memcpy(out_strides, NAD_STRIDES(axisdata0), niter*NPY_SIZEOF_INTP); } } @@ -5955,6 +5984,15 @@ NpyIter_DebugPrint(NpyIter *iter) for (iiter = 0; iiter < niter; ++iiter) printf("%d ", (int)NBF_STRIDES(bufferdata)[iiter]); printf("\n"); + /* Print the fixed strides when there's no inner loop */ + if (itflags&NPY_ITFLAG_NOINNER) { + npy_intp fixedstrides[NPY_MAXDIMS]; + printf("| Fixed Strides: "); + NpyIter_GetInnerFixedStrideArray(iter, fixedstrides); + for (iiter = 0; iiter < niter; ++iiter) + printf("%d ", (int)fixedstrides[iiter]); + printf("\n"); + } printf("| Ptrs: "); for (iiter = 0; iiter < niter; ++iiter) printf("%p ", (void *)NBF_PTRS(bufferdata)[iiter]); From 5c5d026aafd9336c3461834150c407405c13593f Mon Sep 17 00:00:00 2001 From: Mark Wiebe Date: Tue, 1 Feb 2011 17:43:16 -0800 Subject: [PATCH 7/7] ENH: einsum: Write specialized unbuffered loops for several cases Also converted the used inner loops to SSE2, to bring einsum fairly close to numpy.dot's performance for matrix-vector and matrix-matrix multiplication. --- numpy/core/code_generators/numpy_api.py | 22 +- numpy/core/src/multiarray/einsum.c.src | 375 ++++++++++++++++++- numpy/core/src/multiarray/new_iterator.c.src | 134 ++++--- numpy/core/tests/test_new_iterator.py | 3 +- 4 files changed, 473 insertions(+), 61 deletions(-) diff --git a/numpy/core/code_generators/numpy_api.py b/numpy/core/code_generators/numpy_api.py index 47c292ac89dd..a4aa169461c4 100644 --- a/numpy/core/code_generators/numpy_api.py +++ b/numpy/core/code_generators/numpy_api.py @@ -298,17 +298,19 @@ 'NpyIter_GetInnerFixedStrideArray': 262, 'NpyIter_RemoveAxis': 263, 'NpyIter_GetAxisStrideArray': 264, + 'NpyIter_RequiresBuffering': 265, + 'NpyIter_GetInitialDataPtrArray': 266, # - 'PyArray_CastingConverter': 265, - 'PyArray_CountNonzero': 266, - 'PyArray_PromoteTypes': 267, - 'PyArray_MinScalarType': 268, - 'PyArray_ResultType': 269, - 'PyArray_CanCastArrayTo': 270, - 'PyArray_CanCastTypeTo': 271, - 'PyArray_EinsteinSum': 272, - 'PyArray_FillWithZero': 273, - 'PyArray_NewLikeArray': 274, + 'PyArray_CastingConverter': 267, + 'PyArray_CountNonzero': 268, + 'PyArray_PromoteTypes': 269, + 'PyArray_MinScalarType': 270, + 'PyArray_ResultType': 271, + 'PyArray_CanCastArrayTo': 272, + 'PyArray_CanCastTypeTo': 273, + 'PyArray_EinsteinSum': 274, + 'PyArray_FillWithZero': 275, + 'PyArray_NewLikeArray': 276, } ufunc_types_api = { diff --git a/numpy/core/src/multiarray/einsum.c.src b/numpy/core/src/multiarray/einsum.c.src index 488c9c721011..4799c0f443f8 100644 --- a/numpy/core/src/multiarray/einsum.c.src +++ b/numpy/core/src/multiarray/einsum.c.src @@ -26,11 +26,10 @@ #endif /* - * TODO: Only SSE for float32 is implemented in the loops, - * no SSE2 for float64 + * TODO: Only some SSE2 for float64 is implemented. */ #ifdef __SSE2__ -#define EINSUM_USE_SSE2 0 +#define EINSUM_USE_SSE2 1 #else #define EINSUM_USE_SSE2 0 #endif @@ -87,6 +86,10 @@ typedef enum { * 0*5, * 0,1,0,0, * 0*3# + * #float64 = 0*5, + * 0*5, + * 0,0,1,0, + * 0*3# */ /**begin repeat1 @@ -348,6 +351,8 @@ static void #if EINSUM_USE_SSE1 && @float32@ __m128 a, b, value0_sse; +#elif EINSUM_USE_SSE2 && @float64@ + __m128d a, b, value0_sse; #endif NPY_EINSUM_DBG_PRINTF("@name@_sum_of_products_stride0_contig_outcontig_two (%d)\n", @@ -389,7 +394,40 @@ finish_after_unrolled_loop: } /* Finish off the loop */ - goto finish_after_unrolled_loop; + if (count > 0) { + goto finish_after_unrolled_loop; + } + else { + return; + } + } +#elif EINSUM_USE_SSE2 && @float64@ + value0_sse = _mm_set1_pd(value0); + + /* Use aligned instructions if possible */ + if (EINSUM_IS_SSE_ALIGNED(data1) && EINSUM_IS_SSE_ALIGNED(data_out)) { + /* Unroll the loop by 8 */ + while (count >= 8) { + count -= 8; + +/**begin repeat2 + * #i = 0, 2, 4, 6# + */ + a = _mm_mul_pd(value0_sse, _mm_load_pd(data1+@i@)); + b = _mm_add_pd(a, _mm_load_pd(data_out+@i@)); + _mm_store_pd(data_out+@i@, b); +/**end repeat2**/ + data1 += 8; + data_out += 8; + } + + /* Finish off the loop */ + if (count > 0) { + goto finish_after_unrolled_loop; + } + else { + return; + } } #endif @@ -405,6 +443,14 @@ finish_after_unrolled_loop: b = _mm_add_ps(a, _mm_loadu_ps(data_out+@i@)); _mm_storeu_ps(data_out+@i@, b); /**end repeat2**/ +#elif EINSUM_USE_SSE2 && @float64@ +/**begin repeat2 + * #i = 0, 2, 4, 6# + */ + a = _mm_mul_pd(value0_sse, _mm_loadu_pd(data1+@i@)); + b = _mm_add_pd(a, _mm_loadu_pd(data_out+@i@)); + _mm_storeu_pd(data_out+@i@, b); +/**end repeat2**/ #else /**begin repeat2 * #i = 0, 1, 2, 3, 4, 5, 6, 7# @@ -419,7 +465,9 @@ finish_after_unrolled_loop: } /* Finish off the loop */ - goto finish_after_unrolled_loop; + if (count > 0) { + goto finish_after_unrolled_loop; + } } static void @@ -516,6 +564,8 @@ static void #if EINSUM_USE_SSE1 && @float32@ __m128 a, accum_sse = _mm_setzero_ps(); +#elif EINSUM_USE_SSE2 && @float64@ + __m128d a, accum_sse = _mm_setzero_pd(); #endif NPY_EINSUM_DBG_PRINTF("@name@_sum_of_products_contig_contig_outstride0_two (%d)\n", @@ -556,14 +606,41 @@ finish_after_unrolled_loop: data1 += 8; } -#if EINSUM_USE_SSE1 && @float32@ /* Add the four SSE values and put in accum */ a = _mm_shuffle_ps(accum_sse, accum_sse, _MM_SHUFFLE(2,3,0,1)); accum_sse = _mm_add_ps(a, accum_sse); a = _mm_shuffle_ps(accum_sse, accum_sse, _MM_SHUFFLE(1,0,3,2)); accum_sse = _mm_add_ps(a, accum_sse); _mm_store_ss(&accum, accum_sse); -#endif + + /* Finish off the loop */ + goto finish_after_unrolled_loop; + } +#elif EINSUM_USE_SSE2 && @float64@ + /* Use aligned instructions if possible */ + if (EINSUM_IS_SSE_ALIGNED(data0) && EINSUM_IS_SSE_ALIGNED(data1)) { + /* Unroll the loop by 8 */ + while (count >= 8) { + count -= 8; + +/**begin repeat2 + * #i = 0, 2, 4, 6# + */ + /* + * NOTE: This accumulation changes the order, so will likely + * produce slightly different results. + */ + a = _mm_mul_pd(_mm_load_pd(data0+@i@), _mm_load_pd(data1+@i@)); + accum_sse = _mm_add_pd(accum_sse, a); +/**end repeat2**/ + data0 += 8; + data1 += 8; + } + + /* Add the two SSE2 values and put in accum */ + a = _mm_shuffle_pd(accum_sse, accum_sse, _MM_SHUFFLE2(0,1)); + accum_sse = _mm_add_pd(a, accum_sse); + _mm_store_sd(&accum, accum_sse); /* Finish off the loop */ goto finish_after_unrolled_loop; @@ -585,6 +662,17 @@ finish_after_unrolled_loop: a = _mm_mul_ps(_mm_loadu_ps(data0+@i@), _mm_loadu_ps(data1+@i@)); accum_sse = _mm_add_ps(accum_sse, a); /**end repeat2**/ +#elif EINSUM_USE_SSE2 && @float64@ +/**begin repeat2 + * #i = 0, 2, 4, 6# + */ + /* + * NOTE: This accumulation changes the order, so will likely + * produce slightly different results. + */ + a = _mm_mul_pd(_mm_loadu_pd(data0+@i@), _mm_loadu_pd(data1+@i@)); + accum_sse = _mm_add_pd(accum_sse, a); +/**end repeat2**/ #else /**begin repeat2 * #i = 0, 1, 2, 3, 4, 5, 6, 7# @@ -603,6 +691,11 @@ finish_after_unrolled_loop: a = _mm_shuffle_ps(accum_sse, accum_sse, _MM_SHUFFLE(1,0,3,2)); accum_sse = _mm_add_ps(a, accum_sse); _mm_store_ss(&accum, accum_sse); +#elif EINSUM_USE_SSE2 && @float64@ + /* Add the two SSE2 values and put in accum */ + a = _mm_shuffle_pd(accum_sse, accum_sse, _MM_SHUFFLE2(0,1)); + accum_sse = _mm_add_pd(a, accum_sse); + _mm_store_sd(&accum, accum_sse); #endif /* Finish off the loop */ @@ -2101,6 +2194,222 @@ prepare_op_axes(int ndim, int iop, char *labels, npy_intp *axes, return 1; } +static int +unbuffered_loop_nop1_ndim2(NpyIter *iter) +{ + npy_intp coord, shape[2], strides[2][2]; + char *ptrs[2][2], *ptr; + sum_of_products_fn sop; + +#if NPY_EINSUM_DBG_TRACING + NpyIter_DebugPrint(iter); +#endif + NPY_EINSUM_DBG_PRINTF("running hand-coded 1-op 2-dim loop\n"); + + NpyIter_GetShape(iter, shape); + memcpy(strides[0], NpyIter_GetAxisStrideArray(iter, 0), + 2*sizeof(npy_intp)); + memcpy(strides[1], NpyIter_GetAxisStrideArray(iter, 1), + 2*sizeof(npy_intp)); + memcpy(ptrs[0], NpyIter_GetInitialDataPtrArray(iter), + 2*sizeof(char *)); + memcpy(ptrs[1], ptrs[0], 2*sizeof(char*)); + + sop = get_sum_of_products_function(1, + NpyIter_GetDescrArray(iter)[0]->type_num, + NpyIter_GetDescrArray(iter)[0]->elsize, + strides[0]); + + if (sop == NULL) { + PyErr_SetString(PyExc_TypeError, + "invalid data type for einsum"); + return -1; + } + + /* + * Since the iterator wasn't tracking coordinates, the + * loop provided by the iterator is in Fortran-order. + */ + for (coord = shape[1]; coord > 0; --coord) { + sop(1, ptrs[0], strides[0], shape[0]); + + ptr = ptrs[1][0] + strides[1][0]; + ptrs[0][0] = ptrs[1][0] = ptr; + ptr = ptrs[1][1] + strides[1][1]; + ptrs[0][1] = ptrs[1][1] = ptr; + } + + return 0; +} + +static int +unbuffered_loop_nop1_ndim3(NpyIter *iter) +{ + npy_intp coords[2], shape[3], strides[3][2]; + char *ptrs[3][2], *ptr; + sum_of_products_fn sop; + +#if NPY_EINSUM_DBG_TRACING + NpyIter_DebugPrint(iter); +#endif + NPY_EINSUM_DBG_PRINTF("running hand-coded 1-op 3-dim loop\n"); + + NpyIter_GetShape(iter, shape); + memcpy(strides[0], NpyIter_GetAxisStrideArray(iter, 0), + 2*sizeof(npy_intp)); + memcpy(strides[1], NpyIter_GetAxisStrideArray(iter, 1), + 2*sizeof(npy_intp)); + memcpy(strides[2], NpyIter_GetAxisStrideArray(iter, 2), + 2*sizeof(npy_intp)); + memcpy(ptrs[0], NpyIter_GetInitialDataPtrArray(iter), + 2*sizeof(char *)); + memcpy(ptrs[1], ptrs[0], 2*sizeof(char*)); + memcpy(ptrs[2], ptrs[0], 2*sizeof(char*)); + + sop = get_sum_of_products_function(1, + NpyIter_GetDescrArray(iter)[0]->type_num, + NpyIter_GetDescrArray(iter)[0]->elsize, + strides[0]); + + if (sop == NULL) { + PyErr_SetString(PyExc_TypeError, + "invalid data type for einsum"); + return -1; + } + + /* + * Since the iterator wasn't tracking coordinates, the + * loop provided by the iterator is in Fortran-order. + */ + for (coords[1] = shape[2]; coords[1] > 0; --coords[1]) { + for (coords[0] = shape[1]; coords[0] > 0; --coords[0]) { + sop(1, ptrs[0], strides[0], shape[0]); + + ptr = ptrs[1][0] + strides[1][0]; + ptrs[0][0] = ptrs[1][0] = ptr; + ptr = ptrs[1][1] + strides[1][1]; + ptrs[0][1] = ptrs[1][1] = ptr; + } + ptr = ptrs[2][0] + strides[2][0]; + ptrs[0][0] = ptrs[1][0] = ptrs[2][0] = ptr; + ptr = ptrs[2][1] + strides[2][1]; + ptrs[0][1] = ptrs[1][1] = ptrs[2][1] = ptr; + } + + return 0; +} + +static int +unbuffered_loop_nop2_ndim2(NpyIter *iter) +{ + npy_intp coord, shape[2], strides[2][3]; + char *ptrs[2][3], *ptr; + sum_of_products_fn sop; + +#if NPY_EINSUM_DBG_TRACING + NpyIter_DebugPrint(iter); +#endif + NPY_EINSUM_DBG_PRINTF("running hand-coded 2-op 2-dim loop\n"); + + NpyIter_GetShape(iter, shape); + memcpy(strides[0], NpyIter_GetAxisStrideArray(iter, 0), + 3*sizeof(npy_intp)); + memcpy(strides[1], NpyIter_GetAxisStrideArray(iter, 1), + 3*sizeof(npy_intp)); + memcpy(ptrs[0], NpyIter_GetInitialDataPtrArray(iter), + 3*sizeof(char *)); + memcpy(ptrs[1], ptrs[0], 3*sizeof(char*)); + + sop = get_sum_of_products_function(2, + NpyIter_GetDescrArray(iter)[0]->type_num, + NpyIter_GetDescrArray(iter)[0]->elsize, + strides[0]); + + if (sop == NULL) { + PyErr_SetString(PyExc_TypeError, + "invalid data type for einsum"); + return -1; + } + + /* + * Since the iterator wasn't tracking coordinates, the + * loop provided by the iterator is in Fortran-order. + */ + for (coord = shape[1]; coord > 0; --coord) { + sop(2, ptrs[0], strides[0], shape[0]); + + ptr = ptrs[1][0] + strides[1][0]; + ptrs[0][0] = ptrs[1][0] = ptr; + ptr = ptrs[1][1] + strides[1][1]; + ptrs[0][1] = ptrs[1][1] = ptr; + ptr = ptrs[1][2] + strides[1][2]; + ptrs[0][2] = ptrs[1][2] = ptr; + } + + return 0; +} + +static int +unbuffered_loop_nop2_ndim3(NpyIter *iter) +{ + npy_intp coords[2], shape[3], strides[3][3]; + char *ptrs[3][3], *ptr; + sum_of_products_fn sop; + +#if NPY_EINSUM_DBG_TRACING + NpyIter_DebugPrint(iter); +#endif + NPY_EINSUM_DBG_PRINTF("running hand-coded 2-op 3-dim loop\n"); + + NpyIter_GetShape(iter, shape); + memcpy(strides[0], NpyIter_GetAxisStrideArray(iter, 0), + 3*sizeof(npy_intp)); + memcpy(strides[1], NpyIter_GetAxisStrideArray(iter, 1), + 3*sizeof(npy_intp)); + memcpy(strides[2], NpyIter_GetAxisStrideArray(iter, 2), + 3*sizeof(npy_intp)); + memcpy(ptrs[0], NpyIter_GetInitialDataPtrArray(iter), + 3*sizeof(char *)); + memcpy(ptrs[1], ptrs[0], 3*sizeof(char*)); + memcpy(ptrs[2], ptrs[0], 3*sizeof(char*)); + + sop = get_sum_of_products_function(2, + NpyIter_GetDescrArray(iter)[0]->type_num, + NpyIter_GetDescrArray(iter)[0]->elsize, + strides[0]); + + if (sop == NULL) { + PyErr_SetString(PyExc_TypeError, + "invalid data type for einsum"); + return -1; + } + + /* + * Since the iterator wasn't tracking coordinates, the + * loop provided by the iterator is in Fortran-order. + */ + for (coords[1] = shape[2]; coords[1] > 0; --coords[1]) { + for (coords[0] = shape[1]; coords[0] > 0; --coords[0]) { + sop(2, ptrs[0], strides[0], shape[0]); + + ptr = ptrs[1][0] + strides[1][0]; + ptrs[0][0] = ptrs[1][0] = ptr; + ptr = ptrs[1][1] + strides[1][1]; + ptrs[0][1] = ptrs[1][1] = ptr; + ptr = ptrs[1][2] + strides[1][2]; + ptrs[0][2] = ptrs[1][2] = ptr; + } + ptr = ptrs[2][0] + strides[2][0]; + ptrs[0][0] = ptrs[1][0] = ptrs[2][0] = ptr; + ptr = ptrs[2][1] + strides[2][1]; + ptrs[0][1] = ptrs[1][1] = ptrs[2][1] = ptr; + ptr = ptrs[2][2] + strides[2][2]; + ptrs[0][2] = ptrs[1][2] = ptrs[2][2] = ptr; + } + + return 0; +} + /*NUMPY_API * This function provides summation of array elements according to @@ -2435,6 +2744,57 @@ PyArray_EinsteinSum(char *subscripts, npy_intp nop, ret = NpyIter_GetOperandArray(iter)[nop]; Py_INCREF(ret); PyArray_FillWithZero(ret); + + + /***************************/ + /* + * Accceleration for some specific loop structures. Note + * that with axis coalescing, inputs with more dimensions can + * be reduced to fit into these patterns. + */ + if (!NpyIter_RequiresBuffering(iter)) { + npy_intp ndim = NpyIter_GetNDim(iter); + switch (nop) { + case 1: + if (ndim == 2) { + if (unbuffered_loop_nop1_ndim2(iter) < 0) { + Py_DECREF(ret); + ret = NULL; + goto fail; + } + goto finish; + } + else if (ndim == 3) { + if (unbuffered_loop_nop1_ndim3(iter) < 0) { + Py_DECREF(ret); + ret = NULL; + goto fail; + } + goto finish; + } + break; + case 2: + if (ndim == 2) { + if (unbuffered_loop_nop2_ndim2(iter) < 0) { + Py_DECREF(ret); + ret = NULL; + goto fail; + } + goto finish; + } + else if (ndim == 3) { + if (unbuffered_loop_nop2_ndim3(iter) < 0) { + Py_DECREF(ret); + ret = NULL; + goto fail; + } + goto finish; + } + break; + } + } + /***************************/ + if (NpyIter_Reset(iter, NULL) != NPY_SUCCEED) { Py_DECREF(ret); goto fail; @@ -2497,6 +2857,7 @@ PyArray_EinsteinSum(char *subscripts, npy_intp nop, } } +finish: NpyIter_Deallocate(iter); for (iop = 0; iop < nop; ++iop) { Py_DECREF(op[iop]); diff --git a/numpy/core/src/multiarray/new_iterator.c.src b/numpy/core/src/multiarray/new_iterator.c.src index 4e3971313322..3c74b30edf7b 100644 --- a/numpy/core/src/multiarray/new_iterator.c.src +++ b/numpy/core/src/multiarray/new_iterator.c.src @@ -2058,7 +2058,7 @@ NpyIter_GetGetCoords(NpyIter *iter, char **errmsg) /*NUMPY_API * Whether the buffer allocation is being delayed */ -NPY_NO_EXPORT int +NPY_NO_EXPORT npy_bool NpyIter_HasDelayedBufAlloc(NpyIter *iter) { return (NIT_ITFLAGS(iter)&NPY_ITFLAG_DELAYBUF) != 0; @@ -2067,7 +2067,7 @@ NpyIter_HasDelayedBufAlloc(NpyIter *iter) /*NUMPY_API * Whether the iterator handles the inner loop */ -NPY_NO_EXPORT int +NPY_NO_EXPORT npy_bool NpyIter_HasInnerLoop(NpyIter *iter) { return (NIT_ITFLAGS(iter)&NPY_ITFLAG_NOINNER) == 0; @@ -2076,7 +2076,7 @@ NpyIter_HasInnerLoop(NpyIter *iter) /*NUMPY_API * Whether the iterator is tracking coordinates */ -NPY_NO_EXPORT int +NPY_NO_EXPORT npy_bool NpyIter_HasCoords(NpyIter *iter) { return (NIT_ITFLAGS(iter)&NPY_ITFLAG_HASCOORDS) != 0; @@ -2085,18 +2085,46 @@ NpyIter_HasCoords(NpyIter *iter) /*NUMPY_API * Whether the iterator is tracking an index */ -NPY_NO_EXPORT int +NPY_NO_EXPORT npy_bool NpyIter_HasIndex(NpyIter *iter) { return (NIT_ITFLAGS(iter)&NPY_ITFLAG_HASINDEX) != 0; } +/*NUMPY_API + * Whether the iteration could be done with no buffering. + */ +NPY_NO_EXPORT npy_bool +NpyIter_RequiresBuffering(NpyIter *iter) +{ + npy_uint32 itflags = NIT_ITFLAGS(iter); + /*npy_intp ndim = NIT_NDIM(iter);*/ + npy_intp iiter, niter = NIT_NITER(iter); + + char *op_itflags; + + if (!(itflags&NPY_ITFLAG_BUFFER)) { + return 0; + } + + op_itflags = NIT_OPITFLAGS(iter); + + /* If any operand requires a cast, buffering is mandatory */ + for (iiter = 0; iiter < niter; ++iiter) { + if (op_itflags[iiter]&NPY_OP_ITFLAG_CAST) { + return 1; + } + } + + return 0; +} + /*NUMPY_API * Whether the iteration loop, and in particular the iternext() * function, needs API access. If this is true, the GIL must * be retained while iterating. */ -NPY_NO_EXPORT int +NPY_NO_EXPORT npy_bool NpyIter_IterationNeedsAPI(NpyIter *iter) { return (NIT_ITFLAGS(iter)&NPY_ITFLAG_NEEDSAPI) != 0; @@ -2179,7 +2207,9 @@ NpyIter_GetIterIndexRange(NpyIter *iter, } /*NUMPY_API - * Gets the broadcast shape (if coords are enabled) + * Gets the broadcast shape if coords are enabled, otherwise + * gets the shape of the iteration as Fortran-order (fastest-changing + * coordinate first) */ NPY_NO_EXPORT int NpyIter_GetShape(NpyIter *iter, npy_intp *outshape) @@ -2192,23 +2222,27 @@ NpyIter_GetShape(NpyIter *iter, npy_intp *outshape) NpyIter_AxisData *axisdata; char *perm; - if (!(itflags&NPY_ITFLAG_HASCOORDS)) { - PyErr_SetString(PyExc_ValueError, - "Cannot get the shape of an iterator " - "without coordinates requested in the constructor"); - return NPY_FAIL; - } - - perm = NIT_PERM(iter); axisdata = NIT_AXISDATA(iter); sizeof_axisdata = NIT_AXISDATA_SIZEOF(itflags, ndim, niter); - for(idim = 0; idim < ndim; ++idim, NIT_ADVANCE_AXISDATA(axisdata, 1)) { - char p = perm[idim]; - if (p < 0) { - outshape[ndim+p] = NAD_SHAPE(axisdata); + + if (itflags&NPY_ITFLAG_HASCOORDS) { + perm = NIT_PERM(iter); + for(idim = 0; idim < ndim; ++idim) { + char p = perm[idim]; + if (p < 0) { + outshape[ndim+p] = NAD_SHAPE(axisdata); + } + else { + outshape[ndim-p-1] = NAD_SHAPE(axisdata); + } + + NIT_ADVANCE_AXISDATA(axisdata, 1); } - else { - outshape[ndim-p-1] = NAD_SHAPE(axisdata); + } + else { + for(idim = 0; idim < ndim; ++idim) { + outshape[idim] = NAD_SHAPE(axisdata); + NIT_ADVANCE_AXISDATA(axisdata, 1); } } @@ -2237,6 +2271,28 @@ NpyIter_GetDataPtrArray(NpyIter *iter) } } +/*NUMPY_API + * Get the array of data pointers (1 per object being iterated), + * directly into the arrays (never pointing to a buffer), for starting + * unbuffered iteration. This always returns the addresses for the + * iterator position as reset to iterator index 0. + * + * These pointers are different from the pointers accepted by + * NpyIter_ResetBasePointers, because the direction along some + * axes may have been reversed, requiring base offsets. + * + * This function may be safely called without holding the Python GIL. + */ +NPY_NO_EXPORT char ** +NpyIter_GetInitialDataPtrArray(NpyIter *iter) +{ + /*npy_uint32 itflags = NIT_ITFLAGS(iter);*/ + /*npy_intp ndim = NIT_NDIM(iter);*/ + npy_intp niter = NIT_NITER(iter); + + return NIT_RESETDATAPTR(iter); +} + /*NUMPY_API * Get the array of data type pointers (1 per object being iterated) */ @@ -2403,9 +2459,10 @@ NpyIter_GetInnerStrideArray(NpyIter *iter) } /*NUMPY_API - * Gets the array of strides for the specified axis. Requires - * that the iterator be tracking coordinates, and that buffering - * is not enabled. + * Gets the array of strides for the specified axis. + * If the iterator is tracking coordinates, gets the strides + * for the axis specified, otherwise gets the strides for + * the iteration axis as Fortran order (fastest-changing axis first). * * Returns NULL if an error occurs. */ @@ -2420,33 +2477,26 @@ NpyIter_GetAxisStrideArray(NpyIter *iter, npy_intp axis) NpyIter_AxisData *axisdata = NIT_AXISDATA(iter); npy_intp sizeof_axisdata = NIT_AXISDATA_SIZEOF(itflags, ndim, niter); - if (!(itflags&NPY_ITFLAG_HASCOORDS)) { - PyErr_SetString(PyExc_RuntimeError, - "Iterator GetAxisStrideArray may only be called " - "if coordinates are being tracked"); - return NULL; - } - else if (itflags&NPY_ITFLAG_BUFFER) { - PyErr_SetString(PyExc_RuntimeError, - "Iterator GetAxisStrideArray may not be called on " - "a buffered iterator"); - return NULL; - } - else if (axis < 0 || axis >= ndim) { + if (axis < 0 || axis >= ndim) { PyErr_SetString(PyExc_ValueError, "axis out of bounds in iterator GetStrideAxisArray"); return NULL; } - /* Reverse axis, since the iterator treats them that way */ - axis = ndim-1-axis; + if (itflags&NPY_ITFLAG_HASCOORDS) { + /* Reverse axis, since the iterator treats them that way */ + axis = ndim-1-axis; - /* First find the axis in question */ - for (idim = 0; idim < ndim; ++idim, NIT_ADVANCE_AXISDATA(axisdata, 1)) { - if (perm[idim] == axis || -1-perm[idim] == axis) { - return NAD_STRIDES(axisdata); + /* First find the axis in question */ + for (idim = 0; idim < ndim; ++idim, NIT_ADVANCE_AXISDATA(axisdata, 1)) { + if (perm[idim] == axis || -1-perm[idim] == axis) { + return NAD_STRIDES(axisdata); + } } } + else { + return NAD_STRIDES(NIT_INDEX_AXISDATA(axisdata, axis)); + } PyErr_SetString(PyExc_RuntimeError, "internal error in iterator perm"); diff --git a/numpy/core/tests/test_new_iterator.py b/numpy/core/tests/test_new_iterator.py index 04218ed3f595..f115cb75e8c2 100644 --- a/numpy/core/tests/test_new_iterator.py +++ b/numpy/core/tests/test_new_iterator.py @@ -643,10 +643,9 @@ def test_iter_flags_errors(): assert_raises(ValueError, newiter, a, [], [['writeonly']]) assert_raises(ValueError, newiter, a, [], [['readwrite']]) a.flags.writeable = True - # Coords and shape available only with the coords flag + # Coords available only with the coords flag i = newiter(arange(6), [], [['readonly']]) assert_raises(ValueError, lambda i:i.coords, i) - assert_raises(ValueError, lambda i:i.shape, i) # Index available only with an index flag assert_raises(ValueError, lambda i:i.index, i) # GotoCoords and GotoIndex incompatible with buffering or no_inner