10000 BUG: Fix use and errorchecking of ObjectType use by seberg · Pull Request #22566 · numpy/numpy · GitHub
[go: up one dir, main page]

Skip to content

BUG: Fix use and errorchecking of ObjectType use #22566

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions numpy/core/src/multiarray/_multiarray_tests.c.src
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,14 @@ test_neighborhood_iterator(PyObject* NPY_UNUSED(self), PyObject* args)
return NULL;
}

typenum = PyArray_ObjectType(x, 0);
typenum = PyArray_ObjectType(x, NPY_NOTYPE);
if (typenum == NPY_NOTYPE) {
return NULL;
}
typenum = PyArray_ObjectType(fill, typenum);
if (typenum == NPY_NOTYPE) {
return NULL;
}

ax = (PyArrayObject*)PyArray_FromObject(x, typenum, 1, 10);
if (ax == NULL) {
Expand Down Expand Up @@ -343,7 +349,10 @@ test_neighborhood_iterator_oob(PyObject* NPY_UNUSED(self), PyObject* args)
return NULL;
}

typenum = PyArray_ObjectType(x, 0);
typenum = PyArray_ObjectType(x, NPY_NOTYPE);
if (typenum == NPY_NOTYPE) {
return NULL;
}

ax = (PyArrayObject*)PyArray_FromObject(x, typenum, 1, 10);
if (ax == NULL) {
Expand Down
12 changes: 9 additions & 3 deletions numpy/core/src/multiarray/arraytypes.c.src
Original file line number Diff line number Diff line change
Expand Up @@ -3803,17 +3803,23 @@ BOOL_dot(char *ip1, npy_intp is1, char *ip2, npy_intp is2, char *op, npy_intp n,
*((npy_bool *)op) = tmp;
}

/*
* `dot` does not make sense for times, for DATETIME it never worked.
* For timedelta it does/did , but should probably also just be removed.
*/
#define DATETIME_dot NULL

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit of a madness, but np.result_type(bool, "M8") never worked, and that prevented us (correctly) from even trying to get a datetime dot, vdot, or correlate.

Now it doesn't make any sense for timedelta as well (nor does it really work right, since units are always dropped), but not doing that here. This just "preserves" breaking things reliably, becuase we definitely cannot call the functions if they are not defined ;).

/**begin repeat
*
* #name = BYTE, UBYTE, SHORT, USHORT, INT, UINT,
* LONG, ULONG, LONGLONG, ULONGLONG,
* LONGDOUBLE, DATETIME, TIMEDELTA#
* LONGDOUBLE, TIMEDELTA#
* #type = npy_byte, npy_ubyte, npy_short, npy_ushort, npy_int, npy_uint,
* npy_long, npy_ulong, npy_longlong, npy_ulonglong,
* npy_longdouble, npy_datetime, npy_timedelta#
* npy_longdouble, npy_timedelta#
* #out = npy_long, npy_ulong, npy_long, npy_ulong, npy_long, npy_ulong,
* npy_long, npy_ulong, npy_longlong, npy_ulonglong,
* npy_longdouble, npy_datetime, npy_timedelta#
* npy_longdouble, npy_timedelta#
*/
static void
@name@_dot(char *ip1, npy_intp is1, char *ip2, npy_intp is2, char *op, npy_intp n,
Expand Down
40 changes: 33 additions & 7 deletions numpy/core/src/multiarray/multiarraymodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -899,11 +899,15 @@ PyArray_InnerProduct(PyObject *op1, PyObject *op2)
int i;
PyObject* ret = NULL;

typenum = PyArray_ObjectType(op1, 0);
if (typenum == NPY_NOTYPE && PyErr_Occurred()) {
typenum = PyArray_ObjectType(op1, NPY_NOTYPE);
if (typenum == NPY_NOTYPE) {
return NULL;
}
typenum = PyArray_ObjectType(op2, typenum);
if (typenum == NPY_NOTYPE) {
return NULL;
}

typec = PyArray_DescrFromType(typenum);
if (typec == NULL) {
if (!PyErr_Occurred()) {
Expand Down Expand Up @@ -991,11 +995,15 @@ PyArray_MatrixProduct2(PyObject *op1, PyObject *op2, PyArrayObject* out)
PyArray_Descr *typec = NULL;
NPY_BEGIN_THREADS_DEF;

typenum = PyArray_ObjectType(op1, 0);
if (typenum == NPY_NOTYPE && PyErr_Occurred()) {
typenum = PyArray_ObjectType(op1, NPY_NOTYPE);
if (typenum == NPY_NOTYPE) {
return NULL;
}
typenum = PyArray_ObjectType(op2, typenum);
if (typenum == NPY_NOTYPE) {
return NULL;
}

typec = PyArray_DescrFromType(typenum);
if (typec == NULL) {
if (!PyErr_Occurred()) {
Expand Down Expand Up @@ -1373,8 +1381,14 @@ PyArray_Correlate2(PyObject *op1, PyObject *op2, int mode)
int inverted;
int st;

typenum = PyArray_ObjectType(op1, 0);
typenum = PyArray_ObjectType(op1, NPY_NOTYPE);
if (typenum == NPY_NOTYPE) {
return NULL;
}
typenum = PyArray_ObjectType(op2, typenum);
if (typenum == NPY_NOTYPE) {
return NULL;
}

typec = PyArray_DescrFromType(typenum);
Py_INCREF(typec);
Expand Down Expand Up @@ -1440,8 +1454,14 @@ PyArray_Correlate(PyObject *op1, PyObject *op2, int mode)
int unused;
PyArray_Descr *typec;

typenum = PyArray_ObjectType(op1, 0);
typenum = PyArray_ObjectType(op1, NPY_NOTYPE);
if (typenum == NPY_NOTYPE) {
return NULL;
}
typenum = PyArray_ObjectType(op2, typenum);
if (typenum == NPY_NOTYPE) {
return NULL;
}

typec = PyArray_DescrFromType(typenum);
Py_INCREF(typec);
Expand Down Expand Up @@ -2541,8 +2561,14 @@ array_vdot(PyObject *NPY_UNUSED(dummy), PyObject *args)
* Conjugating dot product using the BLAS for vectors.
* Flattens both op1 and op2 before dotting.
*/
typenum = PyArray_ObjectType(op1, 0);
typenum = PyArray_ObjectType(op1, NPY_NOTYPE);
if (typenum == NPY_NOTYPE) {
return NULL;
}
typenum = PyArray_ObjectType(op2, typenum);
if (typenum == NPY_NOTYPE) {
return NULL;
}

type = PyArray_DescrFromType(typenum);
Py_INCREF(type);
Expand Down
24 changes: 18 additions & 6 deletions numpy/core/tests/test_multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1257,9 +1257,9 @@ def test_subarray_comparison(self):
# The main importance is that it does not return True:
with pytest.raises(TypeError):
x == y

def test_empty_structured_array_comparison(self):
# Check that comparison works on empty arrays with nontrivially
# Check that comparison works on empty arrays with nontrivially
# shaped fields
a = np.zeros(0, [('a', '<f8', (1, 1))])
assert_equal(a, a)
Expand Down Expand Up @@ -2232,7 +2232,7 @@ def assert_c(arr):
assert_c(a.copy('C'))
assert_fortran(a.copy('F'))
assert_c(a.copy('A'))

@pytest.mark.parametrize("dtype", ['O', np.int32, 'i,O'])
def test__deepcopy__(self, dtype):
# Force the entry of NULLs into array
Expand Down Expand Up @@ -2441,7 +2441,7 @@ def test_sort_unicode_kind(self):
np.array([0, 1, np.nan]),
])
def test_searchsorted_floats(self, a):
# test for floats arrays containing nans. Explicitly test
# test for floats arrays containing nans. Explicitly test
# half, single, and double precision floats to verify that
# the NaN-handling is correct.
msg = "Test real (%s) searchsorted with nans, side='l'" % a.dtype
Expand All @@ -2457,7 +2457,7 @@ def test_searchsorted_floats(self, a):
assert_equal(y, 2)

def test_searchsorted_complex(self):
# test for complex arrays containing nans.
# test for complex arrays containing nans.
# The search sorted routines use the compare functions for the
# array type, so this checks if that is consistent with the sort
# order.
Expand All @@ -2479,7 +2479,7 @@ def test_searchsorted_complex(self):
a = np.array([0, 128], dtype='>i4')
b = a.searchsorted(np.array(128, dtype='>i4'))
assert_equal(b, 1, msg)

def test_searchsorted_n_elements(self):
# Check 0 elements
a = np.ones(0)
Expand Down Expand Up @@ -6731,6 +6731,18 @@ def test_huge_vectordot(self, dtype):
res = np.dot(data, data)
assert res == 2**30+100

def test_dtype_discovery_fails(self):
# See gh-14247, error checking was missing for failed dtype discovery
class BadObject(object):
def __array__(self):
raise TypeError("just this tiny mint leaf")

with pytest.raises(TypeError):
np.dot(BadObject(), BadObject())

with pytest.raises(TypeError):
np.dot(3.0, BadObject())


class MatmulCommon:
"""Common tests for '@' operator and numpy.matmul.
Expand Down
0