@@ -2207,7 +2207,21 @@ PyArray_Nonzero(PyArrayObject *self)
2207
2207
NpyIter_IterNextFunc * iternext ;
2208
2208
NpyIter_GetMultiIndexFunc * get_multi_index ;
2209
2209
char * * dataptr ;
2210
- int is_empty = 0 ;
2210
+
2211
+ /* Special case - nonzero(zero_d) is nonzero(atleast1d(zero_d)) */
2212
+ if (ndim == 0 ) {
2213
+ static npy_intp const zero_dim_shape [1 ] = {1 };
2214
+ static npy_intp const zero_dim_strides [1 ] = {0 };
2215
+
2216
+ PyArrayObject * self_1d = (PyArrayObject * )PyArray_NewFromDescrAndBase (
2217
+ Py_TYPE (self ), PyArray_DESCR (self ),
2218
+ 1 , zero_dim_shape , zero_dim_strides , PyArray_BYTES (self ),
2219
+ PyArray_FLAGS (self ), (PyObject * )self , (PyObject * )self );
2220
+ if (self_1d == NULL ) {
2221
+ return NULL ;
2222
+ }
2223
+ return PyArray_Nonzero (self_1d );
2224
+ }
2211
2225
2212
2226
/*
2213
2227
* First count the number of non-zeros in 'self'.
@@ -2219,7 +2233,7 @@ PyArray_Nonzero(PyArrayObject *self)
2219
2233
2220
2234
/* Allocate the result as a 2D array */
2221
2235
ret_dims [0 ] = nonzero_count ;
2222
- ret_dims [1 ] = ( ndim == 0 ) ? 1 : ndim ;
2236
+ ret_dims [1 ] = ndim ;
2223
2237
ret = (PyArrayObject * )PyArray_NewFromDescr (
2224
2238
& PyArray_Type , PyArray_DescrFromType (NPY_INTP ),
2225
2239
2 , ret_dims , NULL , NULL ,
@@ -2229,11 +2243,11 @@ PyArray_Nonzero(PyArrayObject *self)
2229
2243
}
2230
2244
2231
2245
/* If it's a one-dimensional result, don't use an iterator */
2232
- if (ndim < = 1 ) {
2246
+ if (ndim = = 1 ) {
2233
2247
npy_intp * multi_index = (npy_intp * )PyArray_DATA (ret );
2234
2248
char * data = PyArray_BYTES (self );
2235
- npy_intp stride = ( ndim == 0 ) ? 0 : PyArray_STRIDE (self , 0 );
2236
- npy_intp count = ( ndim == 0 ) ? 1 : PyArray_DIM (self , 0 );
2249
+ npy_intp stride = PyArray_STRIDE (self , 0 );
2250
+ npy_intp count = PyArray_DIM (self , 0 );
2237
2251
NPY_BEGIN_THREADS_DEF ;
2238
2252
2239
2253
/* nothing to do */
@@ -2351,29 +2365,17 @@ PyArray_Nonzero(PyArrayObject *self)
2351
2365
NpyIter_Deallocate (iter );
2352
2366
2353
2367
finish :
2354
- /* Treat zero-dimensional as shape (1,) */
2355
- if (ndim == 0 ) {
2356
- ndim = 1 ;
2357
- }
2358
-
2359
2368
ret_tuple = PyTuple_New (ndim );
2360
2369
if (ret_tuple == NULL ) {
2361
2370
Py_DECREF (ret );
2362
2371
return NULL ;
2363
2372
}
2364
2373
2365
- for (i = 0 ; i < PyArray_NDIM (ret ); ++ i ) {
2366
- if (PyArray_DIMS (ret )[i ] == 0 ) {
2367
- is_empty = 1 ;
2368
- break ;
2369
- }
2370
- }
2371
-
2372
2374
/* Create views into ret, one for each dimension */
2373
2375
for (i = 0 ; i < ndim ; ++ i ) {
2374
2376
npy_intp stride = ndim * NPY_SIZEOF_INTP ;
2375
2377
/* the result is an empty array, the view must point to valid memory */
2376
- npy_intp data_offset = is_empty ? 0 : i * NPY_SIZEOF_INTP ;
2378
+ npy_intp data_offset = nonzero_count == 0 ? 0 : i * NPY_SIZEOF_INTP ;
2377
2379
2378
2380
PyArrayObject * view = (PyArrayObject * )PyArray_NewFromDescrAndBase (
2379
2381
Py_TYPE (ret ), PyArray_DescrFromType (NPY_INTP ),
0 commit comments