8000 mark_unbacked for strides. · Issue #153204 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

mark_unbacked for strides. #153204

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
laithsakka opened this issue May 8, 2025 · 0 comments
Open

mark_unbacked for strides. #153204

laithsakka opened this issue May 8, 2025 · 0 comments
Labels
module: dynamic shapes oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@laithsakka
Copy link
Contributor
laithsakka commented May 8, 2025

want to avoid 0/1 specializations due to strides =1 we do not have a way to do that right now.
the following will always compile twice once for non contiguous input and once for contiguous.
although the first version is just fine for both
see the comment
Pass in a contiguous tensor, it will recompile due to stride being 1 unfortunately.
on top of #148872

 @skipIfTorchDynamo("not allowed to trace mark_unbacked")
    @torch._dynamo.config.patch("capture_scalar_outputs", True)
    def test_unbacked_non_contigious_reshape1(self):
        cnt = torch._dynamo.testing.CompileCounter()
        # since this happen in place, either reshape or view would work.
        # reshape u1 -> (u0*u0)
        # this result in the tensor "i64[u0, u0][s7*u0, s7].
        # reshape happens in place reshape (no-clone)
        def func(x, y):
            f = y.item()
            t1 = x.view((f, f))
            t2 = x.reshape((f, f))
            # TODO avoid _check_is_size here.
            torch._check_is_size(f)
            return t1 * 10, t2*10

        compiled_func = torch.compile(
            fullgraph=True,
            backend=cnt,
            dynamic=True,
        )(func)

        # create a non-contigious with data being even numbers in [0:cnt-1]
        # and reshape it into sqrt(cnt)*sqrt(cnt)
        def make_non_contiguous_tensor_and_test(cnt):
            # create a non-contiguous tensor x that is skipping odd indices.
            x = torch.arange(cnt * 2)
            x = x.as_strided((x.size()[0] // 2,), (2,))
            
            torch._dynamo.decorators.mark_unbacked(x, 0)
            sz = torch.tensor([int(math.sqrt(cnt))])
            compiled_result = compiled_func(x, sz)
            eager_result = func(x, sz)
            self.assertEqual(compiled_result, eager_result)

        make_non_contiguous_tensor_and_test(4)
        make_non_contiguous_tensor_and_test(49)
        self.assertEqual(cnt.frame_count, 2) 
        
        # Pass in a contiguous tensor, it will recompile due to stride being 1 unfortunately.
        # marking strides unabcked would have fixed it probably. 
        x = torch.arange(100)
        compiled_result = compiled_func(x, torch.tensor([10]))
        eager_result = func(x, torch.tensor([10]))
        self.assertEqual(compiled_result, eager_result)
        self.assertEqual(cnt.frame_count, 1) 

cc @chauhang @penguinwu @ezyang @bobrenjc93

@masnesral masnesral added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label May 12, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: dynamic shapes oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

3 participants
0