8000 [MPS] Gather sliced inputs to batch norm by hvaara · Pull Request #133610 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[MPS] Gather sliced inputs to batch norm #133610

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from

Conversation

hvaara
Copy link
Contributor
@hvaara hvaara commented Aug 15, 2024

This PR removes the executeGatherOp flag from batch norm in favor of relying on the logic in

if (needsGather(src) && gatherTensorData) {
to decide if gathering is necessary.

It's not the most efficient way to solve this issue, but it assures correctness for sliced inputs.

Performance impact

With fix

python -m timeit -n 100 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x)"
100 loops, best of 5: 282 usec per loop

python -m timeit -n 100 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x[5:])"
100 loops, best of 5: 448 usec per loop

python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x)"
1000 loops, best of 5: 705 usec per loop

python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x[5:])"
1000 loops, best of 5: 1.11 msec per loop

python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(1000, 100, 35, 45).to('mps')" "bn(x)"
1000 loops, best of 5: 7.16 msec per loop

python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(1000, 100, 35, 45).to('mps')" "bn(x[5:])"
1000 loops, best of 5: 11.7 msec per loop

Without fix

python -m timeit -n 100 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x)"
100 loops, best of 5: 284 usec per loop

python -m timeit -n 100 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x[5:])"
100 loops, best of 5: 265 usec per loop

python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x)"
1000 loops, best of 5: 715 usec per loop

python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x[5:])"
1000 loops, best of 5: 675 usec per loop

python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(1000, 100, 35, 45).to('mps')" "bn(x)"
1000 loops, best of 5: 7.19 msec per loop

python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(1000, 100, 35, 45).to('mps')" "bn(x[5:])"
1000 loops, best of 5: 7.13 msec per loop

Please feel free to push back or request changes.

Fixes #133520

