-
Notifications
You must be signed in to change notification settings - Fork 24.2k
[MPS] Gather sliced inputs to batch norm #133610
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/133610
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 4b16953 with merge base 546c53b ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
bool executeGatherOp = true; | ||
if (self.is_contiguous(memory_format)) { | ||
memory_format = MemoryFormat::Contiguous; | ||
executeGatherOp = false; | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does my assumption that the decision to gather is sufficiently handled in
if (needsGather(src) && gatherTensorData) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Considering #94760, the issue is tricky I would say. The test expects input layout is preserved in output. Let's look at CI signals.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like everything passed in CI.
bn_cpu = nn.BatchNorm2d(100, affine=False, device='cpu') | ||
bn_mps = nn.BatchNorm2d(100, affine=False, device='mps') | ||
|
||
x_cpu = torch.randn(100, 100, 35, 45).to('cpu') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are these values fine, or should I consider lowering them to cover the minimal case? timeit results suggests a single loop is sub-millisecond.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If it's sub-millisecond this is fine
@pytorchbot label "module: mps" "ciflow/mps" |
Can't add following labels to PR: ciflow/mps. Please ping one of the reviewers for help. |
cc @bauerwer for visibility |
@hvaara Quick check passed on my system. I integrated the Normalization.mm changes into my torch main, built locally and was able to run flux without noise in the image. I ran a few other checkpoints and workflows as well without problems. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Following the fix here, one situation is always my concern: what if output is an alias of self, i.e. they share the same storage. I think to address it thoroughly, we need #128393.
bool executeGatherOp = true; | ||
if (self.is_contiguous(memory_format)) { | ||
memory_format = MemoryFormat::Contiguous; | ||
executeGatherOp = false; | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Considering #94760, the issue is tricky I would say. The test expects input layout is preserved in output. Let's look at CI signals.
@kulinseth @malfet @qqaatw are either of you available for a review? 🙏 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me, let's cherry-pick it into 2.4.1, but I think the problem (and therefore the fix, should be more generic that this one, as all other ops are likely similarly affected
bn_cpu = nn.BatchNorm2d(100, affine=False, device='cpu') | ||
bn_mps = nn.BatchNorm2d(100, affine=False, device='mps') | ||
|
||
x_cpu = torch.randn(100, 100, 35, 45).to('cpu') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If it's sub-millisecond this is fine
@pytorchbot merge -f "MPS tests are green" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Thanks for the review! Following up in #133520 to address comments and future work. |
@pytorchbot cherry-pick --onto release/2.4 -c critical --fixes #133520 |
This PR removes the `executeGatherOp` flag from batch norm in favor of relying on the logic in https://github.com/pytorch/pytorch/blob/4aa66f68a803927ddd127ceaaa1521b8d6e90e5f/aten/src/ATen/native/mps/OperationUtils.mm#L372 to decide if gathering is necessary. It's not the most efficient way to solve this issue, but it assures correctness for sliced inputs. ### Performance impact #### With fix ``` python -m timeit -n 100 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x)" 100 loops, best of 5: 282 usec per loop python -m timeit -n 100 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x[5:])" 100 loops, best of 5: 448 usec per loop python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x)" 1000 loops, best of 5: 705 usec per loop python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x[5:])" 1000 loops, best of 5: 1.11 msec per loop python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(1000, 100, 35, 45).to('mps')" "bn(x)" 1000 loops, best of 5: 7.16 msec per loop python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(1000, 100, 35, 45).to('mps')" "bn(x[5:])" 1000 l A93C oops, best of 5: 11.7 msec per loop ``` #### Without fix ``` python -m timeit -n 100 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x)" 100 loops, best of 5: 284 usec per loop python -m timeit -n 100 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x[5:])" 100 loops, best of 5: 265 usec per loop python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x)" 1000 loops, best of 5: 715 usec per loop python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x[5:])" 1000 loops, best of 5: 675 usec per loop python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(1000, 100, 35, 45).to('mps')" "bn(x)" 1000 loops, best of 5: 7.19 msec per loop python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(1000, 100, 35, 45).to('mps')" "bn(x[5:])" 1000 loops, best of 5: 7.13 msec per loop ``` Please feel free to push back or request changes. Fixes #133520 Pull Request resolved: #133610 Approved by: https://github.com/malfet (cherry picked from commit 43f78bf)
Cherry picking #133610The cherry pick PR is at #134121 and it is linked with issue #133520. The following tracker issues are updated: Details for Dev Infra teamRaised by workflow job |
[MPS] Gather sliced inputs to batch norm (#133610) This PR removes the `executeGatherOp` flag from batch norm in favor of relying on the logic in https://github.com/pytorch/pytorch/blob/4aa66f68a803927ddd127ceaaa1521b8d6e90e5f/aten/src/ATen/native/mps/OperationUtils.mm#L372 to decide if gathering is necessary. It's not the most efficient way to solve this issue, but it assures correctness for sliced inputs. ### Performance impact #### With fix ``` python -m timeit -n 100 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x)" 100 loops, best of 5: 282 usec per loop python -m timeit -n 100 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x[5:])" 100 loops, best of 5: 448 usec per loop python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x)" 1000 loops, best of 5: 705 usec per loop python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x[5:])" 1000 loops, best of 5: 1.11 msec per loop python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(1000, 100, 35, 45).to('mps')" "bn(x)" 1000 loops, best of 5: 7.16 msec per loop python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(1000, 100, 35, 45).to('mps')" "bn(x[5:])" 1000 loops, best of 5: 11.7 msec per loop ``` #### Without fix ``` python -m timeit -n 100 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x)" 100 loops, best of 5: 284 usec per loop python -m timeit -n 100 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x[5:])" 100 loops, best of 5: 265 usec per loop python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x)" 1000 loops, best of 5: 715 usec per loop python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x[5:])" 1000 loops, best of 5: 675 usec per loop python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(1000, 100, 35, 45).to('mps')" "bn(x)" 1000 loops, best of 5: 7.19 msec per loop python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(1000, 100, 35, 45).to('mps')" "bn(x[5:])" 1000 loops, best of 5: 7.13 msec per loop ``` Please feel free to push back or request changes. Fixes #133520 Pull Request resolved: #133610 Approved by: https://github.com/malfet (cherry picked from commit 43f78bf) Co-authored-by: Roy Hvaara <roy@lightyear.no>
This PR removes the `executeGatherOp` flag from batch norm in favor of relying on the logic in https://github.com/pytorch/pytorch/blob/4aa66f68a803927ddd127ceaaa1521b8d6e90e5f/aten/src/ATen/native/mps/OperationUtils.mm#L372 to decide if gathering is necessary. It's not the most efficient way to solve this issue, but it assures correctness for sliced inputs. ### Performance impact #### With fix ``` python -m timeit -n 100 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x)" 100 loops, best of 5: 282 usec per loop python -m timeit -n 100 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x[5:])" 100 loops, best of 5: 448 usec per loop python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x)" 1000 loops, best of 5: 705 usec per loop python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x[5:])" 1000 loops, best of 5: 1.11 msec per loop python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(1000, 100, 35, 45).to('mps')" "bn(x)" 1000 loops, best of 5: 7.16 msec per loop python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(1000, 100, 35, 45).to('mps')" "bn(x[5:])" 1000 loops, best of 5: 11.7 msec per loop ``` #### Without fix ``` python -m timeit -n 100 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x)" 100 loops, best of 5: 284 usec per loop python -m timeit -n 100 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x[5:])" 100 loops, best of 5: 265 usec per loop python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x)" 1000 loops, best of 5: 715 usec per loop python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x[5:])" 1000 loops, best of 5: 675 usec per loop python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(1000, 100, 35, 45).to('mps')" "bn(x)" 1000 loops, best of 5: 7.19 msec per loop python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(1000, 100, 35, 45).to('mps')" "bn(x[5:])" 1000 loops, best of 5: 7.13 msec per loop ``` Please feel free to push back or request changes. Fixes #133520 Pull Request resolved: #133610 Approved by: https://github.com/malfet
This PR removes the
executeGatherOp
flag from batch norm in favor of relying on the logic inpytorch/aten/src/ATen/native/mps/OperationUtils.mm
Line 372 in 4aa66f6
It's not the most efficient way to solve this issue, but it assures correctness for sliced inputs.
Performance impact
With fix
Without fix
Please feel free to push back or request changes.
Fixes #133520