8000 API,ENH: Change definition of complex sign and use it in copysign · numpy/numpy@7a1836b · GitHub
[go: up one dir, main page]

Skip to content

Commit 7a1836b

Browse files
committed
API,ENH: Change definition of complex sign and use it in copysign
Following the API Array standard, the complex sign is now calculated as ``z / |z|`` (instead of the rather less logical case where the sign of the real part was taken, unless the real part was zero, in which case the sign of the imaginary part was returned). Like for real numbers, zero is returned if ``z==0``. With this, it has become possible to extend ``np.copysign(x1, x2)`` to complex numbers, since it can now generally return ``|x1| * sign(x2)`` with the sign as defined above (with no special treatment for zero).
1 parent 0032ede commit 7a1836b

File tree

8 files changed

+181
-43
lines changed

8 files changed

+181
-43
lines changed
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
Change in how complex sign is calculated
2+
----------------------------------------
3+
Following the API Array standard, the complex sign is now calculated as
4+
``z / |z|`` (instead of the rather less logical case where the sign of
5+
the real part was taken, unless the real part was zero, in which case
6+
the sign of the imaginary part was returned). Like for real numbers,
7+
zero is returned if ``z==0``.
8+
9+
With this, it has become possible to extend ``np.copysign(x1, x2)`` to
10+
complex numbers, since it can now generally return ``|x1| * sign(x2)``
11+
with the sign as defined above (with no special treatment for zero).

numpy/_core/code_generators/generate_umath.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1069,7 +1069,7 @@ def english_upper(s):
10691069
Ufunc(2, 1, None,
10701070
docstrings.get('numpy._core.umath.copysign'),
10711071
None,
1072-
TD(flts),
1072+
TD(inexact),
10731073
),
10741074
'nextafter':
10751075
Ufunc(2, 1, None,

numpy/_core/code_generators/ufunc_docstrings.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3540,10 +3540,11 @@ def add_newdoc(place, name, doc):
35403540
The `sign` function returns ``-1 if x < 0, 0 if x==0, 1 if x > 0``. nan
35413541
is returned for nan inputs.
35423542
3543-
For complex inputs, the `sign` function returns
3544-
``sign(x.real) + 0j if x.real != 0 else sign(x.imag) + 0j``.
3543+
For complex inputs, the `sign` function returns ``x / abs(x)``, the
3544+
generalization of the above (and ``0 if x==0``).
35453545
3546-
complex(nan, 0) is returned for complex nan inputs.
3546+
.. versionchanged:: 2.0.0
3547+
Definition of complex sign changed to follow the Array API standard.
35473548
35483549
Parameters
35493550
----------
@@ -3569,8 +3570,8 @@ def add_newdoc(place, name, doc):
35693570
array([-1., 1.])
35703571
>>> np.sign(0)
35713572
0
3572-
>>> np.sign(5-2j)
3573-
(1+0j)
3573+
>>> np.sign([3-4j, 8j])
3574+
array([0.6-0.8j, 0. +1.j ])
35743575
35753576
""")
35763577

@@ -3603,7 +3604,14 @@ def add_newdoc(place, name, doc):
36033604
"""
36043605
Change the sign of x1 to that of x2, element-wise.
36053606
3606-
If `x2` is a scalar, its sign will be copied to all elements of `x1`.
3607+
If `x2` is a scalar, its sign will be copied to all elements of `x1`,
3608+
i.e., the function returns ``abs(x1) * sign(x2)``, with the sign defined
3609+
generally as ``x2 / abs(x2)``. For the special case of ``x2 == 0``, for
3610+
real numbers the floating point sign of the zero is taken, while for
3611+
complex numbers the result is undefined (``nan+nanj``).
3612+
3613+
.. versionadded:: 2.0.0
3614+
Support complex numbers using the sign Array API definition of sign.
36073615
36083616
Parameters
36093617
----------

numpy/_core/function_base.py

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -430,26 +430,17 @@ def geomspace(start, stop, num=50, endpoint=True, dtype=None, axis=0):
430430
start = start.astype(dt, copy=True)
431431
stop = stop.astype(dt, copy=True)
432432

433-
out_sign = _nx.ones(_nx.broadcast(start, stop).shape, dt)
434-
# Avoid negligible real or imaginary parts in output by rotating to
435-
# positive real, calculating, then undoing rotation
436-
if _nx.issubdtype(dt, _nx.complexfloating):
437-
all_imag = (start.real == 0.) & (stop.real == 0.)
438-
if _nx.any(all_imag):
439-
start[all_imag] = start[all_imag].imag
440-
stop[all_imag] = stop[all_imag].imag
441-
out_sign[all_imag] = 1j
442-
443-
both_negative = (_nx.sign(start) == -1) & (_nx.sign(stop) == -1)
444-
if _nx.any(both_negative):
445-
_nx.negative(start, out=start, where=both_negative)
446-
_nx.negative(stop, out=stop, where=both_negative)
447-
_nx.negative(out_sign, out=out_sign, where=both_negative)
433+
# Allow negative real values and ensure a consistent result for complex
434+
# (including avoiding negligible real or imaginary parts in output) by
435+
# rotating start to positive real, calculating, then undoing rotation.
436+
out_sign = _nx.sign(start)
437+
start /= out_sign
438+
stop = stop / out_sign
448439

449440
log_start = _nx.log10(start)
450441
log_stop = _nx.log10(stop)
451442
result = logspace(log_start, log_stop, num=num,
452-
endpoint=endpoint, base=10.0, dtype=dtype)
443+
endpoint=endpoint, base=10.0, dtype=dt)
453444

