@@ -35,12 +35,13 @@ power_of_ten(int n)
35
35
}
36
36
37
37
NPY_NO_EXPORT PyObject *
38
- _PyArray_ArgMaxWithKeepdims (PyArrayObject * op ,
39
- int axis , PyArrayObject * out , int keepdims )
38
+ _PyArray_ArgMinMaxCommon (PyArrayObject * op ,
39
+ int axis , PyArrayObject * out , int keepdims ,
40
+ npy_bool is_argmax )
40
41
{
41
42
PyArrayObject * ap = NULL , * rp = NULL ;
42
- PyArray_ArgFunc * arg_func ;
43
- char * ip ;
43
+ PyArray_ArgFunc * arg_func = NULL ;
44
+ char * ip , * func_name ;
44
45
npy_intp * rptr ;
45
46
npy_intp i , n , m ;
46
47
int elsize ;
@@ -115,7 +116,14 @@ _PyArray_ArgMaxWithKeepdims(PyArrayObject *op,
115
116
}
116
117
}
117
118
118
- arg_func = PyArray_DESCR (ap )-> f -> argmax ;
119
+ if (is_argmax ) {
120
+ func_name = "argmax" ;
121
+ arg_func = PyArray_DESCR (ap )-> f -> argmax ;
122
+ }
123
+ else {
124
+ func_name = "argmin" ;
125
+ arg_func = PyArray_DESCR (ap )-> f -> argmin ;
126
+ }
119
127
if (arg_func == NULL ) {
120
128
PyErr_SetString (PyExc_TypeError ,
121
129
"data type not ordered" );
@@ -124,8 +132,9 @@ _PyArray_ArgMaxWithKeepdims(PyArrayObject *op,
124
132
elsize = PyArray_DESCR (ap )-> elsize ;
125
133
m = PyArray_DIMS (ap )[PyArray_NDIM (ap )- 1 ];
126
134
if (m == 0 ) {
127
- PyErr_SetString (PyExc_ValueError ,
128
- "attempt to get argmax of an empty sequence" );
135
+ PyErr_Format (PyExc_ValueError ,
136
+ "attempt to get %s of an empty sequence" ,
137
+ func_name );
129
138
goto fail ;
130
139
}
131
140
@@ -142,8 +151,9 @@ _PyArray_ArgMaxWithKeepdims(PyArrayObject *op,
142
151
if ((PyArray_NDIM (out ) != out_ndim ) ||
143
152
!PyArray_CompareLists (PyArray_DIMS (out ), out_shape ,
144
153
out_ndim )) {
145
- PyErr_SetString (PyExc_ValueError ,
146
- "output array does not match result of np.argmax." );
154
+ PyErr_Format (PyExc_ValueError ,
155
+ "output array does not match result of np.%s." ,
156
+ func_name );
147
157
goto fail ;
148
158
}
149
159
rp = (PyArrayObject * )PyArray_FromArray (out ,
@@ -179,155 +189,27 @@ _PyArray_ArgMaxWithKeepdims(PyArrayObject *op,
179
189
return NULL ;
180
190
}
181
191
192
+ NPY_NO_EXPORT PyObject *
193
+ _PyArray_ArgMaxWithKeepdims (PyArrayObject * op ,
194
+ int axis , PyArrayObject * out , int keepdims )
195
+ {
196
+ return _PyArray_ArgMinMaxCommon (op , axis , out , keepdims , 1 );
197
+ }
198
+
182
199
/*NUMPY_API
183
200
* ArgMax
184
201
*/
185
202
NPY_NO_EXPORT PyObject *
186
203
PyArray_ArgMax (PyArrayObject * op , int axis , PyArrayObject * out )
187
204
{
188
- return _PyArray_ArgMaxWithKeepdims (op , axis , out , 0 );
205
+ return _PyArray_ArgMinMaxCommon (op , axis , out , 0 , 1 );
189
206
}
190
207
191
208
NPY_NO_EXPORT PyObject *
192
209
_PyArray_ArgMinWithKeepdims (PyArrayObject * op ,
193
210
int axis , PyArrayObject * out , int keepdims )
194
211
{
195
- PyArrayObject * ap = NULL , * rp = NULL ;
196
- PyArray_ArgFunc * arg_func ;
197
- char * ip ;
198
- npy_intp * rptr ;
199
- npy_intp i , n , m ;
200
- int elsize ;
201
- // Keep a copy because axis changes via call to PyArray_CheckAxis
202
- int axis_copy = axis ;
203
- npy_intp _shape_buf [NPY_MAXDIMS ];
204
- npy_intp * out_shape ;
205
- // Keep the number of dimensions and shape of
206
- // original array. Helps when `keepdims` is True.
207
- npy_intp * original_op_shape = PyArray_DIMS (op );
208
- int out_ndim = PyArray_NDIM (op );
209
- NPY_BEGIN_THREADS_DEF ;
210
-
211
- if ((ap = (PyArrayObject * )PyArray_CheckAxis (op , & axis , 0 )) == NULL ) {
212
- return NULL ;
213
- }
214
- /*
215
- * We need to permute the array so that axis is placed at the end.
216
- * And all other dimensions are shifted left.
217
- */
218
- if (axis != PyArray_NDIM (ap )- 1 ) {
219
- PyArray_Dims newaxes ;
220
- npy_intp dims [NPY_MAXDIMS ];
221
- int i ;
222
-
223
- newaxes .ptr = dims ;
224
- newaxes .len = PyArray_NDIM (ap );
225
- for (i = 0 ; i < axis ; i ++ ) {
226
- dims [i ] = i ;
227
- }
228
- for (i = axis ; i < PyArray_NDIM (ap ) - 1 ; i ++ ) {
229
- dims [i ] = i + 1 ;
230
- }
231
- dims [PyArray_NDIM (ap ) - 1 ] = axis ;
232
- op = (PyArrayObject * )PyArray_Transpose (ap , & newaxes );
233
- Py_DECREF (ap );
234
- if (op == NULL ) {
235
- return NULL ;
236
- }
237
- }
238
- else {
239
- op = ap ;
240
- }
241
-
242
- /* Will get native-byte order contiguous copy. */
243
- ap = (PyArrayObject * )PyArray_ContiguousFromAny ((PyObject * )op ,
244
- PyArray_DESCR (op )-> type_num , 1 , 0 );
245
- Py_DECREF (op );
246
- if (ap == NULL ) {
247
- return NULL ;
248
- }
249
-
250
- // Decides the shape of the output array.
251
- if (!keepdims ) {
252
- out_ndim = PyArray_NDIM (ap ) - 1 ;
253
- out_shape = PyArray_DIMS (ap );
254
- } else {
255
- out_shape = _shape_buf ;
256
- if (axis_copy == NPY_MAXDIMS ) {
257
- for (int i = 0 ; i < out_ndim ; i ++ ) {
258
- out_shape [i ] = 1 ;
259
- }
260
- } else {
261
- /*
262
- * While `ap` may be transposed, we can ignore this for `out` because the
263
- * transpose only reorders the size 1 `axis` (not changing memory layout).
264
- */
265
- memcpy (out_shape , original_op_shape , out_ndim * sizeof (npy_intp ));
266
- out_shape [axis ] = 1 ;
267
- }
268
- }
269
-
270
- arg_func = PyArray_DESCR (ap )-> f -> argmin ;
271
- if (arg_func == NULL ) {
272
- PyErr_SetString (PyExc_TypeError ,
273
- "data type not ordered" );
274
- goto fail ;
275
- }
276
- elsize = PyArray_DESCR (ap )-> elsize ;
277
- m = PyArray_DIMS (ap )[PyArray_NDIM (ap )- 1 ];
278
- if (m == 0 ) {
279
- PyErr_SetString (PyExc_ValueError ,
280
- "attempt to get argmin of an empty sequence" );
281
- goto fail ;
282
- }
283
-
284
- if (!out ) {
285
- rp = (PyArrayObject * )PyArray_NewFromDescr (
286
- Py_TYPE (ap ), PyArray_DescrFromType (NPY_INTP ),
287
- out_ndim , out_shape , NULL , NULL ,
288
- 0 , (PyObject * )ap );
289
- if (rp == NULL ) {
290
- goto fail ;
291
- }
292
- }
293
- else {
294
- if ((PyArray_NDIM (out ) != out_ndim ) ||
295
- !PyArray_CompareLists (PyArray_DIMS (out ), out_shape , out_ndim )) {
296
- PyErr_SetString (PyExc_ValueError ,
297
- "output array does not match result of np.argmin." );
298
- goto fail ;
299
- }
300
- rp = (PyArrayObject * )PyArray_FromArray (out ,
301
- PyArray_DescrFromType (NPY_INTP ),
302
- NPY_ARRAY_CARRAY | NPY_ARRAY_WRITEBACKIFCOPY );
303
- if (rp == NULL ) {
304
- goto fail ;
305
- }
306
- }
307
-
308
- NPY_BEGIN_THREADS_DESCR (PyArray_DESCR (ap ));
309
- n = PyArray_SIZE (ap )/m ;
310
- rptr = (npy_intp * )PyArray_DATA (rp );
311
- for (ip = PyArray_DATA (ap ), i = 0 ; i < n ; i ++ , ip += elsize * m ) {
312
- arg_func (ip , m , rptr , ap );
313
- rptr += 1 ;
314
- }
315
- NPY_END_THREADS_DESCR (PyArray_DESCR (ap ));
316
-
317
- Py_DECREF (ap );
318
- /* Trigger the UPDATEIFCOPY/WRITEBACKIFCOPY if necessary */
319
- if (out != NULL && out != rp ) {
320
- PyArray_ResolveWritebackIfCopy (rp );
321
- Py_DECREF (rp );
322
- rp = out ;
323
- Py_INCREF (rp );
324
- }
325
- return (PyObject * )rp ;
326
-
327
- fail :
328
- Py_DECREF (ap );
329
- Py_XDECREF (rp );
330
- return NULL ;
212
+ return _PyArray_ArgMinMaxCommon (op , axis , out , keepdims , 0 );
331
213
}
332
214
333
215
/*NUMPY_API
@@ -336,7 +218,7 @@ _PyArray_ArgMinWithKeepdims(PyArrayObject *op,
336
218
NPY_NO_EXPORT PyObject *
337
219
PyArray_ArgMin (PyArrayObject * op , int axis , PyArrayObject * out )
338
220
{
339
- return _PyArray_ArgMinWithKeepdims (op , axis , out , 0 );
221
+ return _PyArray_ArgMinMaxCommon (op , axis , out , 0 , 0 );
340
222
}
341
223
342
224
/*NUMPY_API
0 commit comments