8000 ENH: Improve Floating Point Cast Performance on ARM (#28769) · r-devulap/numpy@d692fbc · GitHub
[go: up one dir, main page]

Skip to content

Commit d692fbc

Browse files
authored
ENH: Improve Floating Point Cast Performance on ARM (numpy#28769)
* WIP,Prototype: Use Neon SIMD to improve half->float cast performance [ci skip] [skip ci] * Support Neon SIMD float32->float16 cast and update scalar path to use hardware cast * Add missing header * Relax VECTOR_ARITHMETIC check and add comment on need for SIMD routines * Enable hardware cast on x86 when F16C is available * Relax fp exceptions in Clang to enable vectorization for cast * Ignore fp exceptions only for float casts * Fix build * Attempt to fix test failure on ARM64 native * Work around gcc bug for double->half casts * Add release note
1 parent 1672b69 commit d692fbc

File tree

2 files changed

+66
-13
lines changed

2 files changed

+66
-13
lines changed
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
Performance improvements for ``np.float16`` casts
2+
--------------------------------------------------
3+
Earlier, floating point casts to and from ``np.float16`` types
4+
were emulated in software on all platforms.
5+
6+
Now, on ARM devices that support Neon float16 intrinsics (such as
7+
recent Apple Silicon), the native float16 path is used to achieve
8+
the best performance.

numpy/_core/src/multiarray/lowlevel_strided_loops.c.src

Lines changed: 58 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -708,6 +708,16 @@ NPY_NO_EXPORT PyArrayMethod_StridedLoop *
708708

709709
/************* STRIDED CASTING SPECIALIZED FUNCTIONS *************/
710710

711+
#if defined(NPY_HAVE_NEON_FP16)
712+
#define EMULATED_FP16 0
713+
#define NATIVE_FP16 1
714+
typedef _Float16 _npy_half;
715+
#else
716+
#define EMULATED_FP16 1
717+
#define NATIVE_FP16 0
718+
typedef npy_half _npy_half;
719+
#endif
720+
711721
/**begin repeat
712722
*
713723
* #NAME1 = BOOL,
@@ -723,15 +733,16 @@ NPY_NO_EXPORT PyArrayMethod_StridedLoop *
723733
* #type1 = npy_bool,
724734
* npy_ubyte, npy_ushort, npy_uint, npy_ulong, npy_ulonglong,
725735
* npy_byte, npy_short, npy_int, npy_long, npy_longlong,
726-
* npy_half, npy_float, npy_double, npy_longdouble,
736+
* _npy_half, npy_float, npy_double, npy_longdouble,
727737
* npy_cfloat, npy_cdouble, npy_clongdouble#
728738
* #rtype1 = npy_bool,
729739
* npy_ubyte, npy_ushort, npy_uint, npy_ulong, npy_ulonglong,
730740
* npy_byte, npy_short, npy_int, npy_long, npy_longlong,
731-
* npy_half, npy_float, npy_double, npy_longdouble,
741+
* _npy_half, npy_float, npy_double, npy_longdouble,
732742
* npy_float, npy_double, npy_longdouble#
733743
* #is_bool1 = 1, 0*17#
734-
* #is_half1 = 0*11, 1, 0*6#
744+
* #is_emu_half1 = 0*11, EMULATED_FP16, 0*6#
745+
* #is_native_half1 = 0*11, NATIVE_FP16, 0*6#
735746
* #is_float1 = 0*12, 1, 0, 0, 1, 0, 0#
736747
* #is_double1 = 0*13, 1, 0, 0, 1, 0#
737748
* #is_complex1 = 0*15, 1*3#
@@ -752,15 +763,16 @@ NPY_NO_EXPORT PyArrayMethod_StridedLoop *
752763
* #type2 = npy_bool,
753764
* npy_ubyte, npy_ushort, npy_uint, npy_ulong, npy_ulonglong,
754765
* npy_byte, npy_short, npy_int, npy_long, npy_longlong,
755-
* npy_half, npy_float, npy_double, npy_longdouble,
766+
* _npy_half, npy_float, npy_double, npy_longdouble,
756767
* npy_cfloat, npy_cdouble, npy_clongdouble#
757768
* #rtype2 = npy_bool,
758769
* npy_ubyte, npy_ushort, npy_uint, npy_ulong, npy_ulonglong,
759770
* npy_byte, npy_short, npy_int, npy_long, npy_longlong,
760-
* npy_half, npy_float, npy_double, npy_longdouble,
771+
* _npy_half, npy_float, npy_double, npy_longdouble,
761772
* npy_float, npy_double, npy_longdouble#
762773
* #is_bool2 = 1, 0*17#
763-
* #is_half2 = 0*11, 1, 0*6#
774+
* #is_emu_half2 = 0*11, EMULATED_FP16, 0*6#
775+
* #is_native_half2 = 0*11, NATIVE_FP16, 0*6#
764776
* #is_float2 = 0*12, 1, 0, 0, 1, 0, 0#
765777
* #is_double2 = 0*13, 1, 0, 0, 1, 0#
766778
* #is_complex2 = 0*15, 1*3#
@@ -774,8 +786,8 @@ NPY_NO_EXPORT PyArrayMethod_StridedLoop *
774786

775787
#if !(NPY_USE_UNALIGNED_ACCESS && !@aligned@)
776788

777-
/* For half types, don't use actual double/float types in conversion */
778-
#if @is_half1@ || @is_half2@
789+
/* For emulated half types, don't use actual double/float types in conversion */
790+
#if @is_emu_half1@ || @is_emu_half2@
779791

780792
# if @is_float1@
781793
# define _TYPE1 npy_uint32
@@ -801,27 +813,27 @@ NPY_NO_EXPORT PyArrayMethod_StridedLoop *
801813
#endif
802814

803815
/* Determine an appropriate casting conversion function */
804-
#if @is_half1@
816+
#if @is_emu_half1@
805817

806818
# if @is_float2@
807819
# define _CONVERT_FN(x) npy_halfbits_to_floatbits(x)
808820
# elif @is_double2@
809821
# define _CONVERT_FN(x) npy_halfbits_to_doublebits(x)
810-
# elif @is_half2@
822+
# elif @is_emu_half2@
811823
# define _CONVERT_FN(x) (x)
812824
# elif @is_bool2@
813825
# define _CONVERT_FN(x) ((npy_bool)!npy_half_iszero(x))
814826
# else
815827
# define _CONVERT_FN(x) ((_TYPE2)npy_half_to_float(x))
816828
# endif
817829

818-
#elif @is_half2@
830+
#elif @is_emu_half2@
819831

820832
# if @is_float1@
821833
# define _CONVERT_FN(x) npy_floatbits_to_halfbits(x)
822834
# elif @is_double1@
823835
# define _CONVERT_FN(x) npy_doublebits_to_halfbits(x)
824-
# elif @is_half1@
836+
# elif @is_emu_half1@
825837
# define _CONVERT_FN(x) (x)
826838
# elif @is_bool1@
827839
# define _CONVERT_FN(x) npy_float_to_half((float)(x!=0))
@@ -839,7 +851,29 @@ NPY_NO_EXPORT PyArrayMethod_StridedLoop *
839851

840852
#endif
841853

842-
static NPY_GCC_OPT_3 int
854+
// Enable auto-vectorization for floating point casts with clang
855+
#if @is_native_half1@ || @is_float1@ || @is_double1@
856+
#if @is_native_half2@ || @is_float2@ || @is_double2@
857+
#if defined(__clang__) && !defined(__EMSCRIPTEN__)
858+
#if __clang_major__ >= 12
859+
_Pragma("clang fp exceptions(ignore)")
860+
#endif
861+
#endif
862+
#endif
863+
#endif
864+
865+
// Work around GCC bug for double->half casts. For SVE and
866+
// OPT_LEVEL > 1, it implements this as double->single->half
867+
// which is incorrect as it introduces double rounding with
868+
// narrowing casts.
869+
#if (@is_double1@ && @is_native_half2@) && \
870+
defined(NPY_HAVE_SVE) && defined(__GNUC__)
871+
#define GCC_CAST_OPT_LEVEL __attribute__((optimize("O1")))
872+
#else
873+
#define GCC_CAST_OPT_LEVEL NPY_GCC_OPT_3
874+
#endif
875+
876+
static GCC_CAST_OPT_LEVEL int
843877
@prefix@_cast_@name1@_to_@name2@(
844878
PyArrayMethod_Context *context, char *const *args,
845879
const npy_intp *dimensions, const npy_intp *strides,
@@ -933,6 +967,17 @@ static NPY_GCC_OPT_3 int
933967
return 0;
934968
}
935969

970+
#if @is_native_half1@ || @is_float1@ || @is_double1@
971+
#if @is_native_half2@ || @is_float2@ || @is_double2@
972+
#if defined(__clang__) && !defined(__EMSCRIPTEN__)
973+
#if __clang_major__ >= 12
974+
_Pragma("clang fp exceptions(strict)")
975+
#endif
976+
#endif
977+
#endif
978+
#endif
979+
980+
#undef GCC_CAST_OPT_LEVEL
936981
#undef _CONVERT_FN
937982
#undef _TYPE2
938983
#undef _TYPE1

0 commit comments

Comments
 (0)
0