8000 Make npy_half a strong type · numpy/numpy@e0d9b39 · GitHub
[go: up one dir, main page]

Skip to content

Commit e0d9b39

Browse files
Make npy_half a strong type
Instead of using a type alias, make npy_half a struct. As a consequence, cleanup npy_half usage to always reference function declared in numpy/halffloat.h. This avoids vodoo incantation in files that should now nothing of npy_half internal. Some of numpy_half manipulating functions are promoted to inline to avoid performance regression.
1 parent 7e7020f commit e0d9b39

File tree

9 files changed

+292
-133
lines changed

9 files changed

+292
-133
lines changed

numpy/core/include/numpy/halffloat.h

Lines changed: 93 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,51 +8,130 @@
88
extern "C" {
99
#endif
1010

11+
/* To keep ABI compatibility with older version of numpy, npy routines that
12+
* manipulates npy_float as a struct are renamed using the convention defined in
13+
* macro NPY_HALF_INTERNAL_API.
14+
* Each symbol is then redefined based on that naming scheme to keep client code
15+
* mostly unchanged at API level
16+
*/
17+
#define NPY_HALF_INTERNAL_API(name) npy_internal_##name
18+
1119
/*
1220
* Half-precision routines
1321
*/
1422

1523
/* Conversions */
24+
#define npy_half_to_float NPY_HALF_INTERNAL_API(half_to_float)
1625
float npy_half_to_float(npy_half h);
26+
27+
#define npy_half_to_double NPY_HALF_INTERNAL_API(half_to_double)
1728
double npy_half_to_double(npy_half h);
29+
30+
#define npy_float_to_half NPY_HALF_INTERNAL_API(float_to_half)
1831
npy_half npy_float_to_half(float f);
32+
33+
#define npy_double_to_half NPY_HALF_INTERNAL_API(double_to_half)
1934
npy_half npy_double_to_half(double d);
35+
2036
/* Comparisons */
37+
#define npy_half_eq NPY_HALF_INTERNAL_API(half_eq)
2138
int npy_half_eq(npy_half h1, npy_half h2);
39+
40+
#define npy_half_ne NPY_HALF_INTERNAL_API(half_ne)
2241
int npy_half_ne(npy_half h1, npy_half h2);
42+
43+
#define npy_half_le NPY_HALF_INTERNAL_API(half_le)
2344
int npy_half_le(npy_half h1, npy_half h2);
45+
46+
#define npy_half_lt NPY_HALF_INTERNAL_API(half_lt)
2447
int npy_half_lt(npy_half h1, npy_half h2);
48+
49+
#define npy_half_ge NPY_HALF_INTERNAL_API(half_ge)
2550
int npy_half_ge(npy_half h1, npy_half h2);
51+
52+
#define npy_half_gt NPY_HALF_INTERNAL_API(half_gt)
2653
int npy_half_gt(npy_half h1, npy_half h2);
54+
2755
/* faster *_nonan variants for when you know h1 and h2 are not NaN */
56+
#define npy_half_eq_nonan NPY_HALF_INTERNAL_API(half_eq_nonan)
2857
int npy_half_eq_nonan(npy_half h1, npy_half h2);
58+
59+
#define npy_half_lt_nonan NPY_HALF_INTERNAL_API(half_lt_nonan)
2960
int npy_half_lt_nonan(npy_half h1, npy_half h2);
61+
62+
#define npy_half_le_nonan NPY_HALF_INTERNAL_API(half_le_nonan)
3063
int npy_half_le_nonan(npy_half h1, npy_half h2);
64+
3165
/* Miscellaneous functions */
32-
int npy_half_iszero(npy_half h);
33-
int npy_half_isnan(npy_half h);
34-
int npy_half_isinf(npy_half h);
35-
int npy_half_isfinite(npy_half h);
36-
int npy_half_signbit(npy_half h);
66+
#define npy_half_copysign NPY_HALF_INTERNAL_API(half_copysign)
3767
npy_half npy_half_copysign(npy_half x, npy_half y);
68+
69+
#define npy_half_spacing NPY_HALF_INTERNAL_API(half_spacing)
3870
npy_half npy_half_spacing(npy_half h);
71+
72+
#define npy_half_nextafter NPY_HALF_INTERNAL_API(half_nextafter)
3973
npy_half npy_half_nextafter(npy_half x, npy_half y);
74+
75+
#define npy_half_divmod NPY_HALF_INTERNAL_API(half_divmod)
4076
npy_half npy_half_divmod(npy_half x, npy_half y, npy_half *modulus);
4177

78+
#define npy_half_iszero NPY_HALF_INTERNAL_API(half_iszero)
79+
NPY_INLINE int npy_half_iszero(npy_half h) {
80+
return (h.bits&0x7fff) == 0;
81+
}
82+
83+
#define npy_half_isnan NPY_HALF_INTERNAL_API(half_isnan)
84+
NPY_INLINE int npy_half_isnan(npy_half h) {
85+
return ((h.bits&0x7c00u) == 0x7c00u) && ((h.bits&0x03ffu) != 0x0000u);
86+
}
87+
88+
#define npy_half_isinf NPY_HALF_INTERNAL_API(half_isinf)
89+
NPY_INLINE int npy_half_isinf(npy_half h) {
90+
return ((h.bits&0x7fffu) == 0x7c00u);
91+
}
92+
93+
#define npy_half_isfinite NPY_HALF_INTERNAL_API(half_isfinite)
94+
NPY_INLINE int npy_half_isfinite(npy_half h) {
95+
return (h.bits&0x7c00u) != 0x7c00u;
96+
}
97+
98+
#define npy_half_signbit NPY_HALF_INTERNAL_API(half_signbit)
99+
NPY_INLINE int npy_half_signbit(npy_half h) {
100+
return (h.bits&0x8000u) != 0;
101+
}
102+
103+
#define npy_half_neg NPY_HALF_INTERNAL_API(half_neg)
104+
NPY_INLINE npy_half npy_half_neg(npy_half h) {
105+
npy_half res = {(npy_uint16)(h.bits^0x8000u)};
106+
return res;
107+
}
108+
109+
#define npy_half_abs NPY_HALF_INTERNAL_API(half_abs)
110+
NPY_INLINE npy_half npy_half_abs(npy_half h) {
111+
npy_half res = {(npy_uint16)(h.bits&0x7fffu)};
112+
return res;
113+
}
114+
115+
#define npy_half_pos NPY_HALF_INTERNAL_API(half_pos)
116+
NPY_INLINE npy_half npy_half_pos(npy_half h) {
117+
npy_half res = {(npy_uint16)(+h.bits)};
118+
return res;
119+
}
120+
42121
/*
43122
* Half-precision constants
44123
*/
45124

46-
#define NPY_HALF_ZERO (0x0000u)
47-
#define NPY_HALF_PZERO (0x0000u)
48-
#define NPY_HALF_NZERO (0x8000u)
49-
#define NPY_HALF_ONE (0x3c00u)
50-
#define NPY_HALF_NEGONE (0xbc00u)
51-
#define NPY_HALF_PINF (0x7c00u)
52-
#define NPY_HALF_NINF (0xfc00u)
53-
#define NPY_HALF_NAN (0x7e00u)
125+
#define NPY_HALF_ZERO (npy_half){0x0000u}
126+
#define NPY_HALF_PZERO (npy_half){0x0000u}
127+
#define NPY_HALF_NZERO (npy_half){0x8000u}
128+
#define NPY_HALF_ONE (npy_half){0x3c00u}
129+
#define NPY_HALF_NEGONE (npy_half){0xbc00u}
130+
#define NPY_HALF_PINF (npy_half){0x7c00u}
131+
#define NPY_HALF_NINF (npy_half){0xfc00u}
132+
#define NPY_HALF_NAN (npy_half){0x7e00u}
54133

55-
#define NPY_MAX_HALF (0x7bffu)
134+
#define NPY_MAX_HALF (npy_half){0x7bffu}
56135

57136
/*
58137
* Bit-level conversions

numpy/core/include/numpy/npy_common.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1030,7 +1030,8 @@ typedef struct { npy_longdouble real, imag; } npy_clongdouble;
10301030

10311031
/* half/float16 isn't a floating-point type in C */
10321032
#define NPY_FLOAT16 NPY_HALF
1033-
typedef npy_uint16 npy_half;
1033+
typedef npy_uint16 npy_half_bits_t;
1034+
typedef struct npy_half { npy_half_bits_t bits;} npy_half;
10341035
typedef npy_half npy_float16;
10351036

10361037
#if NPY_BITSOF_LONGDOUBLE == 32

numpy/core/src/multiarray/arraytypes.c.src

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -219,11 +219,11 @@ static PyObject *
219219

220220
if ((ap == NULL) || PyArray_ISBEHAVED_RO(ap)) {
221221
t1 = *((@type@ *)ip);
222-
return @func1@((@type1@)t1);
222+
return @func1@(t1);
223223
}
224224
else {
225225
PyArray_DESCR(ap)->f->copyswap(&t1, ip, PyArray_ISBYTESWAPPED(ap), ap);
226-
return @func1@((@type1@)t1);
226+
return @func1@(t1);
227227
}
228228
}
229229

