8000 MPS: Datatype precedence in binary ops [WIP] by lhoenig · Pull Request #78319 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

MPS: Datatype precedence in binary ops [WIP] #78319

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 23 commits into from

Conversation

lhoenig
Copy link
Contributor
@lhoenig lhoenig commented May 25, 2022

Issue #78020 revealed that boolean Tensors on MPS don't behave as expected when subjected to arithmetic operations with other Tensors. This problem in fact extends to all other dtypes to some degree. The underlying reason for the unexpected behavior is that there are currently no rules for what the dominating dtype in a binary operation should be. For example, when multiplying two Tensors with types int and float together, the output Tensor should be a float one, and the int Tensor needs to be cast before the binary operation takes place. Otherwise, wrong results or crashes can happen.

This PR adds these rules and behavior and generally aims to correct any discrepancies in the outputs of binary operations between CPU and MPS backends (currently working only on add,sub,mul,div). The precedence list is modeled after the CPU backend behavior. An extensive test is provided that revealed additional issues. These are still in the process of being fixed.

The assumption is made here that the dimensionality or order of both arguments does not matter for casting. Basically, the output dtype is only determined by the set of the two input dtypes, except for division, where the output is float32 regardless of the input dtypes (this is already taken care of before the Tensors enter binaryOpTensor, so no special case is needed there). All these behaviors seem to be consistent with the CPU backend, as this is what the test function tests against.

The PR code will evolve a bit more but the approach could already be reviewed. The long if/else chain to determine the output dtype is not really pretty on first sight but is a solution with minimal overhead.

Lastly, similar or identical logic might be needed in other places. This should be investigated. If that turns out to be the case it might be good to encapsulate the core logic introduced here, as it should be universal.

Fixes #78020

@facebook-github-bot
Copy link
Contributor
facebook-github-bot commented May 25, 2022

🔗 Helpful links

✅ No Failures (0 Pending)

As of commit 4db40e8 (more details on the Dr. CI page):

Expand to see more

💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@lhoenig
Copy link
Contributor Author
lhoenig commented May 27, 2022

