8000 MAINT: factored out _PyArray_ArgMinMaxCommon by czgdp1807 · Pull Request #19440 · numpy/numpy · GitHub
[go: up one dir, main page]

Skip to content

MAINT: factored out _PyArray_ArgMinMaxCommon #19440

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 12, 2021
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
factored out _PyArray_ArgMinMaxCommon
  • Loading branch information
czgdp1807 committed Jul 9, 2021
commit fa8bbbbea9755ecf55ecd382ea3fbb5d8ac6f851
162 changes: 20 additions & 142 deletions numpy/core/src/multiarray/calculation.c
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,12 @@ power_of_ten(int n)
}

NPY_NO_EXPORT PyObject *
_PyArray_ArgMaxWithKeepdims(PyArrayObject *op,
int axis, PyArrayObject *out, int keepdims)
_PyArray_ArgMinMaxCommon(PyArrayObject *op,
int axis, PyArrayObject *out, int keepdims,
npy_bool is_argmax)
{
PyArrayObject *ap = NULL, *rp = NULL;
PyArray_ArgFunc* arg_func;
PyArray_ArgFunc* arg_func = NULL;
char *ip;
npy_intp *rptr;
npy_intp i, n, m;
Expand Down Expand Up @@ -115,7 +116,12 @@ _PyArray_ArgMaxWithKeepdims(PyArrayObject *op,
}
}

