-
Notifications
You must be signed in to change notification settings - Fork 24.2k
MPS: Datatype precedence in binary ops [WIP] #78319
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful links
✅ No Failures (0 Pending)As of commit 4db40e8 (more details on the Dr. CI page): Expand to see more💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Please report bugs/suggestions to the (internal) Dr. CI Users group. |
I'm kind of stuck here now at two points - although there are more than two problems remaining that need to be fixed.
|
About the second problem. I was able to isolate that outside of the tests. A minimal example is: import torch
torch.tensor(True, device='mps')
print(torch.full((5,), False, device='mps')) Output: Output with some memory debugging enabled - it all looks correct to me:
The issue happens when an MPS tensor is released immediately before an MPS Tensor is allocated using Here are more examples and experiments: import torch as pt
print(hex(pt.tensor(1, device='mps').data_ptr()))
print(hex(pt.tensor(True, device='mps').data_ptr()))
# x will be tensor([ True, False, False, False, False, False, False, False, False, False], device='mps:0')
# ==> First element corrupted in both cases!
print(hex(pt.tensor(0.1, device='mps').data_ptr()))
# x will be tensor([ True, False, False, True, False, False, False, False, False, False], device='mps:0')
# ==> Corruption happens for any datatype, memory seems "mixed"
s = pt.tensor(1, device='mps')
print(hex(s.data_ptr()), s)
# x will not be corrupted
# ==> Memory is retained by binding to `s`, and no corruption happens!
# This could mean that Tensors are not freed correctly.
s = pt.tensor(1, device='mps')
print(hex(s.data_ptr()), s)
del(s)
# x will be tensor([ True, False, False, False, False, False, False, False, False, False], device='mps:0')
# ==> Corruption again! Freeing scalar tensors causes corruption.
print(hex(pt.tensor(True).data_ptr()))
# x will not be corrupted
# ==> Corruption only happens via MPS scalar tensors
x = pt.full((10,), False, device="mps") # this one gets corrupted
# x = pt.full((10,), 0, device="mps") # no corruption (dtype != bool)
# x = pt.full((10,), 0.1, device="mps") # no corruption (dtype != bool)
print(hex(x.data_ptr()), x) # identical memory location as the scalar tensor!
if any(x):
print('Tensor corrupted') Before this PR, the |
@kulinseth Sorry to bother you - maybe you can talk a look when you find the time? I may have found one or two bugs in MPSGraph related to Bool tensors.. or I'm just using it wrong, no idea :)
The two problems may or may not be linked.. |
Hi @lhoenig, Sorry for the delay in response. Do you still see this issue? |
Hi @kulinseth, no problem, thanks for taking a look! I think there are still multiple issues in the current The test I wrote is quite useful to get an overview of what works, I think. It tests the combination of different binary operations with all data types and representative values of them and asserts that the result is equal to the CPU result. |
@kulinseth I made a script for you that contains my test function in isolated form and continues if something fails, i.e. lists all failing cases: https://gist.github.com/lhoenig/b528b1dc6fc65bdfe020f1269c7dc121 The test runs through without issue using my PR branch. Many of the commented out things also work, and my comments above where I got stuck were focused on getting these to work fully as well.
Oh my, I was checking out the wrong |
@kulinseth Now it looks much much better - almost everything works, and you even got One thing that is fixed in my PR here and not on May I suggest the following fix, with which all tests related to division are fixed: diff --git a/aten/src/ATen/native/mps/operations/BinaryOps.mm b/aten/src/ATen/native/mps/operations/BinaryOps.mm
index 9ea4a16aac..c4a53d31bf 100644
--- a/aten/src/ATen/native/mps/operations/BinaryOps.mm
+++ b/aten/src/ATen/native/mps/operations/BinaryOps.mm
@@ -146,8 +146,15 @@ void div_mode_template(const Tensor& self, const Tensor& other,
{
BinaryOpBlock div_mode_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) {
MPSGraph* mpsGraph = cachedGraph->graph();
- MPSGraphTensor* divTensor = [mpsGraph divisionWithPrimaryTensor:primaryCastTensor
- secondaryTensor:secondaryCastTensor
+ // Division needs float inputs, result should also always be float32
+ MPSGraphTensor* primaryFloat32 = [mpsGraph castTensor:primaryCastTensor
+ toType:MPSDataTypeFloat32
+ name:@"castPrimary"];
+ MPSGraphTensor* secondaryFloat32 = [mpsGraph castTensor:secondaryCastTensor
+ toType:MPSDataTypeFloat32
+ name:@"castSecondary"];
+ MPSGraphTensor* divTensor = [mpsGraph divisionWithPrimaryTensor:primaryFloat32
+ secondaryTensor:secondaryFloat32
name:nil];
if (!rounding_mode.has_value()) {
return divTensor; With this fix applied, everything works as well as my PR branch here. Regarding
Could be just an expected error due to different order of operations in the backends but it is interesting that it only happens for certain combinations of dimensionalities, apparently: it only fails for the cases where exactly one of the Tensors has dimension 1 and the other dimension 0. If both have dimension 1 or 0 there is no discrepancy. |
@kulinseth I updated my test script removing done TODOs and enabled the last part, the one that tests This last part also works great now, except for Bools actually. If you uncomment line 28 enabling Bool dtype testing you immediately get the following crash:
|
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
@lhoenig , thanks.
This is indeed true. We also encountered this in our local testing and made a fix here: Please provide feedback if this looks okay. |
Also @lhoenig , can you please go ahead and create PR with your testcase and if there are things which are missing in master. |
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/78319
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
/easycla As part of the transition to the PyTorch Foundation, this project now requires contributions be covered under the new CLA. See #85559 for additional details. This comment will trigger a new check of this PR. If you are already covered, you will simply see a new "EasyCLA" check that passes. If you are not covered, a bot will leave a new comment with a link to sign. |
@kulinseth Excuse me taking so long, please! |
See #84742 and #78319. The test case tests that - for the binary operations (add, sub, mul, div), - for all data types (dtypes), - for a range of representative values and their combinations, - for various shapes and ways of creating the test tensors, the contents and dtype of the result tensor is identical for the MPS and CPU backends. It adds about 15-18s runtime to `test_mps.py`. Pull Request resolved: #87545 Approved by: https://github.com/kit1980
Issue #78020 revealed that boolean Tensors on MPS don't behave as expected when subjected to arithmetic operations with other Tensors. This problem in fact extends to all other dtypes to some degree. The underlying reason for the unexpected behavior is that there are currently no rules for what the dominating dtype in a binary operation should be. For example, when multiplying two Tensors with types
int
andfloat
together, the output Tensor should be afloat
one, and theint
Tensor needs to be cast before the binary operation takes place. Otherwise, wrong results or crashes can happen.This PR adds these rules and behavior and generally aims to correct any discrepancies in the outputs of binary operations between CPU and MPS backends (currently working only on
add,sub,mul,div
). The precedence list is modeled after the CPU backend behavior. An extensive test is provided that revealed additional issues. These are still in the process of being fixed.The assumption is made here that the dimensionality or order of both arguments does not matter for casting. Basically, the output dtype is only determined by the set of the two input dtypes, except for division, where the output is
float32
regardless of the input dtypes (this is already taken care of before the Tensors enterbinaryOpTensor
, so no special case is needed there). All these behaviors seem to be consistent with the CPU backend, as this is what the test function tests against.The PR code will evolve a bit more but the approach could already be reviewed. The long if/else chain to determine the output dtype is not really pretty on first sight but is a solution with minimal overhead.
Lastly, similar or identical logic might be needed in other places. This should be investigated. If that turns out to be the case it might be good to encapsulate the core logic introduced here, as it should be universal.
Fixes #78020