8000 Unbreak fp16 dot issues caused by #137917 (#139262) · pytorch/pytorch@3495ef7 · GitHub
[go: up one dir, main page]

Skip to content

Commit 3495ef7

Browse files
swolchokpytorchmergebot
authored andcommitted
Unbreak fp16 dot issues caused by #137917 (#139262)
See comment for explanation. In short, doing the fixup in float. Pull Request resolved: #139262 Approved by: https://github.com/huydhn
1 parent 4e5f9af commit 3495ef7

File tree

1 file changed

+4
-23
lines changed

1 file changed

+4
-23
lines changed

aten/src/ATen/native/cpu/ReducedPrecisionFloatGemvFastPathKernel.cpp

Lines changed: 4 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -218,22 +218,6 @@ dot_with_fp32_arith_main_loop_no_bfdot(
218218
return reduce(sum);
219219
}
220220

221-
template <typename T>
222-
struct half_to_float16 {
223-
using type = T;
224-
};
225-
226-
227-
#ifdef __aarch64__
228-
template <>
229-
struct half_to_float16<Half> {
230-
using type = float16_t;
231-
};
232-
#endif
233-
234-
template <typename T>
235-
using half_to_float16_t = typename half_to_float16<T>::type;
236-
237221
static_assert(
238222
(vec::Vectorized<Half>::size() & (vec::Vectorized<Half>::size() - 1)) == 0,
239223
"Below code expects power-of-2 vector register size!");
@@ -258,13 +242,10 @@ static_assert(
258242
\
259243
/* Second-tier tail fixup: handle all workloads. */ \
260244
for (int j = len_aligned_vec; j < len; ++j) { \
261-
/* We use half_to_float16_t here because changing to Half was */ \
262-
/* causing arithmetic to at fp16 precision, but the necessary */ \
263-
/* necessary behavior to pass python test/test_mps.py -k */ \
264-
/* test_output_grad_match_nn_functional_linear_cpu_float16 is */ \
265-
/* fp32. (I'm not sure exactly why this fixes it.) */ \
266-
half_to_float16_t<std::decay_t<decltype(vec1[j])>> x1 = vec1[j]; \
267-
half_to_float16_t<std::decay_t<decltype(vec2[j])>> x2 = vec2[j]; \
245+
/* Attempting to use Half here caused multiple test failures; */ \
246+
/* using float to unbreak. (Suspect we need a scalar FMA.) */ \
247+
float x1 = vec1[j]; \
248+
float x2 = vec2[j]; \
268249
reduced_sum += x1 * x2; \
269250
} \
270251
return reduced_sum

0 commit comments

Comments
 (0)
0