8000 [MPSInductor] Implement `prod` reduction (#146396) · pytorch/pytorch@5d81bc3 · GitHub
[go: up one dir, main page]

Skip to content

Commit 5d81bc3

Browse files
malfetpytorchmergebot
authored andcommitted
[MPSInductor] Implement prod reduction (#146396)
Mostly reusing `sum` reduction logic Pull Request resolved: #146396 Approved by: https://github.com/dcci ghstack dependencies: #146369, #146370, #146380, #146389
1 parent bbe9534 commit 5d81bc3

File tree

3 files changed

+15
-2
lines changed

3 files changed

+15
-2
lines changed

c10/metal/reduction_utils.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,18 @@ opmath_t<T> threadgroup_sum(threadgroup T* data, unsigned size) {
1818
return rc;
1919
}
2020

21+
template <typename T>
22+
opmath_t<T> threadgroup_prod(threadgroup T* data, unsigned size) {
23+
opmath_t<T> rc = 1;
24+
// TODO: This should be moved to the callee
25+
::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup);
26+
// TODO: Use `simd_shuffle_down`
27+
for (auto idx = 0; idx < size; ++idx) {
28+
rc *= data[idx];
29+
}
30+
return rc;
31+
}
32+
2133
template <typename T>
2234
T threadgroup_max(threadgroup T* data, unsigned size) {
2335
// TODO: This should be moved to the callee

test/inductor/test_mps_basic.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ def inc_(x):
154154
"test_min_max_reduction_nan",
155155
"test_nan_to_num",
156156
"test_pow2",
157+
"test_prod",
157158
"test_randint_int64_mod",
158159
"test_randn_generator",
159160
"test_remainder",

torch/_inductor/codegen/mps.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -428,12 +428,12 @@ def reduction(
428428
"""
429429
)
430430
return acc
431-
if reduction_type == "sum":
431+
if reduction_type in ["prod", "sum"]:
432432
acc_buf = self._new_accvar(src_dtype, reduction_dim.numel)
433433
self.body.splice(f"{acc_buf}[{reduction_dim.name}] = {value};")
434434
return self.cse.generate(
435435
self.body,
436-
f"c10::metal::threadgroup_sum({acc_buf}, {reduction_dim.numel})",
436+
f"c10::metal::threadgroup_{reduction_type}({acc_buf}, {reduction_dim.numel})",
437437
dtype=DTYPE_TO_COMPUTATION_DTYPE[dtype],
438438
)
439439
if reduction_type in ["max", "min"]:

0 commit comments

Comments
 (0)
0