arg_func = PyArray_DESCR(ap)->f->argmax;
if (is_argmax) {
arg_func = PyArray_DESCR(ap)->f->argmax;
}
else {
arg_func = PyArray_DESCR(ap)->f->argmin;
}
if (arg_func == NULL) {
PyErr_SetString(PyExc_TypeError,
"data type not ordered");
Expand Down Expand Up @@ -179,155 +185,27 @@ _PyArray_ArgMaxWithKeepdims(PyArrayObject *op,
return NULL;
}

NPY_NO_EXPORT PyObject*
_PyArray_ArgMaxWithKeepdims(PyArrayObject *op,
int axis, PyArrayObject *out, int keepdims)
{
return _PyArray_ArgMinMaxCommon(op, axis, out, keepdims, 1);
}

/*NUMPY_API
* ArgMax
*/
NPY_NO_EXPORT PyObject *
PyArray_ArgMax(PyArrayObject *op, int axis, PyArrayObject *out)
{
return _PyArray_ArgMaxWithKeepdims(op, axis, out, 0);
return _PyArray_ArgMinMaxCommon(op, axis, out, 0, 1);
}

NPY_NO_EXPORT PyObject *
_PyArray_ArgMinWithKeepdims(PyArrayObject *op,
int axis, PyArrayObject *out, int keepdims)
{
PyArrayObject *ap = NULL, *rp = NULL;
PyArray_ArgFunc* arg_func;
char *ip;
npy_intp *rptr;
npy_intp i, n, m;
int elsize;
// Keep a copy because axis changes via call to PyArray_CheckAxis
int axis_copy = axis;
npy_intp _shape_buf[NPY_MAXDIMS];
npy_intp *out_shape;
// Keep the number of dimensions and shape of
// original array. Helps when `keepdims` is True.
npy_intp* original_op_shape = PyArray_DIMS(op);
int out_ndim = PyArray_NDIM(op);
NPY_BEGIN_THREADS_DEF;

if ((ap = (PyArrayObject *)PyArray_CheckAxis(op, &axis, 0)) == NULL) {
return NULL;
}
/*
* We need to permute the array so that axis is placed at the end.
* And all other dimensions are shifted left.
*/
if (axis != PyArray_NDIM(ap)-1) {
PyArray_Dims newaxes;
npy_intp dims[NPY_MAXDIMS];
int i;

newaxes.ptr = dims;
newaxes.len = PyArray_NDIM(ap);
for (i = 0; i < axis; i++) {
dims[i] = i;
}
for (i = axis; i < PyArray_NDIM(ap) - 1; i++) {
dims[i] = i + 1;
}
dims[PyArray_NDIM(ap) - 1] = axis;
op = (PyArrayObject *)PyArray_Transpose(ap, &newaxes);
Py_DECREF(ap);
if (op == NULL) {
return NULL;
}
}
else {
op = ap;
}

/* Will get native-byte order contiguous copy. */
ap = (PyArrayObject *)PyArray_ContiguousFromAny((PyObject *)op,
PyArray_DESCR(op)->type_num, 1, 0);
Py_DECREF(op);
if (ap == NULL) {
return NULL;
}

// Decides the shape of the output array.
if (!keepdims) {
out_ndim = PyArray_NDIM(ap) - 1;
out_shape = PyArray_DIMS(ap);
} else {
out_shape = _shape_buf;
if (axis_copy == NPY_MAXDIMS) {
for (int i = 0; i < out_ndim; i++) {
out_shape[i] = 1;
}
} else {
/*
* While `ap` may be transposed, we can ignore this for `out` because the
* transpose only reorders the size 1 `axis` (not changing memory layout).
*/
memcpy(out_shape, original_op_shape, out_ndim * sizeof(npy_intp));
out_shape[axis] = 1;
}
}

arg_func = PyArray_DESCR(ap)->f->argmin;
if (arg_func == NULL) {
PyErr_SetString(PyExc_TypeError,
"data type not ordered");
goto fail;
}
elsize = PyArray_DESCR(ap)->elsize;
m = PyArray_DIMS(ap)[PyArray_NDIM(ap)-1];
if (m == 0) {
PyErr_SetString(PyExc_ValueError,
"attempt to get argmin of an empty sequence");
goto fail;
}

if (!out) {
rp = (PyArrayObject *)PyArray_NewFromDescr(
Py_TYPE(ap), PyArray_DescrFromType(NPY_INTP),
out_ndim, out_shape, NULL, NULL,
0, (PyObject *)ap);
if (rp == NULL) {
goto fail;
}
}
else {
if ((PyArray_NDIM(out) != out_ndim) ||
!PyArray_CompareLists(PyArray_DIMS(out), out_shape, out_ndim)) {
PyErr_SetString(PyExc_ValueError,
"output array does not match result of np.argmin.");
goto fail;
}
rp = (PyArrayObject *)PyArray_FromArray(out,
PyArray_DescrFromType(NPY_INTP),
NPY_ARRAY_CARRAY | NPY_ARRAY_WRITEBACKIFCOPY);
if (rp == NULL) {
goto fail;
}
}

NPY_BEGIN_THREADS_DESCR(PyArray_DESCR(ap));
n = PyArray_SIZE(ap)/m;
rptr = (npy_intp *)PyArray_DATA(rp);
for (ip = PyArray_DATA(ap), i = 0; i < n; i++, ip += elsize*m) {
arg_func(ip, m, rptr, ap);
rptr += 1;
}
NPY_END_THREADS_DESCR(PyArray_DESCR(ap));

Py_DECREF(ap);
/* Trigger the UPDATEIFCOPY/WRITEBACKIFCOPY if necessary */
if (out != NULL && out != rp) {
PyArray_ResolveWritebackIfCopy(rp);
Py_DECREF(rp);
rp = out;
Py_INCREF(rp);
}
return (PyObject *)rp;

fail:
Py_DECREF(ap);
Py_XDECREF(rp);
return NULL;
return _PyArray_ArgMinMaxCommon(op, axis, out, keepdims, 0);
}

/*NUMPY_API
Expand All @@ -336,7 +214,7 @@ _PyArray_ArgMinWithKeepdims(PyArrayObject *op,
NPY_NO_EXPORT PyObject *
PyArray_ArgMin(PyArrayObject *op, int axis, PyArrayObject *out)
{
return _PyArray_ArgMinWithKeepdims(op, axis, out, 0);
return _PyArray_ArgMinMaxCommon(op, axis, out, 0, 0);
}

/*NUMPY_API
Expand Down
0