8000 Update ufunc override to work properly with ufunc methods. by cowlicks · Pull Request #4626 · numpy/numpy · GitHub
[go: up one dir, main page]

Skip to content

Update ufunc override to work properly with ufunc methods. #4626

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 15, 2014
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
221 changes: 192 additions & 29 deletions numpy/core/src/private/ufunc_override.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,154 @@
#include <npy_config.h>
#include "numpy/arrayobject.h"
#include "common.h"
#include <string.h>
#include "numpy/ufuncobject.h"

static void
normalize___call___args(PyUFuncObject *ufunc, PyObject *args,
PyObject **normal_args, PyObject **normal_kwds,
int nin)
{
/* ufunc.__call__(*args, **kwds) */
int nargs = PyTuple_GET_SIZE(args);
PyObject *obj;

*normal_args = PyTuple_GetSlice(args, 0, nin);

/* If we have more args than nin, they must be the output variables.*/
if (nargs > nin) {
if ((nargs - nin) == 1) {
obj = PyTuple_GET_ITEM(args, nargs - 1);
PyDict_SetItemString(*normal_kwds, "out", obj);
}
else {
obj = PyTuple_GetSlice(args, nin, nargs);
PyDict_SetItemString(*normal_kwds, "out", obj);
}
}
}

static void
normalize_reduce_args(PyUFuncObject *ufunc, PyObject *args,
PyObject **normal_args, PyObject **normal_kwds)
{
/* ufunc.reduce(a[, axis, dtype, out, keepdims]) */
int nargs = PyTuple_GET_SIZE(args);
int i;
PyObject *obj;

for (i = 0; i < nargs; i++) {
obj = PyTuple_GET_ITEM(args, i);
if (i == 0) {
*normal_args = PyTuple_GetSlice(args, 0, 1);
}
else if (i == 1) {
/* axis */
PyDict_SetItemString(*normal_kwds, "axis", obj);
}
else if (i == 2) {
/* dtype */
PyDict_SetItemString(*normal_kwds, "dtype", obj);
}
else if (i == 3) {
/* out */
PyDict_SetItemString(*normal_kwds, "out", obj);
}
else {
/* keepdims */
PyDict_SetItemString(*normal_kwds, "keepdims", obj);
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do I need to Py_DECREF(obj) here (and other places)? I'm not sure if PyDict_SetItemString has its own reference to obj now.

}
return;
}

static void
normalize_accumulate_args(PyUFuncObject *ufunc, PyObject *args,
PyObject **normal_args, PyObject **normal_kwds)
{
/* ufunc.accumulate(a[, axis, dtype, out]) */
int nargs = PyTuple_GET_SIZE(args);
int i;
PyObject *obj;

for (i = 0; i < nargs; i++) {
obj = PyTuple_GET_ITEM(args, i);
if (i == 0) {
*normal_args = PyTuple_GetSlice(args, 0, 1);
}
else if (i == 1) {
/* axis */
PyDict_SetItemString(*normal_kwds, "axis", obj);
}
else if (i == 2) {
/* dtype */
PyDict_SetItemString(*normal_kwds, "dtype", obj);
}
else {
/* out */
PyDict_SetItemString(*normal_kwds, "out", obj);
}
}
return;
}

static void
normalize_reduceat_args(PyUFuncObject *ufunc, PyObject *args,
PyObject **normal_args, PyObject **normal_kwds)
{
/* ufunc.reduceat(a, indicies[, axis, dtype, out]) */
int i;
int nargs = PyTuple_GET_SIZE(args);
PyObject *obj;

for (i = 0; i < nargs; i++) {
obj = PyTuple_GET_ITEM(args, i);
if (i == 0) {
/* a and indicies */
*normal_args = PyTuple_GetSlice(args, 0, 2);
}
else if (i == 1) {
/* Handled above, when i == 0. */
continue;
}
else if (i == 2) {
/* axis */
PyDict_SetItemString(*normal_kwds, "axis", obj);
}
else if (i == 3) {
/* dtype */
PyDict_SetItemString(*normal_kwds, "dtype", obj);
}
else {
/* out */
PyDict_SetItemString(*normal_kwds, "out", obj);
}
}
return;
}

static void
normalize_outer_args(PyUFuncObject *ufunc, PyObject *args,
PyObject **normal_args, PyObject **normal_kwds)
{
/* ufunc.outer(A, B)
* This has no kwds so we don't need to do any kwd stuff.
*/
*normal_args = PyTuple_GetSlice(args, 0, 2);
return;
}

static void
normalize_at_args(PyUFuncObject *ufunc, PyObject *args,
PyObject **normal_args, PyObject **normal_kwds)
{
/* ufunc.at(a, indices[, b]) */
int nargs = PyTuple_GET_SIZE(args);

*normal_args = PyTuple_GetSlice(args, 0, nargs);
return;
}

/*
* Check a set of args for the `__numpy_ufunc__` method. If more than one of
* the input arguments implements `__numpy_ufunc__`, they are tried in the
Expand All @@ -18,7 +164,7 @@
*/
static int
PyUFunc_CheckOverride(PyUFuncObject *ufunc, char *method,
PyObject *args, PyObject *kwds,
PyObject *args, PyObject *kwds,
PyObject **result,
int nin)
{
Expand All @@ -36,23 +182,23 @@ PyUFunc_CheckOverride(PyUFuncObject *ufunc, char *method,
PyObject *normal_args = NULL; /* normal_* holds normalized arguments. */
PyObject *normal_kwds = NULL;

PyObject *with_override[NPY_MAXARGS];
PyObject *with_override[NPY_MAXARGS];

/* Pos of each override in args */
int with_override_pos[NPY_MAXARGS];

/*
/*
* Check inputs
*/
if (!PyTuple_Check(args)) {
PyErr_SetString(PyExc_ValueError,
PyErr_SetString(PyExc_ValueError,
"Internal Numpy error: call to PyUFunc_CheckOverride "
"with non-tuple");
goto fail;
}

if (PyTuple_GET_SIZE(args) > NPY_MAXARGS) {
PyErr_SetString(PyExc_ValueError,
PyErr_SetString(PyExc_ValueError,
"Internal Numpy error: too many arguments in call "
"to PyUFunc_CheckOverride");
goto fail;
Expand Down Expand Up @@ -81,14 +227,15 @@ PyUFunc_CheckOverride(PyUFuncObject *ufunc, char *method,
return 0;
}

/*
* Normalize ufunc arguments.
*/
normal_args = PyTuple_GetSlice(args, 0, nin);
if (normal_args == NULL) {
method_name = PyUString_FromString(method);
if (method_name == NULL) {
goto fail;
}

/*
* Normalize ufunc arguments.
*/

/* Build new kwds */
if (kwds && PyDict_CheckExact(kwds)) {
normal_kwds = PyDict_Copy(kwds);
Expand All @@ -100,21 +247,38 @@ PyUFunc_CheckOverride(PyUFuncObject *ufunc, char *method,
goto fail;
}

/* If we have more args than nin, they must be the output variables.*/
if (nargs > nin) {
if ((nargs - nin) == 1) {
obj = PyTuple_GET_ITEM(args, nargs - 1);
PyDict_SetItemString(normal_kwds, "out", obj);
}
else {
obj = PyTuple_GetSlice(args, nin, nargs);
PyDict_SetItemString(normal_kwds, "out", obj);
Py_DECREF(obj);
}
/* decide what to do based on the method. */
/* ufunc.__call__ */
if (strcmp(method, "__call__") == 0) {
normalize___call___args(ufunc, args, &normal_args, &normal_kwds, nin);
}

method_name = PyUString_FromString(method);
if (method_name == NULL) {
/* ufunc.reduce */
else if (strcmp(method, "reduce") == 0) {
normalize_reduce_args(ufunc, args, &normal_args, &normal_kwds);
}

/* ufunc.accumulate */
else if (strcmp(method, "accumulate") == 0) {
normalize_accumulate_args(ufunc, args, &normal_args, &normal_kwds);
}

/* ufunc.reduceat */
else if (strcmp(method, "reduceat") == 0) {
normalize_reduceat_args(ufunc, args, &normal_args, &normal_kwds);
}

/* ufunc.outer */
else if (strcmp(method, "outer") == 0) {
normalize_outer_args(ufunc, args, &normal_args, &normal_kwds);
}

/* ufunc.at */
else if (strcmp(method, "at") == 0) {
normalize_at_args(ufunc, args, &normal_args, &normal_kwds);
}

if (normal_args == NULL) {
goto fail;
}

6D40 Expand Down Expand Up @@ -144,7 +308,7 @@ PyUFunc_CheckOverride(PyUFuncObject *ufunc, char *method,
for (j = i + 1; j < noa; j++) {
other_obj = with_override[j];
if (PyObject_Type(other_obj) != PyObject_Type(obj) &&
PyObject_IsInstance(other_obj,
PyObject_IsInstance(other_obj,
PyObject_Type(override_obj))) {
override_obj = NULL;
break;
Expand All @@ -161,27 +325,27 @@ PyUFunc_CheckOverride(PyUFuncObject *ufunc, char *method,
/* Check if there is a method left to call */
if (!override_obj) {
/* No acceptable override found. */
PyErr_SetString(PyExc_TypeError,
PyErr_SetString(PyExc_TypeError,
"__numpy_ufunc__ not implemented for this type.");
goto fail;
}

/* Call the override */
numpy_ufunc = PyObject_GetAttrString(override_obj,
numpy_ufunc = PyObject_GetAttrString(override_obj,
"__numpy_ufunc__");
if (numpy_ufunc == NULL) {
goto fail;
}

override_args = Py_BuildValue("OOiO", ufunc, method_name,
override_args = Py_BuildValue("OOiO", ufunc, method_name,
override_pos, normal_args);
if (override_args == NULL) {
Py_DECREF(numpy_ufunc);
goto fail;
}

*result = PyObject_Call(numpy_ufunc, override_args, normal_kwds);

Py_DECREF(numpy_ufunc);
Py_DECREF(override_args);

Expand Down Expand Up @@ -212,5 +376,4 @@ PyUFunc_CheckOverride(PyUFuncObject *ufunc, char *method,
Py_XDECREF(normal_kwds);
return 1;
}

#endif
Loading
0