8000 ENH: Changes (and tests) to allow exporting half-floats through the b… · numpy/numpy@f9c7bde · GitHub
[go: up one dir, main page]

Skip to content

Commit f9c7bde

Browse files
Eli Stevensmwiebe
Eli Stevens
authored andcommitted
ENH: Changes (and tests) to allow exporting half-floats through the buffer interface. (#1789)
Code: Added NPY_HALF to switch (descr->type_num) in _buffer_format_string. Added 'e' keys to the _pep3118_native_map and _pep3118_standard_map. Tests: Added entries to the generic round-trip tests. Added specialized half-float test that round-trips example values from the wikipedia page. http://en.wikipedia.org/wiki/Half_precision_floating-point_format http://mail.scipy.org/pipermail/numpy-discussion/2011-March/055795.html
1 parent 65b77ee commit f9c7bde

File tree

3 files changed

+32
-6
lines changed

3 files changed

+32
-6
lines changed

numpy/core/_internal.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,7 @@ def _index_fields(ary, fields):
363363
'L': 'L',
364364
'q': 'q',
365365
'Q': 'Q',
366+
'e': 'e',
366367
'f': 'f',
367368
'd': 'd',
368369
'g': 'g',
@@ -388,6 +389,7 @@ def _index_fields(ary, fields):
388389
'L': 'u4',
389390
'q': 'i8',
390391
'Q': 'u8',
392+
'e': 'f2',
391393
'f': 'f',
392394
'd': 'd',
393395
'Zf': 'F',

numpy/core/src/multiarray/buffer.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,7 @@ _buffer_format_string(PyArray_Descr *descr, _tmp_string_t *str,
358358
break;
359359
case NPY_LONGLONG: if (_append_char(str, 'q')) return -1; break;
360360
case NPY_ULONGLONG: if (_append_char(str, 'Q')) return -1; break;
361+
case NPY_HALF: if (_append_char(str, 'e')) return -1; break;
361362
case NPY_FLOAT: if (_append_char(str, 'f')) return -1; break;
362363
case NPY_DOUBLE: if (_append_char(str, 'd')) return -1; break;
363364
case NPY_LONGDOUBLE: if (_append_char(str, 'g')) return -1; break;

numpy/core/tests/test_multiarray.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2008,9 +2008,11 @@ def test_roundtrip(self):
20082008
('l', 'S4'),
20092009
('m', 'U4'),
20102010
('n', 'V3'),
2011-
('o', '?')]
2011+
('o', '?'),
2012+
('p', np.half),
2013+
]
20122014
x = np.array([(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
2013-
asbytes('aaaa'), 'bbbb', asbytes('xxx'), True)],
2015+
asbytes('aaaa'), 'bbbb', asbytes('xxx'), True, 1.0)],
20142016
dtype=dt)
20152017
self._check_roundtrip(x)
20162018

@@ -2042,6 +2044,25 @@ def test_roundtrip(self):
20422044
x = np.array([1,2,3], dtype='<q')
20432045
assert_raises(ValueError, self._check_roundtrip, x)
20442046

2047+
def test_roundtrip_half(self):
2048+
half_list = [
2049+
1.0,
2050+
-2.0,
2051+
6.5504 * 10**4, # (max half precision)
2052+
2**-14, # ~= 6.10352 * 10**-5 (minimum positive normal)
2053+
2**-24, # ~= 5.96046 * 10**-8 (minimum strictly positive subnormal)
2054+
0.0,
2055+
-0.0,
2056+
float('+inf'),
2057+
float('-inf'),
2058+
0.333251953125, # ~= 1/3
2059+
]
2060+
2061+
x = np.array(half_list, dtype='>e')
2062+
self._check_roundtrip(x)
2063+
x = np.array(half_list, dtype='<e')
2064+
self._check_roundtrip(x)
2065+
20452066
def test_export_simple_1d(self):
20462067
x = np.array([1,2,3,4,5], dtype='i')
20472068
y = memoryview(x)
@@ -2092,9 +2113,11 @@ def test_export_record(self):
20922113
('l', 'S4'),
20932114
('m', 'U4'),
20942115
('n', 'V3'),
2095-
('o', '?')]
2116+
('o', '?'),
2117+
('p', np.half),
2118+
]
20962119
x = np.array([(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
2097-
asbytes('aaaa'), 'bbbb', asbytes(' '), True)],
2120+
asbytes('aaaa'), 'bbbb', asbytes(' '), True, 1.0)],
20982121
dtype=dt)
20992122
y = memoryview(x)
21002123
assert_equal(y.shape, (1,))
@@ -2103,9 +2126,9 @@ def test_export_record(self):
21032126

21042127
sz = sum([dtype(b).itemsize for a, b in dt])
21052128
if dtype('l').itemsize == 4:
2106-
assert_equal(y.format, 'T{b:a:=h:b:i:c:l:d:^q:dx:B:e:@H:f:=I:g:L:h:^Q:hx:=f:i:d:j:^g:k:=Zf:ix:Zd:jx:^Zg:kx:4s:l:=4w:m:3x:n:?:o:}')
2129+
assert_equal(y.format, 'T{b:a:=h:b:i:c:l:d:^q:dx:B:e:@H:f:=I:g:L:h:^Q:hx:=f:i:d:j:^g:k:=Zf:ix:Zd:jx:^Zg:kx:4s:l:=4w:m:3x:n:?:o:@e:p:}')
21072130
else:
2108-
assert_equal(y.format, 'T{b:a:=h:b:i:c:q:d:^q:dx:B:e:@H:f:=I:g:Q:h:^Q:hx:=f:i:d:j:^g:k:=Zf:ix:Zd:jx:^Zg:kx:4s:l:=4w:m:3x:n:?:o:}')
2131+
assert_equal(y.format, 'T{b:a:=h:b:i:c:q:d:^q:dx:B:e:@H:f:=I:g:Q:h:^Q:hx:=f:i:d:j:^g:k:=Zf:ix:Zd:jx:^Zg:kx:4s:l:=4w:m:3x:n:?:o:@e:p:}')
21092132
assert_equal(y.strides, (sz,))
21102133
assert_equal(y.itemsize, sz)
21112134

0 commit comments

Comments
 (0)
0