8000 [dynamic shapes] guard_or_false for computeStorageNbytes by pianpwk · Pull Request #150483 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[dynamic shapes] guard_or_false for computeStorageNbytes #150483

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 22 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion aten/src/ATen/EmptyTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,10 +160,15 @@ SymInt computeStorageNbytes(
// of the last element according to stride
SymInt size = 1;
for (const auto i : c10::irange(sizes.size())) {
if (TORCH_GUARD_SIZE_OBLIVIOUS(sizes[i].sym_eq(0))) {
if (TORCH_GUARD_OR_FALSE(sizes[i].sym_eq(0))) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it make sense to add a runtime assert in this case
something like

torch._check(sizes[i].sym_neq(0), f"We assumes that unbacked size {sizes[i]} is not 0 but it turn out to be zero at rumtime,}",)
alternatively if its ok for this function to return something that is greater than the actual storage number of bytes.
at line 169 we can do to make sure we do not return negative.
return Max(itemsize_bytes * (storage_offset + size),0) ;

Copy link
Contributor Author
@pianpwk pianpwk Apr 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't it ok in this case? Like if the tensor numel is 0 at runtime, aren't we literally storing 0 bytes for the tensor?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the output if we skip this wont be zero though?
like if we return false, but the numel is actually 0 then at line 167
size += strides[i] * (sizes[i] - 1);
size would be
1 + strides[i]*(-1) is the strides in that case guaranteed to be 1? if not we will get.
1-strides[0] which is not always 0 no?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess my questions did you check what would be the actual output at runtime if we skip this one when the size is 0. and if the break things?
one way to see what could fail is this.
#151172

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now I understand that we are already doing this right now, and that well, we are just extending it briefly in this case by making it work for u0-u2 .. etc instead of just u0, u0+u2...
but lets wait for the test above, i think if things fail probably better to fix the soundness and insert a runtime assert on dde that its not actually zero.
if nothing fail we can debug it to understand what happen and decide.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it caused a lot of issues, see CI before I removed the check: https://hud.pytorch.org/pytorch/pytorch/pull/150483?sha=ccc45c38c60db864098da4ce31647bad4f7eee45

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a comment as the following:

# This used to be TORCH_GUARD_SIZE_OBLIVIOUS, but since any size is always >=0, assuming that TORCH_GUARD_SIZE_OBLIVIOUS was safe we extended the assumption to all other unbacked expressions.

Copy link
Contributor
@laithsakka laithsakka Apr 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also can we make
size += strides[i] * (sizes[i] - 1);
--->
size += strides[i] * max(0, (sizes[i] - 1));
I just do not want to this to every possibly return a negative

Copy link
Contributor
@laithsakka laithsakka Apr 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe this is why this is safe
https://www.internalfb.com/diff/D56851139?whitespace=SHOW_ALL

See the comment bellow.


  c10::SymInt new_size_bytes = result.is_contiguous()
        ? at::detail::computeStorageNbytesContiguous(
              size, itemsize, std::move(storage_offset))
        : at::detail::computeStorageNbytes(
              size, stride, itemsize, std::move(storage_offset));
    // TODO: When there are unbacked SymInts, we unconditionally skip the
    // setter.  This is technically wrong, but we cannot conveniently test
    // the real condition in many cases, because a lot of people are using
    // set_ just to swizzle metadata on a tensor, they didn't actually want
    // to see if they need to resize the storage.
    //
    // The old behavior was to unconditionally set_nbytes, but I think not
    // setting it is more safe.
    if (new_size_bytes.has_hint() && storage.sym_nbytes().has_hint() &&
        TORCH_GUARD_SIZE_OBLIVIOUS(
            new_size_bytes.sym_gt(storage.sym_nbytes()))) {
      storage.set_nbytes(std::move(new_size_bytes));
    }
    ```
    
    or at least look like this was not safe and the diff above tried to make it safer but did not want to remove the guard size oblivious?

Copy link
Contributor Author
@pianpwk pianpwk Apr 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice find, sounds like this was never really safe, we just avoided it here. I'm worried the Max(0, *) will cause problems for our weak min/max reasoning, but let's see what CI says

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

# this  is kind more closer to current guard_size_obl(sz.sym_eq(0))
def func(sz):
    if guard_or_false(sz>=0):
         return guard_or_false(sz.sym_eq(0))
     else:
        return sz==0

return 0;
}

// NOTE: while this can technically return negative sizes for
// 0-element tensors, there's a check in TensorShape:set_storage_meta__symint
// that skips setting nbytes with unbacked expressions.
// Would probably be safer to wrap this with a max(*, 0),
// once our min/max symbolic reasoning improves.
size += strides[i] * (sizes[i] - 1);
}
return itemsize_bytes * (storage_offset + size);
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/dynamo/pr_time_benchmarks/expected_results.csv
628C
Original file line numberDiff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,97140000



update_hint_regression,compile_time_instruction_count,1622000000,0.02
update_hint_regression,compile_time_instruction_count,1677500000,0.02



Expand Down
9 changes: 9 additions & 0 deletions test/dynamo/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -7633,6 +7633,15 @@ def dyn_fn(x):
torch.compile(dyn_fn, backend="eager")(y)

@torch._dynamo.config.patch(capture_scalar_outputs=True)
def test_unbacked_empty_tensor(self):
@torch.compile(backend="eager", fullgraph=True)
def fn(x):
n = x.item()
return torch.empty((n - 1) // 2)

self.assertEqual(fn(torch.tensor([4])).size(0), 1)
self.assertEqual(fn(torch.tensor([1])).size(0), 0)

def test_unbacked_2d_expand(self):
@torch.compile(fullgraph=True, dynamic=True, backend="inductor")
def func(a, b):
Expand Down
3 changes: 0 additions & 3 deletions test/export/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -4389,8 +4389,6 @@ def forward(self, x):
ep.module()(torch.tensor([5]))
ep.module()(torch.tensor([1]))

@testing.expectedFailureTrainingIRToRunDecomp
@testing.expectedFailureTrainingIRToRunDecompNonStrict
def test_unbacked_pad(self):
class Foo(torch.nn.Module):
def forward(self, xs, pad):
Expand Down Expand Up @@ -13654,7 +13652,6 @@ def forward(self, input1: torch.Tensor):
inps = (torch.randn(1, 224, 768, device="cpu"),)
export(Foo(), inps)

@testing.expectedFailureCppSerDes # TODO(pianpwk): PowByNatural valuerange deserialization
def test_dim_dynamic(self):
dynamic = Dim.DYNAMIC

Expand Down
Loading
0