8000 implemented newaxis support for zero rank arrays · numpy/numpy@b32744e · GitHub
[go: up one dir, main page]

Skip to content

Commit b32744e

Browse files
author
sasha
committed
implemented newaxis support for zero rank arrays
1 parent 8b2fae0 commit b32744e

File tree

2 files changed

+79
-7
lines changed

2 files changed

+79
-7
lines changed

numpy/core/src/arrayobject.c

Lines changed: 63 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1762,6 +1762,59 @@ PyArray_SetMap(PyArrayMapIterObject *mit, PyObject *op)
17621762

17631763
static PyObject *iter_subscript(PyArrayIterObject *, PyObject *);
17641764

1765+
int
1766+
count_new_axes_0d(PyObject *tuple)
1767+
{
1768+
int i, argument_count;
1769+
int ellipsis_count = 0;
1770+
int newaxis_count = 0;
1771+
argument_count = PyTuple_GET_SIZE(tuple);
1772+
for (i = 0; i < argument_count; ++i) {
1773+
PyObject *arg = PyTuple_GET_ITEM(tuple, i);
1774+
ellipsis_count += (arg == Py_Ellipsis);
1775+
newaxis_count += (arg == Py_None);
1776+
}
1777+
if (newaxis_count + ellipsis_count != argument_count) {
1778+
PyErr_SetString(PyExc_IndexError,
1779+
"0-d arrays can use a single ()"
1780+
" or a list of ellipses and newaxis"
1781+
" as an index");
1782+
return -1;
1783+
}
1784+
if (newaxis_count > MAX_DIMS) {
1785+
PyErr_SetString(PyExc_IndexError,
1786+
"too many dimensions");
1787+
return -1;
1788+
}
1789+
return newaxis_count;
1790+
1791+
}
1792+
static PyObject *
1793+
add_new_axes_0d(PyArrayObject *arr, int newaxis_count)
1794+
{
1795+
PyArrayObject *other;
1796+
intp dimensions[MAX_DIMS], strides[MAX_DIMS];
1797+
int i;
1798+
for (i = 0; i < newaxis_count; ++i) {
1799+
dimensions[i] = strides[i] = 1;
1800+
}
1801+
Py_INCREF(arr->descr);
1802+
if ((other = (PyArrayObject *)
1803+
PyArray_NewFromDescr(arr->ob_type, arr->descr,
1804+
newaxis_count, dimensions,
1805+
strides, arr->data,
1806+
arr->flags,
1807+
(PyObject *)arr)) == NULL)
1808+
return NULL;
1809+
1810+
other->base = (PyObject *)arr;
1811+
Py_INCREF(arr);
1812+
1813+
other ->flags &= ~OWNDATA;
1814+
1815+
return (PyObject *)other;
1816+
}
1817+
17651818
static PyObject *
17661819
array_subscript(PyArrayObject *self, PyObject *op)
17671820
{
@@ -1796,9 +1849,17 @@ array_subscript(PyArrayObject *self, PyObject *op)
17961849
return NULL;
17971850
}
17981851
if (self->nd == 0) {
1799-
if (op == Py_Ellipsis || (PyTuple_Check(op) && \
1800-
0 == PyTuple_GET_SIZE(op)))
1852+
if (op == Py_Ellipsis)
18011853
return PyArray_ToScalar(self->data, self);
1854+
if (op == Py_None)
1855+
return add_new_axes_0d(self, 1);
1856+
if (PyTuple_Check(op)) {
1857+
if (0 == PyTuple_GET_SIZE(op))
1858+
return PyArray_ToScalar(self->data, self);
1859+
if ((nd = count_new_axes_0d(op)) == -1)
1860+
return NULL;
1861+
return add_new_axes_0d(self, nd);
1862+
}
18021863
PyErr_SetString(PyExc_IndexError,
18031864
"0-d arrays can't be indexed.");
18041865
return NULL;

numpy/core/tests/test_multiarray.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -75,15 +75,13 @@ def setUp(self):
7575

7676
def check_ellipsis_subscript(self):
7777
a,b = self.d
78-
7978
self.failUnlessEqual(a[...], 0)
8079
self.failUnlessEqual(b[...].item(), 'x')
8180
self.failUnless(type(a[...]) is a.dtype)
8281
self.failUnless(type(b[...]) is b.dtype)
8382

8483
def check_empty_subscript(self):
8584
a,b = self.d
86-
8785
self.failUnlessEqual(a[()], 0)
8886
self.failUnlessEqual(b[()].item(), 'x')
8987
self.failUnless(type(a[()]) is a.dtype)
@@ -98,15 +96,13 @@ def check_invalid_subscript(self):
9896

9997
def check_ellipsis_subscript_assignment(self):
10098
a,b = self.d
101-
10299
a[...] = 42
103100
self.failUnlessEqual(a, 42)
104101
b[...] = ''
105102
self.failUnlessEqual(b.item(), '')
106103

107104
def check_empty_subscript_assignment(self):
108105
a,b = self.d
109-
110106
a[()] = 42
111107
self.failUnlessEqual(a, 42)
112108
b[()] = ''
@@ -120,7 +116,22 @@ def assign(x, i, v):
120116
self.failUnlessRaises(IndexError, assign, b, 0, '')
121117
self.failUnlessRaises(TypeError, assign, a, (), '')
122118

123-
119+
def check_newaxis(self):
120+
a,b = self.d
121+
self.failUnlessEqual(a[newaxis].shape, (1,))
122+
self.failUnlessEqual(a[..., newaxis].shape, (1,))
123+
self.failUnlessEqual(a[newaxis, ...].shape, (1,))
124+
self.failUnlessEqual(a[..., newaxis].shape, (1,))
125+
self.failUnlessEqual(a[newaxis, ..., newaxis].shape, (1,1))
126+
self.failUnlessEqual(a[..., newaxis, newaxis].shape, (1,1))
127+
self.failUnlessEqual(a[newaxis, newaxis, ...].shape, (1,1))
128+
self.failUnlessEqual(a[(newaxis,)*10].shape, (1,)*10)
129+
130+
def check_invalid_newaxis(self):
131+
a,b = self.d
132+
def subscript(x, i): x[i]
133+
self.failUnlessRaises(IndexError, subscript, a, (newaxis, 0))
134+
self.failUnlessRaises(IndexError, subscript, a, (newaxis,)*50)
124135

125136
if __name__ == "__main__":
126137
ScipyTest('numpy.core.multiarray').run()

0 commit comments

Comments
 (0)
0