8000 remove guard_size_oblivious from unbind. by laithsakka · Pull Request #148815 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

remove guard_size_oblivious from unbind. #148815

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: gh/laithsakka/117/base
Choose a base branch
from

Conversation

laithsakka
Copy link
Contributor
@laithsakka laithsakka commented Mar 8, 2025

Stack from ghstack (oldest at bottom):

unbind will always specialize on dim, because it determine the number of output tensors.
guard_size_oblivious is not useful there and more confusing probably for code readers
added a comment and a test that verifies the specialization.

Copy link
pytorch-bot bot commented Mar 8, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/148815

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 4 New Failures

As of commit 8260058 with merge base c4a0b31 (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

laithsakka added a commit that referenced this pull request Mar 8, 2025
ghstack-source-id: d88a7cb
Pull Request resolved: #148815
Copy link
Contributor
github-actions bot commented Mar 8, 2025

This PR needs a release notes: label

If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@laithsakka laithsakka changed the title remove guard_size_oblivious from unbind [DRAFT] remove guard_size_oblivious from unbind, but anyway seems like its useless there? Mar 8, 2025
@laithsakka laithsakka changed the title [DRAFT] remove guard_size_oblivious from unbind, but anyway seems like its useless there? [DRAFT] remove guard_size_oblivious from unbind. Mar 8, 2025
This was interesting why?
i initially thought this is the best rewrite for this
```
   
    # Looking at this what will happen because of guard_size_oblivious/guard_or_false 
    # At least for normal tesnors we would throw during runtime when t.shape[dim] == 0 the following error. 
    # RuntimeError: number of sections must be larger than 0, got 0
    # Because  of calling torch.tensor_split(t, t.shape[dim], dim) 

    # Maybe a better error msg alternative here, is to throw in such case an error
    # RuntimeError: unbind unbacked semantics does not support 0 inputs during runtime.
    def fall_back(): 
        torch._check(t.shape[dim] == 0,
         lambda: "unbind unbacked semantics does not support 0 inputs during runtime.")
        return False
     if guard_or_else(t.shape[dim] == 0, fall_back):
        return ()
    else:
        return tuple(
            torch.squeeze(s, dim) for s in torch.tensor_split(t, t.shape[dim], dim)
        )


```
but then i tried to add examples
ex:
```
from torch.fx.experimental.symbolic_shapes import sym_or
torch.compile(fullgraph=True)
def func(y):
      return y.unbind(dim=2)

z = torch.ones(1,1,1,1)
func(z)
torch._dynamo.decorators.mark_unbacked(z, 2)
```
but we actually do hit another guard 

Specifically guard_int() bellow

```
std::vector<Tensor> tensor_split_sections_symint(
    const Tensor& self,
    c10::SymInt sym_sections,
    int64_t dim) {
  TORCH_CHECK(
      self.dim() > 0,
      "tensor_split expected at least a 1-dimensional tensor, but got a tensor with ",
      self.dim(),
      " dims");
  int64_t dim_ = maybe_wrap_dim(dim, self.dim());
  // NB: intentional, sections specifies number of output tensors, which
  // cannot be polymorphic
  int64_t sections = sym_sections.guard_int(__FILE__, __LINE__);
 ```
  so my thoughts, is i do not know what is being avoid by adding this, it was added in its on PR so assume for a reason?
  #124959
  I will keep the current semantics of guard_size_oblivious and just use guard_or_false.
  
  maybe for other types of tensors  tensor_split_sections_symint does not throw runtime error or re-guard?
  idk

[ghstack-poisoned]
Divigroup-RAP pushed a commit to Divigroup-RAP/PYTORCH that referenced this pull request Apr 22, 2025
This was interesting why?
i initially thought this is the best rewrite for this
```
   
    # Looking at this what will happen because of guard_size_oblivious/guard_or_false 
    # At least for normal tesnors we would throw during runtime when t.shape[dim] == 0 the following error. 
    # RuntimeError: number of sections must be larger than 0, got 0
    # Because  of calling torch.tensor_split(t, t.shape[dim], dim) 

    # Maybe a better error msg alternative here, is to throw in such case an error
    # RuntimeError: unbind unbacked semantics does not support 0 inputs during runtime.
    def fall_back(): 
        torch._check(t.shape[dim] == 0,
         lambda: "unbind unbacked semantics does not support 0 inputs during runtime.")
        return False
     if guard_or_else(t.shape[dim] == 0, fall_back):
        return ()
    else:
        return tuple(
            torch.squeeze(s, dim) for s in torch.tensor_split(t, t.shape[dim], dim)
        )


```
but then i tried to add examples
ex:
```
from torch.fx.experimental.symbolic_shapes import sym_or
torch.compile(fullgraph=True)
def func(y):
      return y.unbind(dim=2)

z = torch.ones(1,1,1,1)
func(z)
torch._dynamo.decorators.mark_unbacked(z, 2)
```
but we actually do hit another guard 

Specifically guard_int() bellow

```
std::vector<Tensor> tensor_split_sections_symint(
    const Tensor& self,
    c10::SymInt sym_sections,
    int64_t dim) {
  TORCH_CHECK(
      self.dim() > 0,
      "tensor_split expected at least a 1-dimensional tensor, but got a tensor with ",
      self.dim(),
      " dims");
  int64_t dim_ = maybe_wrap_dim(dim, self.dim());
  // NB: intentional, sections specifies number of output tensors, which
  // cannot be polymorphic
  int64_t sections = sym_sections.guard_int(__FILE__, __LINE__);
 ```
  so my thoughts, is i do not know what is being avoid by adding this, it was added in its on PR so assume for a reason?
  #124959
  I will keep the current semantics of guard_size_oblivious and just use guard_or_false.
  
  maybe for other types of tensors  tensor_split_sections_symint does not throw runtime error or re-guard?
  idk

[ghstack-poisoned]
laithsakka added a commit that referenced this pull request May 2, 2025
ghstack-source-id: 41f4725
Pull Request resolved: #148815
@laithsakka laithsakka requested a review from aorenste May 2, 2025 17:02
@laithsakka laithsakka changed the title [DRAFT] remove guard_size_oblivious from unbind. remove guard_size_oblivious from unbind. May 2, 2025
@laithsakka laithsakka requested review from bobrenjc93 and pianpwk May 2, 2025 17:04

# Note: t.shape[dim] can't be dynamic or unbacked, even if we use guard_or_false here we will fail
# later in the split since t.shape[dim] control the number of output tensors.
if t.shape[dim] == 0:
Copy link
Contributor
@pianpwk pianpwk May 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't run this locally, but it seems the data-dependent error the user would see would now originate from here? Their first instinct might then be to try to slap a guard_size_oblivious on again.

Is there some way we can prevent that? Not that it's a big deal.
e.g. by adding

n_out = int(t.shape[dim])  # unbind returns a list, so we force-specialize on the output length anyways; this cannot remain symbolic
...
return tuple(
    torch.squeeze(s, dim) for s in torch.tensor_split(t, n_out, dim)
)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeh I added this comment so people avoid that, good idea i can do that to make it explicit.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants
0