@@ -90,8 +90,6 @@ namespace at::native {
90
90
#if !defined(C10_MOBILE)
91
91
DEFINE_DISPATCH (fp16_gemv_trans_stub);
92
92
DEFINE_DISPATCH (bf16_gemv_trans_stub);
93
- DEFINE_DISPATCH (fp16_dot_stub);
94
- DEFINE_DISPATCH (bf16_dot_stub);
95
93
#endif // !defined(C10_MOBILE)
96
94
97
95
namespace blas_impl {
@@ -122,15 +120,6 @@ void fp16_gemv_trans(
122
120
fp16_gemv_trans_stub (kCPU , m, n, alpha, a, lda, x, incx, beta, y, incy);
123
121
}
124
122
125
- static float fp16_dot (
126
- const int64_t n,
127
- const Half* x,
128
- const int64_t incx,
129
- const Half* y,
130
- const int64_t incy) {
131
- return fp16_dot_stub (kCPU , n, x, incx, y, incy);
132
- }
133
-
134
123
#endif // !defined(C10_MOBILE)
135
124
136
125
#if defined(__aarch64__) && !defined(C10_MOBILE)
@@ -395,16 +384,6 @@ void gemv_fast_path<at::BFloat16>(
395
384
y,
396
385
*incy);
397
386
}
398
-
399
- static float bf16_dot (
400
- const int64_t n,
401
- const BFloat16* x,
402
- const int64_t incx,
403
- const BFloat16* y,
404
- const int64_t incy) {
405
- return bf16_dot_stub (kCPU , n, x, incx, y, incy);
406
- }
407
-
408
387
#if !defined(__aarch64__)
409
388
// Currently, only fp16_gemv_trans is built for non-aarch64.
410
389
template <>
@@ -716,34 +695,6 @@ c10::complex<float> dot_impl(int64_t n, const c10::complex<float>* x, int64_t in
716
695
return dot_impl_floating (n, x, incx, y, incy);
717
696
}
718
697
719
- template <>
720
- Half dot_impl (int64_t n, const Half* x, int64_t incx, const Half* y, int64_t incy) {
721
- if (n == 1 ) {
722
- incx = 1 ;
723
- incy = 1 ;
724
- }
725
- #if !defined(C10_MOBILE)
726
- if (incx == 1 && incy == 1 ) {
727
- return blas_impl::fp16_dot (n, x, incx, y, incy);
728
- }
729
- #endif // !defined(C10_MOBILE)
730
- return blas_impl::dot_naive (n, x, incx, y, incy, std::multiplies<float >{});
731
- }
732
-
733
- template <>
734
- BFloat16 dot_impl (int64_t n, const BFloat16* x, int64_t incx, const BFloat16* y, int64_t incy) {
735
- if (n == 1 ) {
736
- incx = 1 ;
737
- incy = 1 ;
738
- }
739
- #if !defined(C10_MOBILE)
740
- if (incx == 1 && incy == 1 ) {
741
- return blas_impl::bf16_dot (n, x, incx, y, incy);
742
- }
743
- #endif // !defined(C10_MOBILE)
744
- return blas_impl::dot_naive (n, x, incx, y, incy, std::multiplies<float >{});
745
- }
746
-
747
698
namespace {
748
699
template <typename scalar_t >
749
700
struct vdot_op {
@@ -770,7 +721,7 @@ scalar_t vdot_impl(int64_t n, const scalar_t* x, int64_t incx, const scalar_t* y
770
721
#endif
771
722
}
772
723
773
- // Skip reinstantiating the explicitly specialized types `float`, `double`, `half` & `bfloat16 `.
724
+ // Skip reinstantiating the explicitly specialized types `float` and `double `.
774
725
#define INSTANTIATE_DOT_IMPL (scalar_t ) \
775
726
template scalar_t dot_impl<scalar_t >( \
776
727
int64_t n, const scalar_t * x, int64_t incx, const scalar_t * y, int64_t incy);
@@ -779,6 +730,8 @@ INSTANTIATE_DOT_IMPL(int8_t)
779
730
INSTANTIATE_DOT_IMPL (int16_t )
780
731
INSTANTIATE_DOT_IMPL (int )
781
732
INSTANTIATE_DOT_IMPL (int64_t )
733
+ INSTANTIATE_DOT_IMPL (c10::Half)
734
+ INSTANTIATE_DOT_IMPL (c10::BFloat16)
782
735
783
736
#define INSTANTIATE_VDOT_IMPL (scalar_t ) \
784
737
template scalar_t vdot_impl<scalar_t >( \
0 commit comments