8000 [MPS] Gather sliced inputs to batch norm (#134121) · pytorch/pytorch@eba4d08 · GitHub
[go: up one dir, main page]

Skip to content

Commit eba4d08

Browse files
pytorchbothvaara
andauthored
[MPS] Gather sliced inputs to batch norm (#134121)
[MPS] Gather sliced inputs to batch norm (#133610) This PR removes the `executeGatherOp` flag from batch norm in favor of relying on the logic in https://github.com/pytorch/pytorch/blob/4aa66f68a803927ddd127ceaaa1521b8d6e90e5f/aten/src/ATen/native/mps/OperationUtils.mm#L372 to decide if gathering is necessary. It's not the most efficient way to solve this issue, but it assures correctness for sliced inputs. ### Performance impact #### With fix ``` python -m timeit -n 100 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x)" 100 loops, best of 5: 282 usec per loop python -m timeit -n 100 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x[5:])" 100 loops, best of 5: 448 usec per loop python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x)" 1000 loops, best of 5: 705 usec per loop python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x[5:])" 1000 loops, best of 5: 1.11 msec per loop python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(1000, 100, 35, 45).to('mps')" "bn(x)" 1000 loops, best of 5: 7.16 msec per loop python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(1000, 100, 35, 45).to('mps')" "bn(x[5:])" 1000 loops, best of 5: 11.7 msec per loop ``` #### Without fix ``` python -m timeit -n 100 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x)" 100 loops, best of 5: 284 usec per loop python -m timeit -n 100 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x[5:])" 100 loops, best of 5: 265 usec per loop python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x)" 1000 loops, best of 5: 715 usec per loop python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x[5:])" 1000 loops, best of 5: 675 usec per loop python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(1000, 100, 35, 45).to('mps')" "bn(x)" 1000 loops, best of 5: 7.19 msec per loop python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(1000, 100, 35, 45).to('mps')" "bn(x[5:])" 1000 loops, best of 5: 7.13 msec per loop ``` Please feel free to push back or request changes. Fixes #133520 Pull Request resolved: #133610 Approved by: https://github.com/malfet (cherry picked from commit 43f78bf) Co-authored-by: Roy Hvaara <roy@lightyear.no>
1 parent 2213c07 commit eba4d08

File tree

2 files changed

+14
-7
lines changed

2 files changed

+14
-7
lines changed

aten/src/ATen/native/mps/operations/Normalization.mm

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -153,12 +153,6 @@ static void get_shapes(MPSShape* input_shape_readonly,
153153
else
154154
channelsDim = num_input_dims - 1;
155155

156-
bool executeGatherOp = true;
157-
if (self.is_contiguous(memory_format)) {
158-
memory_format = MemoryFormat::Contiguous;
159-
executeGatherOp = false;
160-
}
161-
162156
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
163157
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_mps_dtype, input_shape);
164158
MPSGraphTensor* weightTensor = nil;
@@ -302,7 +296,7 @@ Check if running mean exists (maybe do this check before making graph)
302296
newCachedGraph->runningVarInplaceUpdate_ = runningVarInplaceUpdate;
303297
});
304298

305-
auto inputPlaceholder = Placeholder(cachedGraph->inputTensor_, self, input_shape, executeGatherOp);
299+
auto inputPlaceholder = Placeholder(cachedGraph->inputTensor_, self, input_shape);
306300
auto weightPlaceholder = Placeholder();
307301
if (has_weight)
308302
weightPlaceholder = Placeholder(cachedGraph->weightTensor_, weight_opt.value(), new_mean_shape);

test/test_mps.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2547,6 +2547,19 @@ def test_batch_norm_backward(self):
25472547
# This used to crash, see https://github.com/pytorch/pytorch/issues/98602
25482548
outputs.sum().backward()
25492549

2550+
# Regression test for https://github.com/pytorch/pytorch/issues/133520
2551+
def test_batch_norm_slices(self):
2552+
bn_cpu = nn.BatchNorm2d(100, affine=False, device='cpu')
2553+
bn_mps = nn.BatchNorm2d(100, affine=False, device='mps')
2554+
2555+
x_cpu = torch.randn(100, 100, 35, 45).to('cpu')
2556+
x_mps = x_cpu.to('mps')
2557+
2558+
res_cpu = bn_cpu(x_cpu[5:])
2559+
res_mps = bn_mps(x_mps[5:])
2560+
2561+
self.assertEqual(res_cpu, res_mps)
2562+
25502563
def test_layer_norm_backward(self):
25512564
inputs = torch.rand(4, 4, device="mps", requires_grad=True)
25522565
x = torch.nn.LayerNorm(4).to("mps")

0 commit comments

Comments
 (0)
0