8000 Merge pull request #13739 from eric-wieser/bit_shifts · numpy/numpy@31ffdec · GitHub
[go: up one dir, main page]

Skip to content

Commit 31ffdec

Browse files
authored
Merge pull request #13739 from eric-wieser/bit_shifts
BUG: Don't produce undefined behavior for a << b if b >= bitsof(a)
2 parents 79cb45d + 6cf6ece commit 31ffdec

File tree

7 files changed

+188
-10
lines changed

7 files changed

+188
-10
lines changed

numpy/core/include/numpy/npy_math.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,28 @@ NPY_INPLACE npy_long npy_lcml(npy_long a, npy_long b);
177177
NPY_INPLACE npy_longlong npy_gcdll(npy_longlong a, npy_longlong b);
178178
NPY_INPLACE npy_longlong npy_lcmll(npy_longlong a, npy_longlong b);
179179

180+
NPY_INPLACE npy_ubyte npy_rshiftuhh(npy_ubyte a, npy_ubyte b);
181+
NPY_INPLACE npy_ubyte npy_lshiftuhh(npy_ubyte a, npy_ubyte b);
182+
NPY_INPLACE npy_ushort npy_rshiftuh(npy_ushort a, npy_ushort b);
183+
NPY_INPLACE npy_ushort npy_lshiftuh(npy_ushort a, npy_ushort b);
184+
NPY_INPLACE npy_uint npy_rshiftu(npy_uint a, npy_uint b);
185+
NPY_INPLACE npy_uint npy_lshiftu(npy_uint a, npy_uint b);
186+
NPY_INPLACE npy_ulong npy_rshiftul(npy_ulong a, npy_ulong b);
187+
NPY_INPLACE npy_ulong npy_lshiftul(npy_ulong a, npy_ulong b);
188+
NPY_INPLACE npy_ulonglong npy_rshiftull(npy_ulonglong a, npy_ulonglong b);
189+
NPY_INPLACE npy_ulonglong npy_lshiftull(npy_ulonglong a, npy_ulonglong b);
190+
191+
NPY_INPLACE npy_byte npy_rshifthh(npy_byte a, npy_byte b);
192+
NPY_INPLACE npy_byte npy_lshifthh(npy_byte a, npy_byte b);
193+
NPY_INPLACE npy_short npy_rshifth(npy_short a, npy_short b);
194+
NPY_INPLACE npy_short npy_lshifth(npy_short a, npy_short b);
195+
NPY_INPLACE npy_int npy_rshift(npy_int a, npy_int b);
196+
NPY_INPLACE npy_int npy_lshift(npy_int a, npy_int b);
197+
NPY_INPLACE npy_long npy_rshiftl(npy_long a, npy_long b);
198+
NPY_INPLACE npy_long npy_lshiftl(npy_long a, npy_long b);
199+
NPY_INPLACE npy_longlong npy_rshiftll(npy_longlong a, npy_longlong b);
200+
NPY_INPLACE npy_longlong npy_lshiftll(npy_longlong a, npy_longlong b);
201+
180202
/*
181203
* avx function has a common API for both sin & cos. This enum is used to
182204
* distinguish between the two

numpy/core/setup.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -463,6 +463,12 @@ def generate_config_h(ext, build_dir):
463463
rep = check_long_double_representation(config_cmd)
464464
moredefs.append(('HAVE_LDOUBLE_%s' % rep, 1))
465465

466+
if check_for_right_shift_internal_compiler_error(config_cmd):
467+
moredefs.append('NPY_DO_NOT_OPTIMIZE_LONG_right_shift')
468+
moredefs.append('NPY_DO_NOT_OPTIMIZE_ULONG_right_shift')
469+
moredefs.append('NPY_DO_NOT_OPTIMIZE_LONGLONG_right_shift')
470+
moredefs.append('NPY_DO_NOT_OPTIMIZE_ULONGLONG_right_shift')
471+
466472
# Py3K check
467473
if sys.version_info[0] >= 3:
468474
moredefs.append(('NPY_PY3K', 1))

numpy/core/setup_common.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import warnings
66
import copy
77
import binascii
8+
import textwrap
89

910
from numpy.distutils.misc_util import mingw32
1011

@@ -415,3 +416,41 @@ def long_double_representation(lines):
415416
else:
416417
# We never detected the after_sequence
417418
raise ValueError("Could not lock sequences (%s)" % saw)
419+
420+
421+
def check_for_right_shift_internal_compiler_error(cmd):
422+
"""
423+
On our arm CI, this fails with an internal compilation error
424+
425+
The failure looks like the following, and can be reproduced on ARM64 GCC 5.4:
426+
427+
<source>: In function 'right_shift':
428+
<source>:4:20: internal compiler error: in expand_shift_1, at expmed.c:2349
429+
ip1[i] = ip1[i] >> in2;
430+
^
431+
Please submit a full bug report,
432+
with preprocessed source if appropriate.
433+
See <http://gcc.gnu.org/bugs.html> for instructions.
434+
Compiler returned: 1
435+
436+
This function returns True if this compiler bug is present, and we need to
437+
turn off optimization for the function
438+
"""
439+
cmd._check_compiler()
440+
has_optimize = cmd.try_compile(textwrap.dedent("""\
441+
__attribute__((optimize("O3"))) void right_shift() {}
442+
"""), None, None)
443+
if not has_optimize:
444+
return False
445+
446+
no_err = cmd.try_compile(textwrap.dedent("""\
447+
typedef long the_type; /* fails also for unsigned and long long */
448+
__attribute__((optimize("O3"))) void right_shift(the_type in2, the_type *ip1, int n) {
449+
for (int i = 0; i < n; i++) {
450+
if (in2 < (the_type)sizeof(the_type) * 8) {
451+
ip1[i] = ip1[i] >> in2;
452+
}
453+
}
454+
}
455+
"""), None, None)
456+
return not no_err

numpy/core/src/npymath/npy_math_internal.h.src

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -716,3 +716,44 @@ npy_@func@@c@(@type@ a, @type@ b)
716716
return npy_@func@u@c@(a < 0 ? -a : a, b < 0 ? -b : b);
717717
}
718718
/**end repeat**/
719+
720+
/* Unlike LCM and GCD, we need byte and short variants for the shift operators,
721+
* since the result is dependent on the width of the type
722+
*/
723+
/**begin repeat
724+
*
725+
* #type = byte, short, int, long, longlong#
726+
* #c = hh,h,,l,ll#
727+
*/
728+
/**begin repeat1
729+
*
730+
* #u = u,#
731+
* #is_signed = 0,1#
732+
*/
733+
NPY_INPLACE npy_@u@@type@
734+
npy_lshift@u@@c@(npy_@u@@type@ a, npy_@u@@type@ b)
735+
{
736+
if (NPY_LIKELY((size_t)b < sizeof(a) * CHAR_BIT)) {
737+
return a << b;
738+
}
739+
else {
740+
return 0;
741+
}
742+
}
743+
NPY_INPLACE npy_@u@@type@
744+
npy_rshift@u@@c@(npy_@u@@type@ a, npy_@u@@type@ b)
745+
{
746+
if (NPY_LIKELY((size_t)b < sizeof(a) * CHAR_BIT)) {
747+
return a >> b;
748+
}
749+
#if @is_signed@
750+
else if (a < 0) {
751+
return (npy_@u@@type@)-1; /* preserve the sign bit */
752+
}
753+
#endif
754+
else {
755+
return 0;
756+
}
757+
}
758+
/**end repeat1**/
759+
/**end repeat**/

numpy/core/src/umath/loops.c.src

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -699,6 +699,7 @@ BOOL_@kind@(char **args, npy_intp *dimensions, npy_intp *steps, void *NPY_UNUSED
699699
* #ftype = npy_float, npy_float, npy_float, npy_float, npy_double, npy_double,
700700
* npy_double, npy_double, npy_double, npy_double#
701701
* #SIGNED = 1, 0, 1, 0, 1, 0, 1, 0, 1, 0#
702+
* #c = hh,uhh,h,uh,,u,l,ul,ll,ull#
702703
*/
703704

704705
#define @TYPE@_floor_divide @TYPE@_divide
@@ -776,16 +777,15 @@ NPY_NO_EXPORT NPY_GCC_OPT_3 @ATTR@ void
776777

777778
/**begin repeat2
778779
* Arithmetic
779-
* #kind = add, subtract, multiply, bitwise_and, bitwise_or, bitwise_xor,
780-
* left_shift, right_shift#
781-
* #OP = +, -,*, &, |, ^, <<, >>#
780+
* #kind = add, subtract, multiply, bitwise_and, bitwise_or, bitwise_xor#
781+
* #OP = +, -, *, &, |, ^#
782782
*/
783783

784784
#if @CHK@
785785
NPY_NO_EXPORT NPY_GCC_OPT_3 @ATTR@ void
786786
@TYPE@_@kind@@isa@(char **args, npy_intp *dimensions, npy_intp *steps, void *NPY_UNUSED(func))
787787
{
788-
if(IS_BINARY_REDUCE) {
788+
if (IS_BINARY_REDUCE) {
789789
BINARY_REDUCE_LOOP(@type@) {
790790
io1 @OP@= *(@type@ *)ip2;
791791
}
@@ -799,6 +799,47 @@ NPY_NO_EXPORT NPY_GCC_OPT_3 @ATTR@ void
799799

800800
/**end repeat2**/
801801

802+
/*
803+
* Arithmetic bit shift operations.
804+
*
805+
* Intel hardware masks bit shift values, so large shifts wrap around
806+
* and can produce surprising results. The special handling ensures that
807+
* behavior is independent of compiler or hardware.
808+
* TODO: We could implement consistent behavior for negative shifts,
809+
* which is undefined in C.
810+
*/
811+
812+
#define INT_left_shift_needs_clear_floatstatus
813+
#define UINT_left_shift_needs_clear_floatstatus
814+
815+
NPY_NO_EXPORT NPY_GCC_OPT_3 void
816+
@TYPE@_left_shift@isa@(char **args, npy_intp *dimensions, npy_intp *steps,
817+
void *NPY_UNUSED(func))
818+
{
819+
BINARY_LOOP_FAST(@type@, @type@, *out = npy_lshift@c@(in1, in2));
820+
821+
#ifdef @TYPE@_left_shift_needs_clear_floatstatus
822+
// For some reason, our macOS CI sets an "invalid" flag here, but only
823+
// for some types.
824+
npy_clear_floatstatus_barrier((char*)dimensions);
825+
#endif
826+
}
827+
828+
#undef INT_left_shift_needs_clear_floatstatus
829+
#undef UINT_left_shift_needs_clear_floatstatus
830+
831+
NPY_NO_EXPORT
832+
#ifndef NPY_DO_NOT_OPTIMIZE_@TYPE@_right_shift
833+
NPY_GCC_OPT_3
834+
#endif
835+
void
836+
@TYPE@_right_shift@isa@(char **args, npy_intp *dimensions, npy_intp *steps,
837+
void *NPY_UNUSED(func))
838+
{
839+
BINARY_LOOP_FAST(@type@, @type@, *out = npy_rshift@c@(in1, in2));
840+
}
841+
842+
802843
/**begin repeat2
803844
* #kind = equal, not_equal, greater, greater_equal, less, less_equal,
804845
* logical_and, logical_or#

numpy/core/src/umath/scalarmath.c.src

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -246,25 +246,26 @@ static void
246246
/**end repeat**/
247247

248248

249-
250-
/* QUESTION: Should we check for overflow / underflow in (l,r)shift? */
251-
252249
/**begin repeat
253250
* #name = byte, ubyte, short, ushort, int, uint,
254251
* long, ulong, longlong, ulonglong#
255252
* #type = npy_byte, npy_ubyte, npy_short, npy_ushort, npy_int, npy_uint,
256253
* npy_long, npy_ulong, npy_longlong, npy_ulonglong#
254+
* #suffix = hh,uhh,h,uh,,u,l,ul,ll,ull#
257255
*/
258256

259257
/**begin repeat1
260-
* #oper = and, xor, or, lshift, rshift#
261-
* #op = &, ^, |, <<, >>#
258+
* #oper = and, xor, or#
259+
* #op = &, ^, |#
262260
*/
263261

264262
#define @name@_ctype_@oper@(arg1, arg2, out) *(out) = (arg1) @op@ (arg2)
265263

266264
/**end repeat1**/
267265

266+
#define @name@_ctype_lshift(arg1, arg2, out) *(out) = npy_lshift@suffix@(arg1, arg2)
267+
#define @name@_ctype_rshift(arg1, arg2, out) *(out) = npy_rshift@suffix@(arg1, arg2)
268+
268269
/**end repeat**/
269270

270271
/**begin repeat
@@ -570,7 +571,7 @@ static void
570571
* 1) Convert the types to the common type if both are scalars (0 return)
571572
* 2) If both are not scalars use ufunc machinery (-2 return)
572573
* 3) If both are scalars but cannot be cast to the right type
573-
* return NotImplmented (-1 return)
574+
* return NotImplemented (-1 return)
574575
*
575576
* 4) Perform the function on the C-type.
576577
* 5) If an error condition occurred, check to see

numpy/core/tests/test_scalarmath.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -664,3 +664,31 @@ def test_builtin_abs(self):
664664

665665
def test_numpy_abs(self):
666666
self._test_abs_func(np.abs)
667+
668+
669+
class TestBitShifts(object):
670+
671+
@pytest.mark.parametrize('type_code', np.typecodes['AllInteger'])
672+
@pytest.mark.parametrize('op',
673+
[operator.rshift, operator.lshift], ids=['>>', '<<'])
674+
def test_shift_all_bits(self, type_code, op):
675+
""" Shifts where the shift amount is the width of the type or wider """
676+
# gh-2449
677+
dt = np.dtype(type_code)
678+
nbits = dt.itemsize * 8
679+
for val in [5, -5]:
680+
for shift in [nbits, nbits + 4]:
681+
val_scl = dt.type(val)
682+
shift_scl = dt.type(shift)
683+
res_scl = op(val_scl, shift_scl)
684+
if val_scl < 0 and op is operator.rshift:
685+
# sign bit is preserved
686+
assert_equal(res_scl, -1)
687+
else:
688+
assert_equal(res_scl, 0)
689+
690+
# Result on scalars should be the same as on arrays
691+
val_arr = np.array([val]*32, dtype=dt)
692+
shift_arr = np.array([shift]*32, dtype=dt)
693+
res_arr = op(val_arr, shift_arr)
694+
assert_equal(res_arr, res_scl)

0 commit comments

Comments
 (0)
0