@@ -237,7 +237,7 @@ static int
237237
temp = PyArrayScalar_VAL(op, @kind@);
238238
}
239239
else {
240-
temp = (@type@)@func2@(op);
240+
temp = @func2@(op);
241241
}
242242
if (PyErr_Occurred()) {
243243
PyObject *type, *value, *traceback;
@@ -1252,7 +1252,7 @@ static void
12521252
npy_half *op = output;
12531253

12541254
while (n--) {
1255-
*op++ = npy_@name@bits_to_halfbits(*ip);
1255+
*op++ = (npy_half){npy_@name@bits_to_halfbits(*ip)};
12561256
#if @iscomplex@
12571257
ip += 2;
12581258
#else
@@ -1269,7 +1269,7 @@ HALF_to_@TYPE@(void *input, void *output, npy_intp n,
12691269
@itype@ *op = output;
12701270

12711271
while (n--) {
1272-
*op++ = npy_halfbits_to_@name@bits(*ip++);
1272+
*op++ = npy_halfbits_to_@name@bits((*ip++).bits);
12731273
#if @iscomplex@
12741274
*op++ = 0;
12751275
#endif
@@ -1383,7 +1383,7 @@ BOOL_to_@TOTYPE@(void *input, void *output, npy_intp n,
13831383
@totype@ *op = output;
13841384

13851385
while (n--) {
1386-
*op++ = (@totype@)((*ip++ != NPY_FALSE) ? @one@ : @zero@);
1386+
*op++ = (*ip++ != NPY_FALSE) ? @one@ : @zero@;
13871387
}
13881388
}
13891389
/**end repeat**/

numpy/core/src/multiarray/dragon4.c

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2216,7 +2216,7 @@ Dragon4_PrintFloat_IEEE_binary16(
22162216
const npy_uint32 bufferSize = sizeof(scratch->repr);
22172217
BigInt *bigints = scratch->bigints;
22182218

2219-
npy_uint16 val = *value;
2219+
npy_uint16 bits = value->bits;
22202220
npy_uint32 floatExponent, floatMantissa, floatSign;
22212221

22222222
npy_uint32 mantissa;
@@ -2226,9 +2226,9 @@ Dragon4_PrintFloat_IEEE_binary16(
22262226
char signbit = '\0';
22272227

22282228
/* deconstruct the floating point value */
2229-
floatMantissa = val & bitmask_u32(10);
2230-
floatExponent = (val >> 10) & bitmask_u32(5);
2231-
floatSign = val >> 15;
2229+
floatMantissa = bits & bitmask_u32(10);
2230+
floatExponent = (bits >> 10) & bitmask_u32(5);
2231+
floatSign = bits >> 15;
22322232

22332233
/* output the sign */
22342234
if (floatSign != 0) {

numpy/core/src/multiarray/lowlevel_strided_loops.c.src

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -802,9 +802,9 @@ NPY_NO_EXPORT PyArrayMethod_StridedLoop *
802802
#if @is_half1@
803803

804804
# if @is_float2@
805-
# define _CONVERT_FN(x) npy_halfbits_to_floatbits(x)
805+
# define _CONVERT_FN(x) npy_halfbits_to_floatbits((x).bits)
806806
# elif @is_double2@
807-
# define _CONVERT_FN(x) npy_halfbits_to_doublebits(x)
807+
# define _CONVERT_FN(x) npy_halfbits_to_doublebits((x).bits)
808808
# elif @is_half2@
809809
# define _CONVERT_FN(x) (x)
810810
# elif @is_bool2@
@@ -816,9 +816,9 @@ NPY_NO_EXPORT PyArrayMethod_StridedLoop *
816816
#elif @is_half2@
817817

818818
# if @is_float1@
819-
# define _CONVERT_FN(x) npy_floatbits_to_halfbits(x)
819+
# define _CONVERT_FN(x) (npy_half){npy_floatbits_to_halfbits(x)}
820820
# elif @is_double1@
821-
# define _CONVERT_FN(x) npy_doublebits_to_halfbits(x)
821+
# define _CONVERT_FN(x) (npy_half){npy_doublebits_to_halfbits(x)}
822822
# elif @is_half1@
823823
# define _CONVERT_FN(x) (x)
824824
# elif @is_bool1@

0 commit comments

Comments
 (0)
0