454445
# Make sure the endpoints match the start and stop arguments. This is
455446
# necessary because np.exp(np.log(x)) is not necessarily equal to x.
@@ -458,7 +449,7 @@ def geomspace(start, stop, num=50, endpoint=True, dtype=None, axis=0):
458449
if num > 1 and endpoint:
459450
result[-1] = stop
460451

461-
result = out_sign * result
452+
result *= out_sign
462453

463454
if axis != 0:
464455
result = _nx.moveaxis(result, 0, axis)

numpy/_core/src/umath/loops.c.src

Lines changed: 61 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2229,20 +2229,74 @@ NPY_NO_EXPORT void
22292229
}
22302230
}
22312231

2232+
/*
2233+
* Define calculation of complex sign, z / |z|, which is used in both sign and copysign,
2234+
* with the only difference that for z=0 one should get 0 and NaN, respectively.
2235+
*/
2236+
#define COMPLEX_SIGN(IN, OUT, SIGN_ZERO) \
2237+
const @ftype@ in_r = IN[0]; \
2238+
const @ftype@ in_i = IN[1]; \
2239+
const @ftype@ in_abs = npy_hypot@c@(in_r, in_i); \
2240+
if (NPY_UNLIKELY(npy_isnan(in_abs))) { \
2241+
OUT[0] = NPY_NAN@C@; \
2242+
OUT[1] = NPY_NAN@C@; \
2243+
} \
2244+
else if (NPY_UNLIKELY(npy_isinf(in_abs))) { \
2245+
if (npy_isinf(in_r)) { \
2246+
if (npy_isinf(in_i)) { \
2247+
OUT[0] = NPY_NAN@C@; \
2248+
OUT[1] = NPY_NAN@C@; \
2249+
} \
2250+
else { \
2251+
OUT[0] = in_r > 0 ? 1.: -1.; \
2252+
OUT[1] = 0.; \
2253+
} \
2254+
} \
2255+
else { \
2256+
OUT[0] = 0.; \
2257+
OUT[1] = in_i > 0 ? 1.: -1.; \
2258+
} \
2259+
} \
2260+
else if (NPY_UNLIKELY(in_abs == 0)) { \
2261+
OUT[0] = SIGN_ZERO; \
2262+
OUT[1] = SIGN_ZERO; \
2263+
} \
2264+
else{ \
2265+
OUT[0] = in_r / in_abs; \
2266+
OUT[1] = in_i / in_abs; \
2267+
}
2268+
22322269
NPY_NO_EXPORT void
22332270
@TYPE@_sign(char **args, npy_intp const *dimensions, npy_intp const *steps, void *NPY_UNUSED(func))
22342271
{
2235-
/* fixme: sign of nan is currently 0 */
22362272
UNARY_LOOP {
2237-
const @ftype@ in1r = ((@ftype@ *)ip1)[0];
2238-
const @ftype@ in1i = ((@ftype@ *)ip1)[1];
2239-
((@ftype@ *)op1)[0] = CGT(in1r, in1i, 0.0, 0.0) ? 1 :
2240-
(CLT(in1r, in1i, 0.0, 0.0) ? -1 :
2241-
(CEQ(in1r, in1i, 0.0, 0.0) ? 0 : NPY_NAN@C@));
2242-
((@ftype@ *)op1)[1] = 0;
2273+
COMPLEX_SIGN(((@ftype@ *)ip1), ((@ftype@ *)op1), 0.@C@)
2274+
}
2275+
}
2276+
2277+
NPY_NO_EXPORT void
2278+
@TYPE@_copysign(char **args, npy_intp const *dimensions, npy_intp const *steps, void *NPY_UNUSED(func))
2279+
{
2280+
BINARY_LOOP {
2281+
const @ftype@ in1_abs = npy_hypot@c@(((@ftype@ *)ip1)[0], ((@ftype@ *)ip1)[1]);
2282+
/*
2283+
* Generically, inf * exp(phase) should given NaN,
2284+
* since one cannot recover the phase.
2285+
*/
2286+
if (NPY_UNLIKELY(npy_isinf(in1_abs))) {
2287+
((@ftype@ *)op1)[0] = NPY_NAN@C@;
2288+
((@ftype@ *)op1)[1] = NPY_NAN@C@;
2289+
}
2290+
else {
2291+
COMPLEX_SIGN(((@ftype@ *)ip2), ((@ftype@ *)op1), NPY_NAN@C@)
2292+
((@ftype@ *)op1)[0] *= in1_abs;
2293+
((@ftype@ *)op1)[1] *= in1_abs;
2294+
}
22432295
}
22442296
}
22452297

