10000 [MPS] Fix float64 scalar tensor handling (#153582) · pytorch/pytorch@d5ddc5a · GitHub
[go: up one dir, main page]

Skip to content

Commit d5ddc5a

Browse files
malfetpytorchmergebot
authored andcommitted
[MPS] Fix float64 scalar tensor handling (#153582)
Current implementation causes silent correction problem with torch.compile when someone tries to `torch.compile` function where one of the arguments is say `np.exp(.3)`, which will be represented as torch.float64 scalar tensor Add regssion test for this behavior Pull Request resolved: #153582 Approved by: https://github.com/dcci
1 parent 3e8bda4 commit d5ddc5a

File tree

2 files changed

+9
-0
lines changed

2 files changed

+9
-0
lines changed

aten/src/ATen/native/mps/OperationUtils.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -467,6 +467,12 @@ static inline void mtl_setBuffer(encoder_t encoder, const TensorBase& t, unsigne
467467
if (C10_UNLIKELY(t.device().type() == kCPU)) {
468468
if constexpr (std::is_same_v<id<MTLComputeCommandEncoder>, encoder_t>) {
469469
TORCH_CHECK(t.dim() == 0, "Passed CPU tensor to MPS op");
470+
// MPS does not support doubles, silently downcast CPU scalar to float
471+
if (C10_UNLIKELY(t.scalar_type() == kDouble)) {
472+
auto val = static_cast<float>(*reinterpret_cast<const double*>(t.const_data_ptr()));
473+
[encoder setBytes:&val length:sizeof(val) atIndex:idx];
474+
return;
475+
}
470476
[encoder setBytes:t.storage().data() length:t.element_size() atIndex:idx];
471477
} else {
472478
TORCH_CHECK(false, "Passed CPU tensor to MPS op");

test/test_mps.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12317,6 +12317,9 @@ def test_metal_arange_with_arg(self, start=3.14, step=.5):
1231712317
def test_metal_arange_with_arg_and_scalar_tensor(self):
1231812318
self.test_metal_arange_with_arg(step=torch.tensor(.5))
1231912319

12320+
def test_metal_arange_with_arg_and_scalar_tensor_float64(self):
12321+
self.test_metal_arange_with_arg(step=torch.tensor(.5, dtype=torch.float64))
12322+
1232012323
def test_metal_arange_with_arg_and_cast(self):
1232112324
x = torch.zeros(12, device="mps", dtype=torch.half)
1232212325
y = torch.zeros(12, device="mps", dtype=torch.half)

0 commit comments

Comments
 (0)
0