I'm kind of stuck here now at two points - although there are more than two problems remaining that need to be fixed.

  1. In the Python test, I check that the order of arguments should not matter, and test different combinations of dimensionality of both inputs. I use torch.full there to create an input of dimension 1 with length 20 filled with a value. These tensors are created differently from the usually created tensors, using [MPSGraph constantWithScalar:shape:dataType] in ConstantOps.mm. Apparently, these behave differently under certain conditions, i.e. when cast, and the test revealed this. Bool, int8 and int16 are affected. See the long comments in the code changes and also my gist here where I reproduced the issue outside of PyTorch: https://gist.github.com/lhoenig/5759db313c1ba919f9baad52221c7917 . I really don't know what I'm doing wrong..
  2. Another mysterious bug that seems to be dependent on the operations that happened before. Suddenly in the middle of the test torch.full((20,), False, device='mps') would return tensor([False, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True],, whereas in a fresh run it always works. Its completely unclear why this happens, but could be related to the first issue.

@lhoenig
Copy link
Contributor Author
lhoenig commented May 29, 2022

About the second problem. I was able to isolate that outside of the tests. A minimal example is:

import torch
torch.tensor(True, device='mps')
print(torch.full((5,), False, device='mps'))

Output: tensor([ True, False, False, False, False], device='mps:0')

Output with some memory debugging enabled - it all looks correct to me:

Allocated small private heap of size 8.00 MB (#heaps: 1, free memory: 21.33 GB)
Allocated private buffer #1 with aligned size 256 bytes (requested size: 1 bytes, heap size: 7.98 MB, total allocated: 256 bytes)
Reusing private buffer #1 with aligned size 256 bytes (requested size: 5 bytes)
tensor([ True, False, False, False, False], device='mps:0')
Released buffer #1 of size 256 bytes (heap size: 8.00 MB, total allocated: 0 bytes)
Released heap of size 8.00 MB (free memory: 21.33 GB)

The issue happens when an MPS tensor is released immediately before an MPS Tensor is allocated using torch.full and filled with a bool value. I have no idea why the old memory doesn't get overwritten in this case.

Here are more examples and experiments:

import torch as pt

print(hex(pt.tensor(1, device='mps').data_ptr()))
print(hex(pt.tensor(True, device='mps').data_ptr()))
# x will be tensor([ True, False, False, False, False, False, False, False, False, False], device='mps:0')
# ==> First element corrupted in both cases!

print(hex(pt.tensor(0.1, device='mps').data_ptr()))
# x will be tensor([ True, False, False,  True, False, False, False, False, False, False], device='mps:0')
# ==> Corruption happens for any datatype, memory seems "mixed"

s = pt.tensor(1, device='mps')
print(hex(s.data_ptr()), s)
# x will not be corrupted
# ==> Memory is retained by binding to `s`, and no corruption happens!
#     This could mean that Tensors are not freed correctly.

s = pt.tensor(1, device='mps')
print(hex(s.data_ptr()), s)
del(s)
# x will be tensor([ True, False, False, False, False, False, False, False, False, False], device='mps:0')
# ==> Corruption again! Freeing scalar tensors causes corruption.

print(hex(pt.tensor(True).data_ptr()))
# x will not be corrupted
# ==> Corruption only happens via MPS scalar tensors

x = pt.full((10,), False, device="mps")  # this one gets corrupted
# x = pt.full((10,), 0, device="mps")    # no corruption (dtype != bool)
# x = pt.full((10,), 0.1, device="mps")  # no corruption (dtype != bool)
print(hex(x.data_ptr()), x)  # identical memory location as the scalar tensor!
if any(x):
    print('Tensor corrupted')

Before this PR, the pt.full((10,), False, device="mps") would crash. This issue too points to one or more problems with [MPSGraph constantWithScalar], although I'm really not sure.

@lhoenig
Copy link
Contributor Author
lhoenig commented May 31, 2022

@kulinseth Sorry to bother you - maybe you can talk a look when you find the time? I may have found one or two bugs in MPSGraph related to Bool tensors.. or I'm just using it wrong, no idea :)
To summarize, the two problems are:

  1. Can't cast constantWithScalar-created Tensor with value 1.0 (dtype doesn't matter) to Bool, the result will always be 255 (-1) whereas it should be 1 (True). The casting works correctly for Tensor data passed to the graph from outside, see https://gist.github.com/lhoenig/5759db313c1ba919f9baad52221c7917 . Any other values work fine.. Thats why currently I use this hacky workaround with substituting 1.1 as the scalar value for "True", as 1.0 will give "False".
  2. Memory from old freed Tensor doesn't get overwritten when using torch.full, which internally also uses constantWithScalar. Again, this issue only happens for Bool tensors. See above comment for more details.

The two problems may or may not be linked..

@kulinseth
Copy link
Collaborator

@kulinseth Sorry to bother you - maybe you can talk a look when you find the time? I may have found one or two bugs in MPSGraph related to Bool tensors.. or I'm just using it wrong, no idea :) To summarize, the two problems are:

  1. Can't cast constantWithScalar-created Tensor with value 1.0 (dtype doesn't matter) to Bool, the result will always be 255 (-1) whereas it should be 1 (True). The casting works correctly for Tensor data passed to the graph from outside, see https://gist.github.com/lhoenig/5759db313c1ba919f9baad52221c7917 . Any other values work fine.. Thats why currently I use this hacky workaround with substituting 1.1 as the scalar value for "True", as 1.0 will give "False".
  2. Memory from old freed Tensor doesn't get overwritten when using torch.full, which internally also uses constantWithScalar. Again, this issue only happens for Bool tensors. See above comment for more details.

The two problems may or may not be linked..