2298+
#undef COMPLEX_SIGN
2299+
22462300
/**begin repeat1
22472301
* #kind = maximum, minimum#
22482302
* #OP = CGE, CLE#

numpy/_core/src/umath/loops.h.src

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -666,6 +666,9 @@ C@TYPE@__arg(char **args, npy_intp const *dimensions, npy_intp const *steps, voi
666666
NPY_NO_EXPORT void
667667
C@TYPE@_sign(char **args, npy_intp const *dimensions, npy_intp const *steps, void *NPY_UNUSED(func));
668668

669+
NPY_NO_EXPORT void
670+
C@TYPE@_copysign(char **args, npy_intp const *dimensions, npy_intp const *steps, void *NPY_UNUSED(func));
671+
669672
/**begin repeat1
670673
* #kind = maximum, minimum#
671674
* #OP = CGE, CLE#

numpy/_core/tests/test_regression.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1189,9 +1189,10 @@ def test_unaligned_unicode_access(self):
11891189
def test_sign_for_complex_nan(self):
11901190
# Ticket 794.
11911191
with np.errstate(invalid='ignore'):
1192-
C = np.array([-np.inf, -2+1j, 0, 2-1j, np.inf, np.nan])
1192+
C = np.array([-np.inf, -3+4j, 0, 4-3j, np.inf, np.nan])
11931193
have = np.sign(C)
1194-
want = np.array([-1+0j, -1+0j, 0+0j, 1+0j, 1+0j, np.nan])
1194+
want = np.array([-1+0j, -0.6+0.8j, 0+0j, 0.8-0.6j, 1+0j,
1195+
complex(np.nan, np.nan)])
11951196
assert_equal(have, want)
11961197

11971198
def test_for_equal_names(self):
@@ -1481,7 +1482,7 @@ def test_buffer_hashlib(self):
14811482

14821483
x = np.array([1, 2, 3], dtype=np.dtype('<i4'))
14831484
assert_equal(
1484-
sha256(x).hexdigest(),
1485+
sha256(x).hexdigest(),
14851486
'4636993d3e1da4e9d6b8f87b79e8f7c6d018580d52661950eabc3845c5897a4d'
14861487
)
14871488

@@ -1941,7 +1942,7 @@ def test_pickle_py2_scalar_latin1_hack(self):
19411942
'invalid'),
19421943

19431944
# different 8-bit code point in KOI8-R vs latin1
1944-
(np.bytes_(b'\x9c'),
1945+
(np.bytes_(b'\x9c'),
19451946
b"cnumpy.core.multiarray\nscalar\np0\n(cnumpy\ndtype\np1\n(S'S1'\np2\nI0\nI1\ntp3\nRp4\n(I3\nS'|'\np5\nNNNI1\nI1\nI0\ntp6\nbS'\\x9c'\np7\ntp8\nRp9\n.", # noqa
19461947
'different'),
19471948
]

numpy/_core/tests/test_umath.py

Lines changed: 78 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2825,6 +2825,28 @@ def test_sign(self):
28252825
assert_equal(res, tgt)
28262826
assert_equal(out, tgt)
28272827

2828+
def test_sign_complex(self):
2829+
a = np.array([
2830+
np.inf, -np.inf, complex(0, np.inf), complex(0, -np.inf),
2831+
complex(np.inf, np.inf), complex(np.inf, -np.inf), # nan
2832+
np.nan, complex(0, np.nan), complex(np.nan, np.nan), # nan
2833+
0.0, # 0.
2834+
3.0, -3.0, -2j, 3.0+4.0j, -8.0+6.0j
2835+
])
2836+
out = np.zeros(a.shape, a.dtype)
2837+
tgt = np.array([
2838+
1., -1., 1j, -1j,
2839+
] + [complex(np.nan, np.nan)] * 5 + [
2840+
0.0,
2841+
1.0, -1.0, -1j, 0.6+0.8j, -0.8+0.6j])
2842+
2843+
with np.errstate(invalid='ignore'):
2844+
res = ncu.sign(a)
2845+
assert_equal(res, tgt)
2846+
res = ncu.sign(a, out)
2847+
assert_(res is out)
2848+
assert_equal(res, tgt)
2849+
28282850
def test_sign_dtype_object(self):
28292851
# In reference to github issue #6229
28302852

@@ -2843,6 +2865,62 @@ def test_nan():
28432865

28442866
assert_raises(TypeError, test_nan)
28452867

2868+
2869+
class TestCopySign:
2870+
def test_copysign(self):
2871+
assert_(np.copysign(1, -1) == -1)
2872+
with np.errstate(divide="ignore"):
2873+
assert_(1 / np.copysign(0, -1) < 0)
2874+
assert_(1 / np.copysign(0, 1) > 0)
2875+
assert_(np.signbit(np.copysign(np.nan, -1)))
2876+
assert_(not np.signbit(np.copysign(np.nan, 1)))
2877+
2878+
def test_copysign_complex(self):
2879+
a = np.array([
2880+
np.inf, -np.inf, complex(0, np.inf), complex(0, -np.inf),
2881+
complex(np.inf, np.inf), complex(np.inf, -np.inf), # nan
2882+
np.nan, complex(0, np.nan), complex(np.nan, np.nan), # nan
2883+
0.0, # 0.
2884+
3.0, -3.0, -2j, 3.0+4.0j, -8.0+6.0j
2885+
])
2886+
with np.errstate(invalid='ignore'):
2887+
t1 = np.copysign(a, a.conj())
2888+
isnan1 = np.isnan(t1)
2889+
assert_array_equal(np.isnan(t1.real), np.isnan(t1.imag))
2890+
# If a is not finite, multiplying with a phasor should give NaN.
2891+
# If a == 0, the sign is not defined, and thus the phasor is NaN.
2892+
assert_array_equal(isnan1, ~np.isfinite(a) | (a == 0))
2893+
# Check the rest is correct.
2894+
assert_array_equal(t1[~isnan1], a[~isnan1].conj())
2895+
2896+
# Now check propagating signs of the finite numbers to all others.
2897+
b = a[np.isfinite(a) & (a != 0)][:, np.newaxis]
2898+
with np.errstate(invalid='ignore'):
2899+
t2 = np.copysign(a, b)
2900+
2901+
isnan2 = np.isnan(t2)
2902+
assert_array_equal(np.isnan(t2.real), np.isnan(t2.imag))
2903+
# Only non-finite a should give NaN.
2904+
assert_array_equal(isnan2[0], ~np.isfinite(a))
2905+
# Which is the same for all b.
2906+
assert np.all(isnan2[0] == isnan2)
2907+
# Should get sign of b, except for a=0, when no sign can be assigned.
2908+
expected_sign = np.sign(np.broadcast_to(b, t2.shape))
2909+
expected_sign[:, a == 0] = 0
2910+
assert_array_almost_equal(np.sign(t2)[~isnan2], expected_sign[~isnan2])
2911+
2912+
# Finally, check propagating the possibly ill-defined sign of a to b
2913+
with np.errstate(invalid='ignore'):
2914+
t3 = np.copysign(b, a)
2915+
2916+
isnan3 = np.isnan(t3)
2917+
assert_array_equal(np.isnan(t3.real), np.isnan(t3.imag))
2918+
# Undefined or zero sign should give NaN (a.real or a.imag inf is OK).
2919+
sign_a = np.sign(np.broadcast_to(a, t2.shape))
2920+
assert_array_equal(isnan3, np.isnan(sign_a) | (a == 0))
2921+
assert_array_almost_equal(np.sign(t3)[~isnan3], sign_a[~isnan3])
2922+
2923+
28462924
class TestMinMax:
28472925
def test_minmax_blocked(self):
28482926
# simd tests on max/min, test all alignments, slow but important
@@ -4426,14 +4504,6 @@ def _check_branch_cut(f, x0, dx, re_sign=1, im_sign=-1, sig_zero_ok=False,
44264504
assert_(np.all(np.absolute(y0[ji].real - ym.real*re_sign) < atol), (y0[ji], ym))
44274505
assert_(np.all(np.absolute(y0[ji].imag - ym.imag*im_sign) < atol), (y0[ji], ym))
44284506

4429-
def test_copysign():
4430-
assert_(np.copysign(1, -1) == -1)
4431-
with np.errstate(divide="ignore"):
4432-
assert_(1 / np.copysign(0, -1) < 0)
4433-
assert_(1 / np.copysign(0, 1) > 0)
4434-
assert_(np.signbit(np.copysign(np.nan, -1)))
4435-
assert_(not np.signbit(np.copysign(np.nan, 1)))
4436-
44374507
def _test_nextafter(t):
44384508
one = t(1)
44394509
two = t(2)

0 commit comments

Comments
 (0)
0