10000 MAINT: Respond to review comments on gh-7473 · eric-wieser/numpy@fca077c · GitHub
[go: up one dir, main page]

Skip to content

Commit fca077c

Browse files
committed
MAINT: Respond to review comments on numpygh-7473
This: * Inlines the macros in loops.c.src * Replaces 8 with `CHAR_BIT `. The `NPY_SIZEOF_*` macros are not used here because it's too much work to apply them to the signed types, and they expand to the same thing anyway. * Removes the reduce loop specializations which likely no one cares about * Uses pytest.mark.parametrize to shorten the test
1 parent ff11d01 commit fca077c

File tree

3 files changed

+77
-129
lines changed

3 files changed

+77
-129
lines changed

numpy/core/src/umath/loops.c.src

Lines changed: 27 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -777,7 +777,7 @@ NPY_NO_EXPORT NPY_GCC_OPT_3 @ATTR@ void
777777
/**begin repeat2
778778
* Arithmetic
779779
* #kind = add, subtract, multiply, bitwise_and, bitwise_or, bitwise_xor#
780-
* #OP = +, -,*, &, |, ^#
780+
* #OP = +, -, *, &, |, ^#
781781
*/
782782

783783
#if @CHK@
@@ -808,98 +808,48 @@ NPY_NO_EXPORT NPY_GCC_OPT_3 @ATTR@ void
808808
* which is undefined in C.
809809
*/
810810

811-
#define LEFT_SHIFT_OP \
812-
do { \
813-
if (NPY_LIKELY(in2 < sizeof(@type@) * 8)) { \
814-
*out = in1 << in2; \
815-
} \
816-
else { \
817-
*out = 0; \
818-
} \
819-
} while (0)
820-
821-
822811
NPY_NO_EXPORT NPY_GCC_OPT_3 void
823812
@TYPE@_left_shift@isa@(char **args, npy_intp *dimensions, npy_intp *steps,
824813
void *NPY_UNUSED(func))
825814
{
826-
if (IS_BINARY_REDUCE) {
827-
BINARY_REDUCE_LOOP(@type@) {
828-
@type@ ip2_val = *(@type@ *)ip2;
829-
830-
if (NPY_LIKELY(ip2_val < sizeof(@type@) * 8)) {
831-
io1 <<= ip2_val;
832-
}
833-
else {
834-
io1 = 0;
835-
}
815+
BINARY_LOOP_FAST(@type@, @type@,
816+
if (NPY_LIKELY(in2 < sizeof(@type@) * CHAR_BIT)) {
817+
*out = in1 << in2;
836818
}
837-
*((@type@ *)iop1) = io1;
838-
}
839-
else {
840-
BINARY_LOOP_FAST(@type@, @type@, LEFT_SHIFT_OP);
841-
}
842-
}
843-
844-
#undef LEFT_SHIFT_OP
845-
846-
#define RIGHT_SHIFT_OP_SIGNED \
847-
do { \
848-
if (NPY_LIKELY(in2 < sizeof(@type@) * 8)) { \
849-
*out = in1 >> in2; \
850-
} \
851-
else if (in1 < 0) { \
852-
*out = -1; \
853-
} \
854-
else { \
855-
*out = 0; \
856-
} \
857-
} while (0)
858-
859-
#define RIGHT_SHIFT_OP_UNSIGNED \
860-
do { \
861-
if (NPY_LIKELY(in2 < sizeof(@type@) * 8)) { \
862-
*out = in1 >> in2; \
863-
} \
864-
else { \
865-
*out = 0; \
866-
} \
867-
} while (0)
819+
else {
820+
*out = 0;
821+
}
822+
);
823+
}
868824

