@@ -708,6 +708,16 @@ NPY_NO_EXPORT PyArrayMethod_StridedLoop *
708
708
709
709
/************* STRIDED CASTING SPECIALIZED FUNCTIONS *************/
710
710
711
+ #if defined(NPY_HAVE_NEON_FP16)
712
+ #define EMULATED_FP16 0
713
+ #define NATIVE_FP16 1
714
+ typedef _Float16 _npy_half;
715
+ #else
716
+ #define EMULATED_FP16 1
717
+ #define NATIVE_FP16 0
718
+ typedef npy_half _npy_half;
719
+ #endif
720
+
711
721
/**begin repeat
712
722
*
713
723
* #NAME1 = BOOL,
@@ -723,15 +733,16 @@ NPY_NO_EXPORT PyArrayMethod_StridedLoop *
723
733
* #type1 = npy_bool,
724
734
* npy_ubyte, npy_ushort, npy_uint, npy_ulong, npy_ulonglong,
725
735
* npy_byte, npy_short, npy_int, npy_long, npy_longlong,
726
- * npy_half , npy_float, npy_double, npy_longdouble,
736
+ * _npy_half , npy_float, npy_double, npy_longdouble,
727
737
* npy_cfloat, npy_cdouble, npy_clongdouble#
728
738
* #rtype1 = npy_bool,
729
739
* npy_ubyte, npy_ushort, npy_uint, npy_ulong, npy_ulonglong,
730
740
* npy_byte, npy_short, npy_int, npy_long, npy_longlong,
731
- * npy_half , npy_float, npy_double, npy_longdouble,
741
+ * _npy_half , npy_float, npy_double, npy_longdouble,
732
742
* npy_float, npy_double, npy_longdouble#
733
743
* #is_bool1 = 1, 0*17#
734
- * #is_half1 = 0*11, 1, 0*6#
744
+ * #is_emu_half1 = 0*11, EMULATED_FP16, 0*6#
745
+ * #is_native_half1 = 0*11, NATIVE_FP16, 0*6#
735
746
* #is_float1 = 0*12, 1, 0, 0, 1, 0, 0#
736
747
* #is_double1 = 0*13, 1, 0, 0, 1, 0#
737
748
* #is_complex1 = 0*15, 1*3#
@@ -752,15 +763,16 @@ NPY_NO_EXPORT PyArrayMethod_StridedLoop *
752
763
* #type2 = npy_bool,
753
764
* npy_ubyte, npy_ushort, npy_uint, npy_ulong, npy_ulonglong,
754
765
* npy_byte, npy_short, npy_int, npy_long, npy_longlong,
755
- * npy_half , npy_float, npy_double, npy_longdouble,
766
+ * _npy_half , npy_float, npy_double, npy_longdouble,
756
767
* npy_cfloat, npy_cdouble, npy_clongdouble#
757
768
* #rtype2 = npy_bool,
758
769
* npy_ubyte, npy_ushort, npy_uint, npy_ulong, npy_ulonglong,
759
770
* npy_byte, npy_short, npy_int, npy_long, npy_longlong,
760
- * npy_half , npy_float, npy_double, npy_longdouble,
771
+ * _npy_half , npy_float, npy_double, npy_longdouble,
761
772
* npy_float, npy_double, npy_longdouble#
762
773
* #is_bool2 = 1, 0*17#
763
- * #is_half2 = 0*11, 1, 0*6#
774
+ * #is_emu_half2 = 0*11, EMULATED_FP16, 0*6#
775
+ * #is_native_half2 = 0*11, NATIVE_FP16, 0*6#
764
776
* #is_float2 = 0*12, 1, 0, 0, 1, 0, 0#
765
777
* #is_double2 = 0*13, 1, 0, 0, 1, 0#
766
778
* #is_complex2 = 0*15, 1*3#
@@ -774,8 +786,8 @@ NPY_NO_EXPORT PyArrayMethod_StridedLoop *
774
786
775
787
#if !(NPY_USE_UNALIGNED_ACCESS && !@aligned@)
776
788
777
- /* For half types, don't use actual double/float types in conversion */
778
- #if @is_half1 @ || @is_half2 @
789
+ /* For emulated half types, don't use actual double/float types in conversion */
790
+ #if @is_emu_half1 @ || @is_emu_half2 @
779
791
780
792
# if @is_float1@
781
793
# define _TYPE1 npy_uint32
@@ -801,27 +813,27 @@ NPY_NO_EXPORT PyArrayMethod_StridedLoop *
801
813
#endif
802
814
803
815
/* Determine an appropriate casting conversion function */
804
- #if @is_half1 @
816
+ #if @is_emu_half1 @
805
817
806
818
# if @is_float2@
807
819
# define _CONVERT_FN(x) npy_halfbits_to_floatbits(x)
808
820
# elif @is_double2@
809
821
# define _CONVERT_FN(x) npy_halfbits_to_doublebits(x)
810
- # elif @is_half2 @
822
+ # elif @is_emu_half2 @
811
823
# define _CONVERT_FN(x) (x)
812
824
# elif @is_bool2@
813
825
# define _CONVERT_FN(x) ((npy_bool)!npy_half_iszero(x))
814
826
# else
815
827
# define _CONVERT_FN(x) ((_TYPE2)npy_half_to_float(x))
816
828
# endif
817
829
818
- #elif @is_half2 @
830
+ #elif @is_emu_half2 @
819
831
820
832
# if @is_float1@
821
833
# define _CONVERT_FN(x) npy_floatbits_to_halfbits(x)
822
834
# elif @is_double1@
823
835
# define _CONVERT_FN(x) npy_doublebits_to_halfbits(x)
824
- # elif @is_half1 @
836
+ # elif @is_emu_half1 @
825
837
# define _CONVERT_FN(x) (x)
826
838
# elif @is_bool1@
827
839
# define _CONVERT_FN(x) npy_float_to_half((float)(x!=0))
@@ -839,7 +851,29 @@ NPY_NO_EXPORT PyArrayMethod_StridedLoop *
839
851
840
852
#endif
841
853
842
- static NPY_GCC_OPT_3 int
854
+ // Enable auto-vectorization for floating point casts with clang
855
+ #if @is_native_half1@ || @is_float1@ || @is_double1@
856
+ #if @is_native_half2@ || @is_float2@ || @is_double2@
857
+ #if defined(__clang__) && !defined(__EMSCRIPTEN__)
858
+ #if __clang_major__ >= 12
859
+ _Pragma("clang fp exceptions(ignore)")
860
+ #endif
861
+ #endif
862
+ #endif
863
+ #endif
864
+
865
+ // Work around GCC bug for double->half casts. For SVE and
866
+ // OPT_LEVEL > 1, it implements this as double->single->half
867
+ // which is incorrect as it introduces double rounding with
868
+ // narrowing casts.
869
+ #if (@is_double1@ && @is_native_half2@) && \
870
+ defined(NPY_HAVE_SVE) && defined(__GNUC__)
871
+ #define GCC_CAST_OPT_LEVEL __attribute__((optimize("O1")))
872
+ #else
873
+ #define GCC_CAST_OPT_LEVEL NPY_GCC_OPT_3
874
+ #endif
875
+
876
+ static GCC_CAST_OPT_LEVEL int
843
877
@prefix@_cast_@name1@_to_@name2@(
844
878
PyArrayMethod_Context *context, char *const *args,
845
879
const npy_intp *dimensions, const npy_intp *strides,
@@ -933,6 +967,17 @@ static NPY_GCC_OPT_3 int
933
967
return 0;
934
968
}
935
969
970
+ #if @is_native_half1@ || @is_float1@ || @is_double1@
971
+ #if @is_native_half2@ || @is_float2@ || @is_double2@
972
+ #if defined(__clang__) && !defined(__EMSCRIPTEN__)
973
+ #if __clang_major__ >= 12
974
+ _Pragma("clang fp exceptions(strict)")
975
+ #endif
976
+ #endif
977
+ #endif
978
+ #endif
979
+
980
+ #undef GCC_CAST_OPT_LEVEL
936
981
#undef _CONVERT_FN
937
982
#undef _TYPE2
938
983
#undef _TYPE1
0 commit comments