8000 [MPSInductor] Implement Welford reduction (#146703) · pytorch/pytorch@2328dcc · GitHub 10000
[go: up one dir, main page]

Skip to content

Commit 2328dcc

Browse files
malfetpytorchmergebot
authored andcommitted
[MPSInductor] Implement Welford reduction (#146703)
Still work in progress, though fallback works as expected, but custom shader is not Pull Request resolved: #146703 Approved by: https://github.com/jansel, https://github.com/dcci
1 parent 69feef5 commit 2328dcc

File tree

3 files changed

+24
-1
lines changed

3 files changed

+24
-1
lines changed

c10/metal/reduction_utils.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,18 @@ opmath_t<T> threadgroup_prod(threadgroup T* data, unsigned size) {
2929
return rc;
3030
}
3131

32+
template <typename T>
33+
float2 threadgroup_welford_reduce(threadgroup T* data, unsigned size) {
34+
float m = data[0];
35+
float m2 = 0;
36+
for (unsigned idx = 1; idx < size; ++idx) {
37+
float delta = data[idx] - m;
38+
m += delta / (idx + 1);
39+
m2 += delta * (data[idx] - m);
40+
}
41+
return float2(m, m2);
42+
}
43+
3244
template <typename T>
3345
T threadgroup_max(threadgroup T* data, unsigned size) {
3446
// TODO: This should be moved to the callee

test/inductor/test_mps_basic.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ def fn(a):
164164
"test_inf",
165165
"test_isinf",
166166
"test_isinf2",
167+
"test_layer_norm",
167168
"test_lgamma",
168169
"test_linear_float64",
169170
"test_log_fp64",

torch/_inductor/codegen/mps.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from torch.utils._sympy.value_ranges import ValueRanges
1313

1414
from ..utils import get_bounds_index_expr, get_kernel_metadata
15-
from ..virtualized import ops, V
15+
from ..virtualized import ops, OpsWrapper, V
1616
from .common import (
1717
CSEVariable,
1818
DeferredLine,
@@ -463,6 +463,16 @@ def reduction(
463463
f"c10::metal::threadgroup_{reduction_type}({acc_buf}, {reduction_dim.numel})",
464464
dtype=dtype,
465465
)
466+
if reduction_type == "welford_reduce":
467+
acc_buf = self._new_accvar(src_dtype, reduction_dim.numel)
468+
self.body.splice(f"{acc_buf}[{reduction_dim.name}] = {value};")
469+
wf_res = self.cse.generate(
470+
self.body,
471+
f"c10::metal::threadgroup_{reduction_type}({acc_buf}, {reduction_dim.numel})",
472+
)
473+
return OpsWrapper._unwrap(
474+
(f"{wf_res}.x", f"{wf_res}.y", self.features.reduction_numel)
475+
)
466476
raise NotImplementedError(reduction_type)
467477

468478
def codegen_iteration_ranges_entry(self, entry: IterationRangesEntry) -> None:

0 commit comments

Comments
 (0)
0