E578 [MPS] Fix index_add for complex + int64 (#160926) · pytorch/pytorch@a44a0d3 · GitHub
[go: up one dir, main page]

Skip to content

Commit a44a0d3

Browse files
malfetpytorchmergebot
authored andcommitted
[MPS] Fix index_add for complex + int64 (#160926)
By re-using deterministic algorithm from https://github.com/pytorch/pytorch/blob/bbc7c03e936ab0fce69dfda1fdf4798523a2fbf8/aten/src/ATen/native/cuda/Indexing.cu#L1106-L1113 Fixes #160845 Pull Request resolved: #160926 Approved by: https://github.com/manuelcandales ghstack dependencies: #160850, #160889
1 parent 2f0cba9 commit a44a0d3

File tree

2 files changed

+23
-2
lines changed

2 files changed

+23
-2
lines changed

aten/src/ATen/native/mps/operations/Indexing.mm

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -512,7 +512,28 @@ Tensor flip_mps(const Tensor& self, IntArrayRef dims) {
512512
return;
513513
}
514514

515-
TORCH_CHECK(source.scalar_type() != ScalarType::Long, "index_add(): Expected non int64 dtype for source.");
515+
bool use_deterministic_algorithm = globalContext().deterministicAlgorithms();
516+
517+
// TODO: Do not use deterministic algorithm for long/complex but rather implement it as Metal shader
518+
use_deterministic_algorithm |= source.scalar_type() == ScalarType::Long;
519+
use_deterministic_algorithm |= c10::isComplexType(source.scalar_type());
520+
521+
if (use_deterministic_algorithm) {
522+
if (!result.is_same(self)) {
523+
result.copy_(self);
524+
}
525+
torch::List<std::optional<Tensor>> indices;
526+
indices.reserve(dim + 1);
527+
for (const auto i : c10::irange(dim)) {
528+
indices.emplace_back();
529+
}
530+
indices.emplace_back(index.to(at::kLong));
531+
const Tensor result_ = (result.dim() == 0) ? result.view(1) : result;
532+
const Tensor source_ = (source.dim() == 0) ? source.view(1) : source;
533+
result_.index_put_(indices, source_.mul(alpha), true);
534+
return;
535+
}
536+
516537
auto casted_type = isFloatingType(source.scalar_type()) ? ScalarType::Float : ScalarType::Int;
517538

518539
struct CachedGraph : public MPSCachedGraph {

torch/testing/_internal/common_mps.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ def mps_ops_modifier(
7474
"H",
7575
"hsplit",
7676
"imag",
77+
"index_add",
7778
"index_copy",
7879
"index_select",
7980
"index_put",
@@ -419,7 +420,6 @@ def mps_ops_modifier(
419420
],
420421
# Unsupported dtypes
421422
"histc": [torch.float16, torch.bfloat16],
422-
"index_add": [torch.int64],
423423
# GEMM on MPS is not supported for integral types
424424
"nn.functional.linear": [
425425
torch.int16,

0 commit comments

Comments
 (0)
0