File tree Expand file tree Collapse file tree 2 files changed +23
-2
lines changed
aten/src/ATen/native/mps/operations Expand file tree Collapse file tree 2 files changed +23
-2
lines changed Original file line number Diff line number Diff line change @@ -512,7 +512,28 @@ Tensor flip_mps(const Tensor& self, IntArrayRef dims) {
512512 return ;
513513 }
514514
515- TORCH_CHECK (source.scalar_type () != ScalarType::Long, " index_add(): Expected non int64 dtype for source." );
515+ bool use_deterministic_algorithm = globalContext ().deterministicAlgorithms ();
516+
517+ // TODO: Do not use deterministic algorithm for long/complex but rather implement it as Metal shader
518+ use_deterministic_algorithm |= source.scalar_type () == ScalarType::Long;
519+ use_deterministic_algorithm |= c10::isComplexType (source.scalar_type ());
520+
521+ if (use_deterministic_algorithm) {
522+ if (!result.is_same (self)) {
523+ result.copy_ (self);
524+ }
525+ torch::List<std::optional<Tensor>> indices;
526+ indices.reserve (dim + 1 );
527+ for (const auto i : c10::irange (dim)) {
528+ indices.emplace_back ();
529+ }
530+ indices.emplace_back (index.to (at::kLong ));
531+ const Tensor result_ = (result.dim () == 0 ) ? result.view (1 ) : result;
532+ const Tensor source_ = (source.dim () == 0 ) ? source.view (1 ) : source;
533+ result_.index_put_ (indices, source_.mul (alpha), true );
534+ return ;
535+ }
536+
516537 auto casted_type = isFloatingType (source.scalar_type ()) ? ScalarType::Float : ScalarType::Int;
517538
518539 struct CachedGraph : public MPSCachedGraph {
Original file line number Diff line number Diff line change @@ -74,6 +74,7 @@ def mps_ops_modifier(
7474 "H" ,
7575 "hsplit" ,
7676 "imag" ,
77+ "index_add" ,
7778 "index_copy" ,
7879 "index_select" ,
7980 "index_put" ,
@@ -419,7 +420,6 @@ def mps_ops_modifier(
419420 ],
420421 # Unsupported dtypes
421422 "histc" : [torch .float16 , torch .bfloat16 ],
422- "index_add" : [torch .int64 ],
423423 # GEMM on MPS is not supported for integral types
424424 "nn.functional.linear" : [
425425 torch .int16 ,
You can’t perform that action at this time.
0 commit comments