8000 [dynamic shapes] guard_or_false for computeStorageNbytes by pianpwk · Pull Request #150483 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[dynamic shapes] guard_or_false for computeStorageNbytes #150483

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 22 commits into from

Conversation

pianpwk
Copy link
Contributor
@pianpwk pianpwk commented Apr 1, 2025

Copy link
pytorch-bot bot commented Apr 1, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/150483

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 5a9460e with merge base cbcb57d (image):

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@@ -1701,6 +1702,9 @@ def _process_sym_expr(sym: sympy.Expr, hint: Optional[Union[int, bool, float]] =
compiler_min=vr.lower, # type: ignore[arg-type]
compiler_max=vr.upper, # type: ignore[arg-type]
)
# ShapeEnv meta
if isinstance(sym, sympy.Symbol):
self.shape_env.var_to_stack[sym] = CapturedTraceback.extract(skip=1)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixes one-off deserialization issue

@@ -160,7 +160,7 @@ SymInt computeStorageNbytes(
// of the last element according to stride
SymInt size = 1;
for (const auto i : c10::irange(sizes.size())) {
if (TORCH_GUARD_SIZE_OBLIVIOUS(sizes[i].sym_eq(0))) {
if (TORCH_GUARD_OR_FALSE(sizes[i].sym_eq(0))) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it make sense to add a runtime assert in this case
something like

torch._check(sizes[i].sym_neq(0), f"We assumes that unbacked size {sizes[i]} is not 0 but it turn out to be zero at rumtime,}",)
alternatively if its ok for this function to return something that is greater than the actual storage number of bytes.
at line 169 we can do to make sure we do not return negative.
return Max(itemsize_bytes * (storage_offset + size),0) ;

Copy link
Contributor Author
@pianpwk pianpwk Apr 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't it ok in this case? Like if the tensor numel is 0 at runtime, aren't we literally storing 0 bytes for the tensor?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the output if we skip this wont be zero though?
like if we return false, but the numel is actually 0 then at line 167
size += strides[i] * (sizes[i] - 1);
size would be
1 + strides[i]*(-1) is the strides in that case guaranteed to be 1? if not we will get.
1-strides[0] which is not always 0 no?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess my questions did you check what would be the actual output at runtime if we skip this one when the size is 0. and if the break things?
one way to see what could fail is this.
#151172

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now I understand that we are already doing this right now, and that well, we are just extending it briefly in this case by making it work for u0-u2 .. etc instead of just u0, u0+u2...
but lets wait for the test above, i think if things fail probably better to fix the soundness and insert a runtime assert on dde that its not actually zero.
if nothing fail we can debug it to understand what happen and decide.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it caused a lot of issues, see CI before I removed the check: https://hud.pytorch.org/pytorch/pytorch/pull/150483?sha=ccc45c38c60db864098da4ce31647bad4f7eee45

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a comment as the following:

# This used to be TORCH_GUARD_SIZE_OBLIVIOUS, but since any size is always >=0, assuming that TORCH_GUARD_SIZE_OBLIVIOUS was safe we extended the assumption to all other unbacked expressions.

Copy link
Contributor
@laithsakka laithsakka Apr 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also can we make
size += strides[i] * (sizes[i] - 1);
--->
size += strides[i] * max(0, (sizes[i] - 1));
I just do not want to this to every possibly return a negative

Copy link
Contributor
@laithsakka laithsakka Apr 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe this is why this is safe
https://www.internalfb.com/diff/D56851139?whitespace=SHOW_ALL

See the comment bellow.


  c10::SymInt new_size_bytes = result.is_contiguous()
        ? at::detail::computeStorageNbytesContiguous(
              size, itemsize, std::move(storage_offset))
        : at::detail::computeStorageNbytes(
              size, stride, itemsize, std::move(storage_offset));
    // TODO: When there are unbacked SymInts, we unconditionally skip the
    // setter.  This is technically wrong, but we cannot conveniently test
    // the real condition in many cases, because a lot of people are using
    // set_ just to swizzle metadata on a tensor, they didn't actually want
    // to see if they need to 
8000
resize the storage.
    //
    // The old behavior was to unconditionally set_nbytes, but I think not
    // setting it is more safe.
    if (new_size_bytes.has_hint() && storage.sym_nbytes().has_hint() &&
        TORCH_GUARD_SIZE_OBLIVIOUS(
            new_size_bytes.sym_gt(storage.sym_nbytes()))) {
      storage.set_nbytes(std::move(new_size_bytes));
    }
    ```
    
    or at least look like this was not safe and the diff above tried to make it safer but did not want to remove the guard size oblivious?

Copy link
Contributor Author
@pianpwk pianpwk Apr 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice find, sounds like this was never really safe, we just avoided it here. I'm worried the Max(0, *) will cause problems for our weak min/max reasoning, but let's see what CI says

@pianpwk
Copy link
Contributor Author
pianpwk commented Apr 11, 2025

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased pianpwk/oblivious_storagenbytes onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout pianpwk/oblivious_storagenbytes && git pull --rebase)

@pytorchmergebot pytorchmergebot force-pushed the pianpwk/oblivious_storagenbytes branch from 12a6643 to 173d186 Compare April 11, 2025 21:53
@pianpwk pianpwk marked this pull request as ready for review April 11, 2025 22:27
@pianpwk pianpwk requested review from bdhirsh and laithsakka April 11, 2025 22:27
@@ -393,6 +394,9 @@ def f(x):

self.assertEqual(counter.frame_count, 2) # not three or four!

# TODO(laithsakka): guard_or_false fallback should occur before oblivious/unbacked hints
# maybe we can deprecate this option with backed_size_oblivious?
@unittest.expectedFailure
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happens because falling back to oblivious hint can add guard that triggers an additional (e.g. u0 != 0)), avoidable if we had used the False fallback earlier

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a fix for this

@@ -160,7 +160,7 @@ SymInt computeStorageNbytes(
// of the last element according to stride
SymInt size = 1;
for (const auto i : c10::irange(sizes.size())) {
if (TORCH_GUARD_SIZE_OBLIVIOUS(sizes[i].sym_eq(0))) {
if (TORCH_GUARD_OR_FALSE(sizes[i].sym_eq(0))) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

# this  is kind more closer to current guard_size_obl(sz.sym_eq(0))
def func(sz):
    if guard_or_false(sz>=0):
         return guard_or_false(sz.sym_eq(0))
     else:
        return sz==0

pytorchmergebot pushed a commit that referenced this pull request May 1, 2025
We know the output shape, and we know this always produces a clone. Avoids data-dependent errors from the decomposition.

along with #150483, should fix #123855
Pull Request resolved: #152129
Approved by: https://github.com/laithsakka
@aorenste
Copy link
Contributor
aorenste commented May 5, 2025

FYI: This change is blocking #152662

@aorenste
Copy link
Contributor
aorenste commented May 5, 2025

After rebasing I needed this patch to make this work: P1803805396

@pianpwk
Copy link
Contributor Author
pianpwk commented May 9, 2025

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label May 9, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants
0