8000 ENH: Correct identities for logical ufuncs and logaddexp · liwt31/numpy@4c277f5 · GitHub
[go: up one dir, main page]

Skip to content

Commit 4c277f5

Browse files
eric-wieserliwt31
authored andcommitted
ENH: Correct identities for logical ufuncs and logaddexp
Fixes numpy#7702
1 parent 2b1b5fa commit 4c277f5

File tree

2 files changed

+43
-16
lines changed

2 files changed

+43
-16
lines changed

numpy/core/code_generators/generate_umath.py

Lines changed: 39 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,14 @@
1010
import ufunc_docstrings as docstrings
1111
sys.path.pop(0)
1212

13-
Zero = "PyUFunc_Zero"
14-
One = "PyUFunc_One"
15-
None_ = "PyUFunc_None"
16-
AllOnes = "PyUFunc_MinusOne"
17-
ReorderableNone = "PyUFunc_ReorderableNone"
13+
Zero = "PyInt_FromLong(0)"
14+
One = "PyInt_FromLong(1)"
15+
True_ = "(Py_INCREF(Py_True), Py_True)"
16+
False_ = "(Py_INCREF(Py_False), Py_False)"
17+
None_ = object()
18+
AllOnes = "PyInt_FromLong(-1)"
19+
MinusInfinity = 'PyFloat_FromDouble(-NPY_INFINITY)'
20+
ReorderableNone = "(Py_INCREF(Py_None), Py_None)"
1821

1922
# Sentinel value to specify using the full type description in the
2023
# function name
@@ -458,7 +461,7 @@ def english_upper(s):
458461
[TypeDescription('O', FullTypeDescr, 'OO', 'O')],
459462
),
460463
'logical_and':
461-
Ufunc(2, 1, One,
464+
Ufunc(2, 1, True_,
462465
docstrings.get('numpy.core.umath.logical_and'),
463466
'PyUFunc_SimpleBinaryComparisonTypeResolver',
464467
TD(nodatetime_or_obj, out='?', simd=[('avx2', ints)]),
@@ -472,14 +475,14 @@ def english_upper(s):
472475
TD(O, f='npy_ObjectLogicalNot'),
473476
),
474477
'logical_or':
475-
Ufunc(2, 1, Zero,
478+
Ufunc(2, 1, False_,
476479
docstrings.get('numpy.core.umath.logical_or'),
477480
'PyUFunc_SimpleBinaryComparisonTypeResolver',
478481
TD(nodatetime_or_obj, out='?', simd=[('avx2', ints)]),
479482
TD(O, f='npy_ObjectLogicalOr'),
480483
),
481484
'logical_xor':
482-
Ufunc(2, 1, Zero,
485+
Ufunc(2, 1, False_,
483486
docstrings.get('numpy.core.umath.logical_xor'),
484487
'PyUFunc_SimpleBinaryComparisonTypeResolver',
485488
TD(nodatetime_or_obj, out='?'),
@@ -514,7 +517,7 @@ def english_upper(s):
514517
TD(O, f='npy_ObjectMin')
515518
),
516519
'logaddexp':
517-
Ufunc(2, 1, None,
520+
Ufunc(2, 1, MinusInfinity,
518521
docstrings.get('numpy.core.umath.logaddexp'),
519522
None,
520523
TD(flts, f="logaddexp", astype={'e':'f'})
@@ -1048,18 +1051,38 @@ def make_ufuncs(funcdict):
10481051
# do not play well with \n
10491052
docstring = '\\n\"\"'.join(docstring.split(r"\n"))
10501053
fmt = textwrap.dedent("""\
1051-
f = PyUFunc_FromFuncAndData(
1054+
identity = {identity_expr};
1055+
if ({has_identity} && identity == NULL) {{
1056+
return -1;
1057+
}}
1058+
f = PyUFunc_FromFuncAndDataAndSignatureAndIdentity(
10521059
{name}_functions, {name}_data, {name}_signatures, {nloops},
10531060
{nin}, {nout}, {identity}, "{name}",
1054-
"{doc}", 0
1061+
"{doc}", 0, NULL, identity
10551062
);
1063+
if ({has_identity}) {{
1064+
Py_DECREF(identity);
1065+
}}
10561066
if (f == NULL) {{
10571067
return -1;
1058-
}}""")
1059-
mlist.append(fmt.format(
1068+
}}
1069+
""")
1070+
args = dict(
10601071
name=name, nloops=len(uf.type_descriptions),
1061-
nin=uf.nin, nout=uf.nout, identity=uf.identity, doc=docstring
1062-
))
1072+
nin=uf.nin, nout=uf.nout,
1073+
has_identity='0' if uf.identity is None_ else '1',
1074+
identity='PyUFunc_IdentityValue',
1075+
identity_expr=uf.identity,
1076+
doc=docstring
1077+
)
1078+
1079+
# Only PyUFunc_None means don't reorder - we pass this using the old
1080+
# argument
1081+
if uf.identity is None_:
1082+
args['identity'] = 'PyUFunc_None'
1083+
args['identity_expr'] = 'NULL'
1084+
1085+
mlist.append(fmt.format(**args))
10631086
if uf.typereso is not None:
10641087
mlist.append(
10651088
r"((PyUFuncObject *)f)->type_resolver = &%s;" % uf.typereso)
@@ -1087,7 +1110,7 @@ def make_code(funcdict, filename):
10871110
10881111
static int
10891112
InitOperators(PyObject *dictionary) {
1090-
PyObject *f;
1113+
PyObject *f, *identity;
10911114
10921115
%s
10931116
%s

numpy/core/tests/test_umath.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -685,6 +685,10 @@ def test_nan(self):
685685
assert_(np.isnan(np.logaddexp(0, np.nan)))
686686
assert_(np.isnan(np.logaddexp(np.nan, np.nan)))
687687

688+
def test_reduce(self):
689+
assert_equal(np.logaddexp.identity, -np.inf)
690+
assert_equal(np.logaddexp.reduce([]), -np.inf)
691+
688692

689693
class TestLog1p(object):
690694
def test_log1p(self):

0 commit comments

Comments
 (0)
0