8000 Revert "[float16]: Fast path for torch.dot with float16/bfloat16 (#15… · pytorch/pytorch@fdadda2 · GitHub
[go: up one dir, main page]

Skip to content

Commit fdadda2

Browse files
Revert "[float16]: Fast path for torch.dot with float16/bfloat16 (#152799)"
This reverts commit d57bf53. Reverted #152799 on behalf of https://github.com/malfet due to This broke C10_MOBILE builds, not sure why it was not surfaced on pull, see https://hud.pytorch.org/hud/pytorch/pytorch/a766c1d117bbdf348d1a9d02a514651b3e05b1e4/1?per_page=50&name_filter=lightweight&mergeEphemeralLF=true ([comment](#152799 (comment)))
1 parent a766c1d commit fdadda2

File tree

3 files changed

+3
-79
lines changed

3 files changed

+3
-79
lines changed

aten/src/ATen/native/BlasKernel.cpp

Lines changed: 3 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,6 @@ namespace at::native {
9090
#if !defined(C10_MOBILE)
9191
DEFINE_DISPATCH(fp16_gemv_trans_stub);
9292
DEFINE_DISPATCH(bf16_gemv_trans_stub);
93-
DEFINE_DISPATCH(fp16_dot_stub);
94-
DEFINE_DISPATCH(bf16_dot_stub);
9593
#endif // !defined(C10_MOBILE)
9694

9795
namespace blas_impl {
@@ -122,15 +120,6 @@ void fp16_gemv_trans(
122120
fp16_gemv_trans_stub(kCPU, m, n, alpha, a, lda, x, incx, beta, y, incy);
123121
}
124122

125-
static float fp16_dot(
126-
const int64_t n,
127-
const Half* x,
128-
const int64_t incx,
129-
const Half* y,
130-
const int64_t incy) {
131-
return fp16_dot_stub(kCPU, n, x, incx, y, incy);
132-
}
133-
134123
#endif // !defined(C10_MOBILE)
135124

136125
#if defined(__aarch64__) && !defined(C10_MOBILE)
@@ -395,16 +384,6 @@ void gemv_fast_path<at::BFloat16>(
395384
y,
396385
*incy);
397386
}
398-
399-
static float bf16_dot(
400-
const int64_t n,
401-
const BFloat16* x,
402-
const int64_t incx,
403-
const BFloat16* y,
404-
const int64_t incy) {
405-
return bf16_dot_stub(kCPU, n, x, incx, y, incy);
406-
}
407-
408387
#if !defined(__aarch64__)
409388
// Currently, only fp16_gemv_trans is built for non-aarch64.
410389
template <>
@@ -716,34 +695,6 @@ c10::complex<float> dot_impl(int64_t n, const c10::complex<float>* x, int64_t in
716695
return dot_impl_floating(n, x, incx, y, incy);
717696
}
718697

719-
template <>
720-
Half dot_impl(int64_t n, const Half* x, int64_t incx, const Half* y, int64_t incy) {
721-
if (n == 1) {
722-
incx = 1;
723-
incy = 1;
724-
}
725-
#if !defined(C10_MOBILE)
726-
if (incx == 1 && incy == 1) {
727-
return blas_impl::fp16_dot(n, x, incx, y, incy);
728-
}
729-
#endif // !defined(C10_MOBILE)
730-
return blas_impl::dot_naive(n, x, incx, y, incy, std::multiplies<float>{});
731-
}
732-
733-
template <>
734-
BFloat16 dot_impl(int64_t n, const BFloat16* x, int64_t incx, const BFloat16* y, int64_t incy) {
735-
if (n == 1) {
736-
incx = 1;
737-
incy = 1;
738-
}
739-
#if !defined(C10_MOBILE)
740-
if (incx == 1 && incy == 1) {
741-
return blas_impl::bf16_dot(n, x, incx, y, incy);
742-
}
743-
#endif // !defined(C10_MOBILE)
744-
return blas_impl::dot_naive(n, x, incx, y, incy, std::multiplies<float>{});
745-
}
746-
747698
namespace {
748699
template <typename scalar_t>
749700
struct vdot_op {
@@ -770,7 +721,7 @@ scalar_t vdot_impl(int64_t n, const scalar_t* x, int64_t incx, const scalar_t* y
770721
#endif
771722
}
772723

773-
// Skip reinstantiating the explicitly specialized types `float`, `double`, `half` & `bfloat16`.
724+
// Skip reinstantiating the explicitly specialized types `float` and `double`.
774725
#define INSTANTIATE_DOT_IMPL(scalar_t) \
775726
template scalar_t dot_impl<scalar_t>( \
776727
int64_t n, const scalar_t * x, int64_t incx, const scalar_t * y, int64_t incy);
@@ -779,6 +730,8 @@ INSTANTIATE_DOT_IMPL(int8_t)
779730
INSTANTIATE_DOT_IMPL(int16_t)
780731
INSTANTIATE_DOT_IMPL(int)
781732
INSTANTIATE_DOT_IMPL(int64_t)
733+
INSTANTIATE_DOT_IMPL(c10::Half)
734+
INSTANTIATE_DOT_IMPL(c10::BFloat16)
782735

783736
#define INSTANTIATE_VDOT_IMPL(scalar_t) \
784737
template scalar_t vdot_impl<scalar_t>( \

aten/src/ATen/native/cpu/ReducedPrecisionFloatGemvFastPathKernel.cpp

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -475,35 +475,12 @@ void bf16_gemv_trans(
475475
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(incx == 1 && alpha == 1.0 && beta == 0.0);
476476
return bf16_gemv_trans_fp32_arith_by_dot_products(m, n, a, lda, x, y, incy);
477477
}
478-
479-
float fp16_dot(
480-
const int64_t n,
481-
const at::Half* x,
482-
const int64_t incx,
483-
const at::Half* y,
484-
const int64_t incy) {
485-
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(incx == 1 && incy == 1);
486-
return fp16_dot_with_fp32_arith(x, y, n);
487-
}
488-
489-
float bf16_dot(
490-
const int64_t n,
491-
const at::BFloat16* x,
492-
const int64_t incx,
493-
const at::BFloat16* y,
494-
const int64_t incy) {
495-
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(incx == 1 && incy == 1);
496-
return bf16_dot_with_fp32_arith(x, y, n);
497-
}
498-
499478
#endif // !defined(C10_MOBILE)
500479
} // namespace CPU_CAPABILITY
501480

502481
#if !defined(C10_MOBILE)
503482
REGISTER_DISPATCH(fp16_gemv_trans_stub, &fp16_gemv_trans)
504483
REGISTER_DISPATCH(bf16_gemv_trans_stub, &bf16_gemv_trans)
505-
REGISTER_DISPATCH(fp16_dot_stub, &fp16_dot)
506-
REGISTER_DISPATCH(bf16_dot_stub, &bf16_dot)
507484
#endif //!defined(C10_MOBILE)
508485

509486
} // namespace at::native

aten/src/ATen/native/cpu/ReducedPrecisionFloatGemvFastPathKernel.h

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,6 @@ DECLARE_DISPATCH(fp16_gemv_fn, fp16_gemv_trans_stub)
1313
using bf16_gemv_fn = void(*)(int, int, BFloat16, const BFloat16*, int, const BFloat16*, int, BFloat16, BFloat16*, int);
1414
DECLARE_DISPATCH(bf16_gemv_fn, bf16_gemv_trans_stub)
1515

16-
using fp16_dot_fn = float(*)(const int64_t, const Half*, const int64_t, const Half*, const int64_t);
17-
DECLARE_DISPATCH(fp16_dot_fn, fp16_dot_stub)
18-
19-
using bf16_dot_fn = float(*)(const int64_t, const BFloat16*, const int64_t, const BFloat16*, const int64_t);
20-
DECLARE_DISPATCH(bf16_dot_fn, bf16_dot_stub)
21-
2216
inline namespace CPU_CAPABILITY {
2317
float fp16_dot_with_fp32_arith(const Half* vec1, const Half* vec2, int64_t len);
2418
float bf16_dot_with_fp32_arith(const BFloat16* vec1, const BFloat16* vec2, int64_t len);

0 commit comments

Comments
 (0)
0