869825
NPY_NO_EXPORT NPY_GCC_OPT_3 void
870826
@TYPE@_right_shift@isa@(char **args, npy_intp *dimensions, npy_intp *steps,
871827
void *NPY_UNUSED(func))
872828
{
873-
if (IS_BINARY_REDUCE) {
874-
BINARY_REDUCE_LOOP(@type@) {
875-
@type@ ip2_val = *(@type@ *)ip2;
876-
877-
if (NPY_LIKELY(ip2_val < sizeof(@type@) * 8)) {
878-
io1 >>= ip2_val;
879-
}
880829
#if @SIGNED@
881-
else if (io1 < 0) {
882-
io1 = -1;
883-
}
884-
#endif
885-
else {
886-
io1 = 0;
887-
}
830+
BINARY_LOOP_FAST(@type@, @type@, {
831+
if (NPY_LIKELY(in2 < sizeof(@type@) * CHAR_BIT)) {
832+
*out = in1 >> in2;
888833
}
889-
*((@type@ *)iop1) = io1;
890-
}
891-
else {
892-
#if @SIGNED@
893-
BINARY_LOOP_FAST(@type@, @type@, RIGHT_SHIFT_OP_SIGNED);
834+
else if (in1 < 0) {
835+
*out = (@type@)-1; /* shift right preserves the sign bit */
836+
}
837+
else {
838+
*out = 0;
839+
}
840+
});
894841
#else
895-
BINARY_LOOP_FAST(@type@, @type@, RIGHT_SHIFT_OP_UNSIGNED);
842+
BINARY_LOOP_FAST(@type@, @type@, {
843+
if (NPY_LIKELY(in2 < sizeof(@type@) * CHAR_BIT)) {
844+
*out = in1 >> in2;
845+
}
846+
else {
847+
*out = 0;
848+
}
849+
});
896850
#endif
897-
}
898851
}
899852

900-
#undef RIGHT_SHIFT_OP_SIGNED
901-
#undef RIGHT_SHIFT_OP_UNSIGNED
902-
903853

904854
/**begin repeat2
905855
* #kind = equal, not_equal, greater, greater_equal, less, less_equal,

numpy/core/src/umath/scalarmath.c.src

Lines changed: 29 additions & 27 deletions
< 10000 td data-grid-cell-id="diff-7355537178d1b6f85a40acf970ac3317b184721c75828f33ed136ea0f6c92ae9-290-291-2" data-line-anchor="diff-7355537178d1b6f85a40acf970ac3317b184721c75828f33ed136ea0f6c92ae9L290" data-selected="false" role="gridcell" style="background-color:var(--diffBlob-deletionLine-bgColor, var(--diffBlob-deletion-bgColor-line));padding-right:24px" tabindex="-1" valign="top" class="focusable-grid-cell diff-text-cell left-side-diff-cell border-right left-side">-
#define @name@_ctype_rshift(arg1, arg2, out) \
Original file line numberDiff line numberDiff line change
@@ -263,38 +263,40 @@ static void
263263

264264
/**end repeat1**/
265265

266-
#define @name@_ctype_lshift(arg1, arg2, out) \
267-
do { \
268-
if (NPY_LIKELY((arg2) < sizeof(@type@) * 8)) { \
269< F438 /td>-
*(out) = (arg1) << (arg2); \
270-
} \
271-
else { \
272-
*(out) = 0; \
273-
} \
266+
/* Note: these need to be kept in sync with the shift ufuncs */
267+
268+
#define @name@_ctype_lshift(arg1, arg2, out) \
269+
do { \
270+
if (NPY_LIKELY((arg2) < sizeof(@type@) * CHAR_BIT)) { \
271+
*(out) = (arg1) << (arg2); \
272+
} \
273+
else { \
274+
*(out) = 0; \
275+
} \
274276
} while (0)
275277

276278
#if @issigned@
277-
#define @name@_ctype_rshift(arg1, arg2, out) \
278-
do { \
279-
if (NPY_LIKELY((arg2) < sizeof(@type@) * 8)) { \
280-
*(out) = (arg1) >> (arg2); \
281-
} \
282-
else if ((arg1) < 0) { \
283-
*(out) = -1; \
284-
} \
285-
else { \
286-
*(out) = 0; \
287-
} \
279+
#define @name@_ctype_rshift(arg1, arg2, out) \
280+
do { \
281+
if (NPY_LIKELY((arg2) < sizeof(@type@) * CHAR_BIT)) { \
282+
*(out) = (arg1) >> (arg2); \
283+
} \
284+
else if ((arg1) < 0) { \
285+
*(out) = -1; \
286+
} \
287+
else { \
288+
*(out) = 0; \
289+
} \
288290
} while (0)
289291
#else
290
291-
do { \
292-
if (NPY_LIKELY((arg2) < sizeof(@type@) * 8)) { \
293-
*(out) = (arg1) >> (arg2); \
294-
} \
295-
else { \
296-
*(out) = 0; \
297-
} \
292+
#define @name@_ctype_rshift(arg1, arg2, out) \
293+
do { \
294+
if (NPY_LIKELY((arg2) < sizeof(@type@) * CHAR_BIT)) { \
295+
*(out) = (arg1) >> (arg2); \
296+
} \
297+
else { \
298+
*(out) = 0; \
299+
} \
298300
} while (0)
299301
#endif
300302

numpy/core/tests/test_scalarmath.py

Lines changed: 21 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -668,31 +668,27 @@ def test_numpy_abs(self):
668668

669669
class TestBitShifts(object):
670670

671-
def test_left_shift(self):
671+
@pytest.mark.parametrize('type_code', np.typecodes['AllInteger'])
672+
@pytest.mark.parametrize('op',
673+
[operator.rshift, operator.lshift], ids=['>>', '<<'])
674+
def test_shift_all_bits(self, type_code, op):
675+
""" Shifts where the shift amount is the width of the type or wider """
672676
# gh-2449
673-
for dt in np.typecodes['AllInteger']:
674-
arr = np.array([5, -5], dtype=dt)
675-
scl_pos, scl_neg = arr
676-
for shift in np.array([arr.dtype.itemsize * 8], dtype=dt):
677-
res_pos = scl_pos << shift
678-
res_neg = scl_neg << shift
679-
assert_equal(res_pos, 0)
680-
assert_equal(res_neg, 0)
681-
# Result on scalars should be the same as on arrays
682-
assert_array_equal(arr << shift, [res_pos, res_neg])
683-
684-
def test_right_shift(self):
685-
# gh-2449
686-
for dt in np.typecodes['AllInteger']:
687-
arr = np.array([5, -5], dtype=dt)
688-
scl_pos, scl_neg = arr
689-
for shift in np.array([arr.dtype.itemsize * 8], dtype=dt):
690-
res_pos = scl_pos >> shift
691-
res_neg = scl_neg >> shift
692-
assert_equal(res_pos, 0)
693-
if dt in np.typecodes['UnsignedInteger']:
694-
assert_equal(res_neg, 0)
677+
dt = np.dtype(type_code)
678+
nbits = dt.itemsize * 8
679+
for val in [5, -5]:
680+
for shift in [nbits, nbits + 4]:
681+
val_scl = dt.type(val)
682+
shift_scl = dt.type(shift)
683+
res_scl = op(val_scl, shift_scl)
684+
if val_scl < 0 and op is operator.rshift:
685+
# sign bit is preserved
686+
assert_equal(res_scl, -1)
695687
else:
696-
assert_equal(res_neg, -1)
688+
assert_equal(res_scl, 0)
689+
697690
# Result on scalars should be the same as on arrays
698-
assert_array_equal(arr >> shift, [res_pos, res_neg], dt)
691+
val_arr = np.array([val]*32, dtype=dt)
692+
shift_arr = np.array([shift]*32, dtype=dt)
693+
res_arr = op(val_arr, shift_arr)
694+
assert_equal(res_arr, res_scl)

0 commit comments

Comments
 (0)
0