Copy link
pytorch-bot bot commented Aug 15, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/133610

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 4b16953 with merge base 546c53b (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: mps Release notes category label Aug 15, 2024
Comment on lines -156 to -161
bool executeGatherOp = true;
if (self.is_contiguous(memory_format)) {
memory_format = MemoryFormat::Contiguous;
executeGatherOp = false;
}

Copy link
Contributor Author
@hvaara hvaara Aug 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does my assumption that the decision to gather is sufficiently handled in

if (needsGather(src) && gatherTensorData) {
hold?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Considering #94760, the issue is tricky I would say. The test expects input layout is preserved in output. Let's look at CI signals.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like everything passed in CI.

Comment on lines +2542 to +2545
bn_cpu = nn.BatchNorm2d(100, affine=False, device='cpu')
bn_mps = nn.BatchNorm2d(100, affine=False, device='mps')

x_cpu = torch.randn(100, 100, 35, 45).to('cpu')
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these values fine, or should I consider lowering them to cover the minimal case? timeit results suggests a single loop is sub-millisecond.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's sub-millisecond this is fine

@hvaara
Copy link
Contributor Author
hvaara commented Aug 15, 2024

@pytorchbot label "module: mps" "ciflow/mps"

Copy link
pytorch-bot bot commented Aug 15, 2024

Can't add following labels to PR: ciflow/mps. Please ping one of the reviewers for help.

@hvaara
Copy link
Contributor Author
hvaara commented Aug 15, 2024

cc @bauerwer for visibility

@bauerwer
Copy link

@hvaara Quick check passed on my system. I integrated the Normalization.mm changes into my torch main, built locally and was able to run flux without noise in the image. I ran a few other checkpoints and workflows as well without problems.
Screenshot 2024-08-16 at 02 03 00
Screenshot 2024-08-16 at 02 17 01

@bdhirsh bdhirsh added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Aug 16, 2024
Copy link
Collaborator
@qqaatw qqaatw left a comment
8000

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Following the fix here, one situation is always my concern: what if output is an alias of self, i.e. they share the same storage. I think to address it thoroughly, we need #128393.

Comment on lines -156 to -161
bool executeGatherOp = true;
if (self.is_contiguous(memory_format)) {
memory_format = MemoryFormat::Contiguous;
executeGatherOp = false;
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Considering #94760, the issue is tricky I would say. The test expects input layout is preserved in output. Let's look at CI signals.

@qqaatw qqaatw added ciflow/mps Run MPS tests (subset of trunk) ciflow/trunk Trigger trunk jobs on your pull request labels Aug 17, 2024
@hvaara
Copy link
Contributor Author
hvaara commented Aug 20, 2024

@kulinseth @malfet @qqaatw are either of you available for a review? 🙏

Copy link
Contributor
@malfet malfet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me, let's cherry-pick it into 2.4.1, but I think the problem (and therefore the fix, should be more generic that this one, as all other ops are likely similarly affected

Comment on lines +2542 to +2545
bn_cpu = nn.BatchNorm2d(100, affine=False, device='cpu')
bn_mps = nn.BatchNorm2d(100, affine=False, device='mps')

x_cpu = torch.randn(100, 100, 35, 45).to('cpu')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's sub-millisecond this is fine

@malfet malfet added this to the 2.4.1 milestone Aug 20, 2024
@malfet
Copy link
Contributor
malfet commented Aug 20, 2024

@pytorchbot merge -f "MPS tests are green"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@hvaara
Copy link
Contributor Author
hvaara commented Aug 20, 2024

Thanks for the review! Following up in #133520 to address comments and future work.

@hvaara hvaara deleted the sliced-batch-norm branch August 20, 2024 18:30
@atalman
Copy link
Contributor
atalman commented Aug 21, 2024

@pytorchbot cherry-pick --onto release/2.4 -c critical --fixes #133520

pytorchbot pushed a commit that referenced this pull request Aug 21, 2024
This PR removes the `executeGatherOp` flag from batch norm in favor of relying on the logic in https://github.com/pytorch/pytorch/blob/4aa66f68a803927ddd127ceaaa1521b8d6e90e5f/aten/src/ATen/native/mps/OperationUtils.mm#L372 to decide if gathering is necessary.

It's not the most efficient way to solve this issue, but it assures correctness for sliced inputs.

### Performance impact

#### With fix

```
python -m timeit -n 100 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x)"
100 loops, best of 5: 282 usec per loop

python -m timeit -n 100 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x[5:])"
100 loops, best of 5: 448 usec per loop

python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x)"
1000 loops, best of 5: 705 usec per loop

python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x[5:])"
1000 loops, best of 5: 1.11 msec per loop

python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(1000, 100, 35, 45).to('mps')" "bn(x)"
1000 loops, best of 5: 7.16 msec per loop

python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(1000, 100, 35, 45).to('mps')" "bn(x[5:])"
1000 l
A93C
oops, best of 5: 11.7 msec per loop
```

#### Without fix

```
python -m timeit -n 100 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x)"
100 loops, best of 5: 284 usec per loop

python -m timeit -n 100 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x[5:])"
100 loops, best of 5: 265 usec per loop

python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x)"
1000 loops, best of 5: 715 usec per loop

python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x[5:])"
1000 loops, best of 5: 675 usec per loop

python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(1000, 100, 35, 45).to('mps')" "bn(x)"
1000 loops, best of 5: 7.19 msec per loop

python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(1000, 100, 35, 45).to('mps')" "bn(x[5:])"
1000 loops, best of 5: 7.13 msec per loop
```

Please feel free to push back or request changes.

Fixes #133520
Pull Request resolved: #133610
Approved by: https://github.com/malfet

(cherry picked from commit 43f78bf)
@pytorchbot
Copy link
Collaborator

Cherry picking #133610

The cherry pick PR is at #134121 and it is linked with issue #133520. The following tracker issues are updated:

Details for Dev Infra team Raised by workflow job

atalman pushed a commit that referenced this pull request Aug 22, 2024
[MPS] Gather sliced inputs to batch norm (#133610)

This PR removes the `executeGatherOp` flag from batch norm in favor of relying on the logic in https://github.com/pytorch/pytorch/blob/4aa66f68a803927ddd127ceaaa1521b8d6e90e5f/aten/src/ATen/native/mps/OperationUtils.mm#L372 to decide if gathering is necessary.

It's not the most efficient way to solve this issue, but it assures correctness for sliced inputs.

### Performance impact

#### With fix

```
python -m timeit -n 100 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x)"
100 loops, best of 5: 282 usec per loop

python -m timeit -n 100 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x[5:])"
100 loops, best of 5: 448 usec per loop

python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x)"
1000 loops, best of 5: 705 usec per loop

python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x[5:])"
1000 loops, best of 5: 1.11 msec per loop

python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(1000, 100, 35, 45).to('mps')" "bn(x)"
1000 loops, best of 5: 7.16 msec per loop

python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(1000, 100, 35, 45).to('mps')" "bn(x[5:])"
1000 loops, best of 5: 11.7 msec per loop
```

#### Without fix

```
python -m timeit -n 100 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x)"
100 loops, best of 5: 284 usec per loop

python -m timeit -n 100 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x[5:])"
100 loops, best of 5: 265 usec per loop

python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x)"
1000 loops, best of 5: 715 usec per loop

python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x[5:])"
1000 loops, best of 5: 675 usec per loop

python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(1000, 100, 35, 45).to('mps')" "bn(x)"
1000 loops, best of 5: 7.19 msec per loop

python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(1000, 100, 35, 45).to('mps')" "bn(x[5:])"
1000 loops, best of 5: 7.13 msec per loop
```

Please feel free to push back or request changes.

Fixes #133520
Pull Request resolved: #133610
Approved by: https://github.com/malfet

(cherry picked from commit 43f78bf)

Co-authored-by: Roy Hvaara <roy@lightyear.no>
pytorch-bot bot pushed a commit that referenced this pull request Sep 13, 2024
This PR removes the `executeGatherOp` flag from batch norm in favor of relying on the logic in https://github.com/pytorch/pytorch/blob/4aa66f68a803927ddd127ceaaa1521b8d6e90e5f/aten/src/ATen/native/mps/OperationUtils.mm#L372 to decide if gathering is necessary.

It's not the most efficient way to solve this issue, but it assures correctness for sliced inputs.

### Performance impact

#### With fix

```
python -m timeit -n 100 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x)"
100 loops, best of 5: 282 usec per loop

python -m timeit -n 100 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x[5:])"
100 loops, best of 5: 448 usec per loop

python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x)"
1000 loops, best of 5: 705 usec per loop

python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x[5:])"
1000 loops, best of 5: 1.11 msec per loop

python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(1000, 100, 35, 45).to('mps')" "bn(x)"
1000 loops, best of 5: 7.16 msec per loop

python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(1000, 100, 35, 45).to('mps')" "bn(x[5:])"
1000 loops, best of 5: 11.7 msec per loop
```

#### Without fix

```
python -m timeit -n 100 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x)"
100 loops, best of 5: 284 usec per loop

python -m timeit -n 100 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x[5:])"
100 loops, best of 5: 265 usec per loop

python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x)"
1000 loops, best of 5: 715 usec per loop

python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x[5:])"
1000 loops, best of 5: 675 usec per loop

python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(1000, 100, 35, 45).to('mps')" "bn(x)"
1000 loops, best of 5: 7.19 msec per loop

python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(1000, 100, 35, 45).to('mps')" "bn(x[5:])"
1000 loops, best of 5: 7.13 msec per loop
```

Please feel free to push back or request changes.

Fixes #133520
Pull Request resolved: #133610
Approved by: https://github.com/malfet
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/mps Run MPS tests (subset of trunk) ciflow/trunk Trigger trunk jobs on your pull request Merged open source release notes: mps Release notes category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[MPS] Incorrect result from batch norm with sliced inputs
8 participants
0