8000 [MPSInductor] Cast halfs to floats (#151246) · pytorch/pytorch@46ce8f7 · GitHub
[go: up one dir, main page]

Skip to content

Commit 46ce8f7

Browse files
malfetpytorchmergebot
authored andcommitted
[MPSInductor] Cast halfs to floats (#151246)
To avoid accuracy issues when small reductions are unrolled, cast half to float during the `load` op As `op_math_t<half>` is indeed float This fixes `test_unroll_small_reduction` for reduced precision types Pull Request resolved: #151246 Approved by: https://github.com/dcci ghstack dependencies: #151224
1 parent 0a6e1d6 commit 46ce8f7

File tree

2 files changed

+9
-1
lines changed

2 files changed

+9
-1
lines changed

test/inductor/test_mps_basic.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,7 @@ def fn(a):
234234
"test_sum_int",
235235
"test_sum_keepdims",
236236
"test_tanh",
237+
"test_unroll_small_reduction",
237238
"test_vectorized_ops_masked",
238239
"test_var_mean_tile_reduction_True",
239240
"test_view_as_complex",

torch/_inductor/codegen/mps.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -487,8 +487,15 @@ def load(self, name: str, index: sympy.Expr) -> CSEVariable:
487487
"""Codegen a load from an InputBuffer"""
488488
var = self.args.input(name)
489489
index = self.prepare_indexing(index)
490+
dtype = V.graph.get_dtype(name)
490491
line = f"{var}[{self.index_to_str(index)}]"
491-
return self.cse.generate(self.loads, line, dtype=V.graph.get_dtype(name))
492+
if dtype in [torch.float16, torch.bfloat16]:
493+
# TODO(NS): Figure out the right balance betwene optype casts
494+
# op_math_t for half-precision floats should be float32
495+
# Otherwise it can lead to a corretness issues with eager
496+
line = f"static_cast<float>({line})"
497+
dtype = torch.float32
498+
return self.cse.generate(self.loads, line, dtype=dtype)
492499

493500
def store(
494501
self, name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None

0 commit comments

Comments
 (0)
0