-
Notifications
You must be signed in to change notification settings - Fork 24.3k
[dynamic shapes] guard_or_false for computeStorageNbytes #150483
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
Show all changes
22 commits
Select commit
Hold shift + click to select a range
6b76bf6
init
pianpwk b6f0964
lint
pianpwk 88da25c
fix test
pianpwk 15dae80
test
pianpwk 6201d42
mark_as_oblivious
pianpwk c96f389
mark expected fail
pianpwk 173d186
Update test_recompiles.py
pianpwk 56d047b
Merge branch 'main' of https://github.com/pytorch/pytorch into pianpw…
pianpwk 90aea1b
try with >=1 check
pianpwk ccc45c3
Update EmptyTensor.cpp
pianpwk a06fa83
remove >= 1 check
pianpwk 8cae0b8
Update EmptyTensor.cpp
pianpwk 3d9fff4
size-oblivious test
pianpwk a389dca
lint
pianpwk 7683a45
Update EmptyTensor.cpp
pianpwk 5a9c277
Merge branch 'main' of https://github.com/pytorch/pytorch into pianpw…
pianpwk 06f1f71
Update test_export.py
pianpwk 6c91597
Update expected_results.csv
pianpwk d6ad2b8
Update EmptyTensor.cpp
pianpwk ec21175
Merge branch 'main' into pianpwk/oblivious_storagenbytes
pianpwk abc6beb
Update expected_results.csv
pianpwk 5a9460e
merge
pianpwk File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -160,10 +160,15 @@ SymInt computeStorageNbytes( | |
// of the last element according to stride | ||
SymInt size = 1; | ||
for (const auto i : c10::irange(sizes.size())) { | ||
if (TORCH_GUARD_SIZE_OBLIVIOUS(sizes[i].sym_eq(0))) { | ||
if (TORCH_GUARD_OR_FALSE(sizes[i].sym_eq(0))) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
return 0; | ||
} | ||
|
||
// NOTE: while this can technically return negative sizes for | ||
// 0-element tensors, there's a check in TensorShape:set_storage_meta__symint | ||
// that skips setting nbytes with unbacked expressions. | ||
// Would probably be safer to wrap this with a max(*, 0), | ||
// once our min/max symbolic reasoning improves. | ||
size += strides[i] * (sizes[i] - 1); | ||
} | ||
return itemsize_bytes * (storage_offset + size); | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would it make sense to add a runtime assert in this case
something like
torch._check(sizes[i].sym_neq(0), f"We assumes that unbacked size {sizes[i]} is not 0 but it turn out to be zero at rumtime,}",)
alternatively if its ok for this function to return something that is greater than the actual storage number of bytes.
at line 169 we can do to make sure we do not return negative.
return Max(itemsize_bytes * (storage_offset + size),0) ;
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't it ok in this case? Like if the tensor numel is 0 at runtime, aren't we literally storing 0 bytes for the tensor?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the output if we skip this wont be zero though?
like if we return false, but the numel is actually 0 then at line 167
size += strides[i] * (sizes[i] - 1);
size would be
1 + strides[i]*(-1) is the strides in that case guaranteed to be 1? if not we will get.
1-strides[0] which is not always 0 no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess my questions did you check what would be the actual output at runtime if we skip this one when the size is 0. and if the break things?
one way to see what could fail is this.
#151172
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now I understand that we are already doing this right now, and that well, we are just extending it briefly in this case by making it work for u0-u2 .. etc instead of just u0, u0+u2...
but lets wait for the test above, i think if things fail probably better to fix the soundness and insert a runtime assert on dde that its not actually zero.
if nothing fail we can debug it to understand what happen and decide.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it caused a lot of issues, see CI before I removed the check: https://hud.pytorch.org/pytorch/pytorch/pull/150483?sha=ccc45c38c60db864098da4ce31647bad4f7eee45
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you add a comment as the following:
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also can we make
size += strides[i] * (sizes[i] - 1);
--->
size += strides[i] * max(0, (sizes[i] - 1));
I just do not want to this to every possibly return a negative
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe this is why this is safe
https://www.internalfb.com/diff/D56851139?whitespace=SHOW_ALL
See the comment bellow.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice find, sounds like this was never really safe, we just avoided it here. I'm worried the
Max(0, *)
will cause problems for our weak min/max reasoning, but let's see what CI says