Hi @lhoenig, Sorry for the delay in response. Do you still see this issue?
We have used type Promotion to c10::promoteTypes(self.scalar_type(), other.scalar_type()) to handle the type mismatches in the graph. For Boolean tensors , the issue still exists ?

@lhoenig
Copy link
Contributor Author
lhoenig commented Jul 3, 2022

Hi @kulinseth, no problem, thanks for taking a look! I think there are still multiple issues in the current master related to MPS type precedence. The test case I wrote doesn't get very far with the current master code, but most things were working using my PR here. There seem to be at least problems with int32, bool, and division.

The test I wrote is quite useful to get an overview of what works, I think. It tests the combination of different binary operations with all data types and representative values of them and asserts that the result is equal to the CPU result.

@lhoenig
Copy link
Contributor Author
lhoenig commented Jul 3, 2022

@kulinseth I made a script for you that contains my test function in isolated form and continues if something fails, i.e. lists all failing cases: https://gist.github.com/lhoenig/b528b1dc6fc65bdfe020f1269c7dc121

The test runs through without issue using my PR branch. Many of the commented out things also work, and my comments above where I got stuck were focused on getting these to work fully as well.

On master I get immediate crashes in MPS actually (broadcast incompatibility, could be that the type promotion is done too late). When I was using my test function in the test harness before, I didn't get the crashes somehow, but still many different errors depending on how much I comment out.

Oh my, I was checking out the wrong master (the classic..). So I was still looking at an old version. I am compiling the real master from scratch again and will report results shortly.

@lhoenig
Copy link
Contributor Author
lhoenig commented Jul 3, 2022

@kulinseth Now it looks much much better - almost everything works, and you even got float16 mostly working, nice!

One thing that is fixed in my PR here and not on master is special handling for division: I figured out that for division, both tensors always need to be cast to float32 before performing the division, that is how the CPU backend does it. This special handling cannot be captured in a type promotion function that doesn't depend on the operation.

May I suggest the following fix, with which all tests related to division are fixed:

diff --git a/aten/src/ATen/native/mps/operations/BinaryOps.mm b/aten/src/ATen/native/mps/operations/BinaryOps.mm
index 9ea4a16aac..c4a53d31bf 100644
--- a/aten/src/ATen/native/mps/operations/BinaryOps.mm
+++ b/aten/src/ATen/native/mps/operations/BinaryOps.mm
@@ -146,8 +146,15 @@ void div_mode_template(const Tensor& self, const Tensor& other,
 {
   BinaryOpBlock div_mode_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) {
     MPSGraph* mpsGraph = cachedGraph->graph();
-    MPSGraphTensor* divTensor =  [mpsGraph divisionWithPrimaryTensor:primaryCastTensor
-                                                     secondaryTensor:secondaryCastTensor
+    // Division needs float inputs, result should also always be float32
+    MPSGraphTensor* primaryFloat32 = [mpsGraph castTensor:primaryCastTensor
+                                                   toType:MPSDataTypeFloat32
+                                                     name:@"castPrimary"];
+    MPSGraphTensor* secondaryFloat32 = [mpsGraph castTensor:secondaryCastTensor
+                                                     toType:MPSDataTypeFloat32
+                                                       name:@"castSecondary"];
+    MPSGraphTensor* divTensor =  [mpsGraph divisionWithPrimaryTensor:primaryFloat32
+                                                     secondaryTensor:secondaryFloat32
                                                                 name:nil];
     if (!rounding_mode.has_value()) {
       return divTensor;

With this fix applied, everything works as well as my PR branch here.

Regarding float16, there are still some failures with multiplication and division (that are not fixed by the fix above). The division failures are too many to list them here all in detail, but the multiplication failures are as such:

torch.float16,torch.float32: (-234.5).mul(0.1)
FAIL for dim1 <op> dim0, results:
mps: tensor([-23.4531], device='mps:0', dtype=torch.float16)
cpu: tensor([-23.4375], dtype=torch.float16)

torch.float32,torch.float16: (0.1).mul(-234.5)
FAIL for dim0 <op> dim1, results:
mps: tensor([-23.4531], device='mps:0', dtype=torch.float16)
cpu: tensor([-23.4375], dtype=torch.float16)

torch.float16,torch.float32: (-234.5).mul(111.99)
FAIL for dim1 <op> dim0, results:
mps: tensor([-26256.], device='mps:0', dtype=torch.float16)
cpu: tensor([-26272.], dtype=torch.float16)

torch.float32,torch.float16: (111.99).mul(-234.5)
FAIL for dim0 <op> dim1, results:
mps: tensor([-26256.], device='mps:0', dtype=torch.float16)
cpu: tensor([-26272.], dtype=torch.float16)

Could be just an expected error due to different order of operations in the backends but it is interesting that it only happens for certain combinations of dimensionalities, apparently: it only fails for the cases where exactly one of the Tensors has dimension 1 and the other dimension 0. If both have dimension 1 or 0 there is no discrepancy.

@lhoenig
Copy link
Contributor Author
lhoenig commented Jul 3, 2022

@kulinseth I updated my test script removing done TODOs and enabled the last part, the one that tests torch.full combined with a scalar and the other way around, and that both of these are equal: https://gist.github.com/lhoenig/b528b1dc6fc65bdfe020f1269c7dc121

This last part also works great now, except for Bools actually. If you uncomment line 28 enabling Bool dtype testing you immediately get the following crash:

torch.bool,torch.bool: (False).add(False)
/AppleInternal/Library/BuildRoots/b6051351-c030-11ec-96e9-3e7866fcf3a1/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm:343: failed assertion `unsupported datatype for constant'
[1]    95044 abort      python3 ~/test_mps_dtypes_precedence.py

@github-actions
Copy link
Contributor
github-actions bot commented Sep 1, 2022

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions bot added the Stale label Sep 1, 2022
@kulinseth
Copy link
Collaborator

@lhoenig , thanks.

One thing that is fixed in my PR here and not on master is special handling for division: I figured out that for division, both tensors always need to be cast to float32 before performing the division, that is how the CPU backend does it. This special handling cannot be captured in a type promotion function that doesn't depend on the operation.

This is indeed true. We also encountered this in our local testing and made a fix here:
#84742

Please provide feedback if this looks okay.

@kulinseth
Copy link
Collaborator

Also @lhoenig , can you please go ahead and create PR with your testcase and if there are things which are missing in master.

@pytorch-bot
Copy link
pytorch-bot bot commented Sep 27, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/78319

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot
Copy link
Contributor

/easycla

As part of the transition to the PyTorch Foundation, this project now requires contributions be covered under the new CLA. See #85559 for additional details.

This comment will trigger a new check of this PR. If you are already covered, you will simply see a new "EasyCLA" check that passes. If you are not covered, a bot will leave a new comment with a link to sign.

@linux-foundation-easycla
Copy link
linux-foundation-easycla bot commented Oct 3, 2022

CLA Signed

The committers listed above are authorized under a signed CLA.

@lhoenig
Copy link
Contributor Author
lhoenig commented Oct 22, 2022

@kulinseth Excuse me taking so long, please!
Division works perfectly now, and everything else that my testcase covers as well!
I made a PR with the testcase here: #87545

@lhoenig lhoenig closed this Oct 22, 2022
pytorchmergebot pushed a commit that referenced this pull request Nov 17, 2022
See #84742 and #78319.

The test case tests that
- for the binary operations (add, sub, mul, div),
- for all data types (dtypes),
- for a range of representative values and their combinations,
- for various shapes and ways of creating the test tensors,

the contents and dtype of the result tensor is identical for the MPS and CPU backends.

It adds about 15-18s runtime to `test_mps.py`.
Pull Request resolved: #87545
Approved by: https://github.com/kit1980
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

MPS: crashes and strange behavior with bool Tensors
4 participants
0