8000 Data dependent free reshape. by laithsakka · Pull Request #153198 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Data dependent free reshape. #153198

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 15 commits into
base: gh/laithsakka/172/base
Choose a base branch
from

Conversation

laithsakka
Copy link
Contributor
@laithsakka laithsakka commented May 8, 2025

Stack from ghstack (oldest at bottom):

change 1

Lets consider the most general case, if torch compile is asked to reshape [u0, u1][u3, u4] -> [u5, u6] what shall it do?
The shape is general enough to represent both contiguous and non contiguous tensors, tensors where a clone free reshape can happen and other where a clone free cant happen. The current algorithm will fail due to data dependent errors.

The general idea is if its impossible to tell if the reshape can happen in place, (because for some concrete inputs
it will and other not) then its ok to take the general path and clone, instead of failing or asking the user to give hints.

In with this change reshape works as the following:

  1. if we know the input is contiguous we will convert the reshape to view.
  2. if compute_strides succeed we will use view. (compute_strides was changed to not fail when when unbacked presented instead it will just return nullptr if it cant compute the strides meaning we shall use a clone).
  3. if neither 1, 2 works clone and use a view.

Side note: having a view does not mean that inductor will not clone, for inductor there is a pass that converts all views back to reshapes.

change 2 :

We trace _reshape_view_helper when doing fake tensor tracing , but not during proxy tracing. hence such tracing wont effect the graph (only compute output shapes of several operations). We should not fail there, because it should always be possible for us to pass it in case of reshape.
when reshape_symint was called we would have either cloned, or compute_strides succeeded so the view should pass.

What I did is the following: we run _reshape_view_helper, if we fail due to unbacked we call _view_simple which will succeed always for reshapes, (might fail for views when its impossible to do the view, in such case we throw the dde that was thrown by the original algorithm).

Ideally I would want to register _view_simple as the meta and avoid calling _reshape_view_helper completely but I am running some issues with the dispatcher with subclasses and I do not have time to debug it. namely one test
would end up calling view instead of view_symint during meta dispatch when i register a meta decompositions
python test/dynamo/test_subclasses.py SubclassTests.test_subclass_views_dynamic_True
#153303 will follow up with that change in a separate PR. cc @bdhirsh

Two other alternatives for registering _view_simple as meta and the try catch approach in this PR is:

  1. call _view_simple if any input is dynamic see [DRAFT] avoidance strategy for reshape_view_helper guards in compile: call _view_simple if inputs has dynamic dimensions. #153521
  2. if we make is_compiling works for framework code tracing (does not work rn) we can call _view_simple
    is if is_compiling.

Note:

Reshape can still fail when is_contiguous is called, Next PR will handle that by calling is_known_contiguous.

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov

[ghstack-poisoned]
Copy link
pytorch-bot bot commented May 8, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/153198

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ciflow/inductor module: dynamo module: inductor oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (fsdp) release notes category labels May 8, 2025
laithsakka added a commit that referenced this pull request May 8, 2025
ghstack-source-id: c64eb2d
Pull Request resolved: #153198
@laithsakka laithsakka changed the title new_reshape Data dependent free reshape. May 8, 2025
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
laithsakka added a commit that referenced this pull request May 8, 2025
ghstack-source-id: f9a865c
Pull Request resolved: #153198
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
@laithsakka laithsakka requested a review from bobrenjc93 as a code owner May 10, 2025 02:34
laithsakka added a commit that referenced this pull request May 10, 2025
ghstack-source-id: 9c36c39
Pull Request resolved: #153198

:	modified:   torch/_refs/__init__.py
@laithsakka laithsakka changed the title Data dependent free reshape. [draft] Data dependent free reshape. May 10, 2025
8000
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
laithsakka added a commit that referenced this pull request May 10, 2025
ghstack-source-id: 57a4858
Pull Request resolved: #153198

:	modified:   torch/_refs/__init__.py
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
laithsakka added a commit that referenced this pull request May 12, 2025
ghstack-source-id: 599f371
Pull Request resolved: #153198

:	modified:   torch/_refs/__init__.py
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
laithsakka added a commit that referenced this pull request May 12, 2025
ghstack-source-id: 7fd5afc
Pull Request resolved: #153198

:	modified:   torch/_refs/__init__.py
@laithsakka laithsakka changed the title [draft] Data dependent free reshape. Data dependent free reshape. May 12, 2025
#### change 1
Lets consider the most general case, if PyTorch is asked to reshape [u0, u1][u3, u4] -> [u5, u6] what shall it do?
The current algorithm will fail due to data dependent errors. 

The general thoughts, is if its impossible to tell if the reshape can happen in place, (because for some concrete inputs
it will and other not) then its ok to take the general path and clone, instead of failing or asking the user to give hints.

In with this change reshape works as the following:
1. if we know the input is contiguous we will convert the reshape to view. 
2. if compute_strides succeed we will use view. (compute_strides  was changed to not fail when when unbacked presented instead it will just return nullptr if it cant compute the strides meaning we shall use a clone).
3. if neither 1, 2 works clone and use a view.

Side note: having a view does not mean that inductor will not clone, for inductor there is a pass that converts all views back to reshapes. 

#### change 2:
We used to trace _reshape_view_helper when doing fake tensor tracing , but not during proxy tracing. 
Since such tracing wont effect the graph, i registered a simpler version fore meta tracing that matches the
reshape logic in c++. and that will not fail with DDE. Note that _reshape_view_helper is very NOT dynamic friendly.
Address #153303
 
#### change 4: 
Reshape can still fail when is_contiguous is called, Next PR will handle that by calling is_known_contiguous. 

cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
laithsakka added a commit that referenced this pull request May 13, 2025
ghstack-source-id: 1ba1ef0
Pull Request resolved: #153198

:	modified:   torch/_refs/__init__.py
@laithsakka laithsakka marked this pull request as draft May 13, 2025 18:52
#### change 1
Lets consider the most general case, if PyTorch is asked to reshape [u0, u1][u3, u4] -> [u5, u6] what shall it do?
The current algorithm will fail due to data dependent errors. 

The general thoughts, is if its impossible to tell if the reshape can happen in place, (because for some concrete inputs
it will and other not) then its ok to take the general path and clone, instead of failing or asking the user to give hints.

In with this change reshape works as the following:
1. if we know the input is contiguous we will convert the reshape to view. 
2. if compute_strides succeed we will use view. (compute_strides  was changed to not fail when when unbacked presented instead it will just return nullptr if it cant compute the strides meaning we shall use a clone).
3. if neither 1, 2 works clone and use a view.

Side note: having a view does not mean that inductor will not clone, for inductor there is a pass that converts all views back to reshapes. 

#### change 2:
We used to trace _reshape_view_helper when doing fake tensor tracing , but not during proxy tracing. 
Since such tracing wont effect the graph, i registered a simpler version fore meta tracing that is dynamic shapes freindly 
(less recompilations that are not needed and less no data dependent errors). Address #153303.

The only side effect is the following , in the dynamo/aot_graph for some rare cases the graphs might show 
strides that are different than those of eager. Note that since we do not decompose the view into the actual 
operations that called by _reshape_view_helper. anyway there is no guarantee that the view will have the same
output shape as that of _reshape_view_helper.
 
#### change 4: 
Reshape can still fail when is_contiguous is called, Next PR will handle that by calling is_known_contiguous. 

cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
laithsakka added a commit that referenced this pull request May 13, 2025
ghstack-source-id: 821c2ec
Pull Request resolved: #153198

:	modified:   torch/_refs/__init__.py
@laithsakka
Copy link
Contributor Author

I will split the meta registration into its own function since its hitting issues in the dispatcher for some tests where we call for subclasses we call view instead of view_symint and crash lol I do not have time to debug so will just split the change to a later diff maybe to unblock and will revert back to the try catch appraoch .
this is not very ideal since the reshape could actually be adding guards that are not needed and causing undesired recompilations but we will move step by step/.

#### change 1
Lets consider the most general case, if torch compile is asked to reshape [u0, u1][u3, u4] -> [u5, u6] what shall it do?
The shape is general enough to represent both contiguous and non contiguous tensors, tensors where a clone free reshape can happen and other where a clone free cant happen.  The current algorithm will fail due to data dependent errors. 

The general idea is if its impossible to tell if the reshape can happen in place, (because for some concrete inputs
it will and other not) then its ok to take the general path and clone, instead of failing or asking the user to give hints.

In with this change reshape works as the following:
1. if we know the input is contiguous we will convert the reshape to view. 
2. if compute_strides succeed we will use view. (compute_strides  was changed to not fail when when unbacked presented instead it will just return nullptr if it cant compute the strides meaning we shall use a clone).
3. if neither 1, 2 works clone and use a view.

Side note: having a view does not mean that inductor will not clone, for inductor there is a pass that converts all views back to reshapes. 

#### change 2 :
We trace _reshape_view_helper when doing fake tensor tracing , but not during proxy tracing. hence such tracing wont effect the graph (only compute output shapes of several operations). We should not fail there, because it should always be possible for us to pass it in case of reshape.
when reshape_symint was called we would have either cloned, or compute_strides succeeded so the view should pass.
 
What I did is the following: we run _reshape_view_helper, if we fail due to unbacked we call _view_simple which will succeed always for reshapes, (might fail for views when its impossible to do the view, in such case we throw the dde that was thrown by the original algorithm).

Ideally I would want to register _view_simple as the meta and avoid calling  _reshape_view_helper completely but I am running an issue with the dispatcher with subclasses and I do not have time to debug it. namely one test 
would end up calling view instead of view_symint during meta dispatch when i register a meta decompositions
```python test/dynamo/test_subclasses.py SubclassTests.test_subclass_views_dynamic_True ```
 #153303.
 I will follow up with that change in a separate PR.  cc bdhirsh  
 
#### Note:
Reshape can still fail when is_contiguous is called, Next PR will handle that by calling is_known_contiguous. 

cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
laithsakka added a commit that referenced this pull request May 13, 2025
ghstack-source-id: fbd66fc
Pull Request resolved: #153198

:	modified:   torch/_refs/__init__.py


def _reshape_view_helper_core_alg(
a: TensorLikeType, shape, allow_copy: bool
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moved as is

#### change 1
Lets consider the most general case, if torch compile is asked to reshape [u0, u1][u3, u4] -> [u5, u6] what shall it do?
The shape is general enough to represent both contiguous and non contiguous tensors, tensors where a clone free reshape can happen and other where a clone free cant happen.  The current algorithm will fail due to data dependent errors. 

The general idea is if its impossible to tell if the reshape can happen in place, (because for some concrete inputs
it will and other not) then its ok to take the general path and clone, instead of failing or asking the user to give hints.

In with this change reshape works as the following:
1. if we know the input is contiguous we will convert the reshape to view. 
2. if compute_strides succeed we will use view. (compute_strides  was changed to not fail when when unbacked presented instead it will just return nullptr if it cant compute the strides meaning we shall use a clone).
3. if neither 1, 2 works clone and use a view.

Side note: having a view does not mean that inductor will not clone, for inductor there is a pass that converts all views back to reshapes. 

#### change 2 :
We trace _reshape_view_helper when doing fake tensor tracing , but not during proxy tracing. hence such tracing wont effect the graph (only compute output shapes of several operations). We should not fail there, because it should always be possible for us to pass it in case of reshape.
when reshape_symint was called we would have either cloned, or compute_strides succeeded so the view should pass.
 
What I did is the following: we run _reshape_view_helper, if we fail due to unbacked we call _view_simple which will succeed always for reshapes, (might fail for views when its impossible to do the view, in such case we throw the dde that was thrown by the original algorithm).

Ideally I would want to register _view_simple as the meta and avoid calling  _reshape_view_helper completely but I am running an issue with the dispatcher with subclasses and I do not have time to debug it. namely one test 
would end up calling view instead of view_symint during meta dispatch when i register a meta decompositions
```python test/dynamo/test_subclasses.py SubclassTests.test_subclass_views_dynamic_True ```
 #153303.
 I will follow up with that change in a separate PR.  cc bdhirsh  
 
#### Note:
Reshape can still fail when is_contiguous is called, Next PR will handle that by calling is_known_contiguous. 

cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
laithsakka added a commit that referenced this pull request May 13, 2025
ghstack-source-id: 82ee9a8
Pull Request resolved: #153198

:	modified:   torch/_refs/__init__.py
#### change 1
Lets consider the most general case, if torch compile is asked to reshape [u0, u1][u3, u4] -> [u5, u6] what shall it do?
The shape is general enough to represent both contiguous and non contiguous tensors, tensors where a clone free reshape can happen and other where a clone free cant happen.  The current algorithm will fail due to data dependent errors. 

The general idea is if its impossible to tell if the reshape can happen in place, (because for some concrete inputs
it will and other not) then its ok to take the general path and clone, instead of failing or asking the user to give hints.

In with this change reshape works as the following:
1. if we know the input is contiguous we will convert the reshape to view. 
2. if compute_strides succeed we will use view. (compute_strides  was changed to not fail when when unbacked presented instead it will just return nullptr if it cant compute the strides meaning we shall use a clone).
3. if neit
A3E2
her 1, 2 works clone and use a view.

Side note: having a view does not mean that inductor will not clone, for inductor there is a pass that converts all views back to reshapes. 

#### change 2 :
We trace _reshape_view_helper when doing fake tensor tracing , but not during proxy tracing. hence such tracing wont effect the graph (only compute output shapes of several operations). We should not fail there, because it should always be possible for us to pass it in case of reshape.
when reshape_symint was called we would have either cloned, or compute_strides succeeded so the view should pass.
 
What I did is the following: we run _reshape_view_helper, if we fail due to unbacked we call _view_simple which will succeed always for reshapes, (might fail for views when its impossible to do the view, in such case we throw the dde that was thrown by the original algorithm).

Ideally I would want to register _view_simple as the meta and avoid calling  _reshape_view_helper completely but I am running an issue with the dispatcher with subclasses and I do not have time to debug it. namely one test 
would end up calling view instead of view_symint during meta dispatch when i register a meta decompositions
```python test/dynamo/test_subclasses.py SubclassTests.test_subclass_views_dynamic_True ```
 #153303.
 I will follow up with that change in a separate PR.  cc bdhirsh  
 
#### Note:
Reshape can still fail when is_contiguous is called, Next PR will handle that by calling is_known_contiguous. 

cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
laithsakka added a commit that referenced this pull request May 13, 2025
ghstack-source-id: 5686bfe
Pull Request resolved: #153198

:	modified:   torch/_refs/__init__.py
#### change 1
Lets consider the most general case, if torch compile is asked to reshape [u0, u1][u3, u4] -> [u5, u6] what shall it do?
The shape is general enough to represent both contiguous and non contiguous tensors, tensors where a clone free reshape can happen and other where a clone free cant happen.  The current algorithm will fail due to data dependent errors. 

The general idea is if its impossible to tell if the reshape can happen in place, (because for some concrete inputs
it will and other not) then its ok to take the general path and clone, instead of failing or asking the user to give hints.

In with this change reshape works as the following:
1. if we know the input is contiguous we will convert the reshape to view. 
2. if compute_strides succeed we will use view. (compute_strides  was changed to not fail when when unbacked presented instead it will just return nullptr if it cant compute the strides meaning we shall use a clone).
3. if neither 1, 2 works clone and use a view.

Side note: having a view does not mean that inductor will not clone, for inductor there is a pass that converts all views back to reshapes. 

#### change 2 :
We trace _reshape_view_helper when doing fake tensor tracing , but not during proxy tracing. hence such tracing wont effect the graph (only compute output shapes of several operations). We should not fail there, because it should always be possible for us to pass it in case of reshape.
when reshape_symint was called we would have either cloned, or compute_strides succeeded so the view should pass.
 
What I did is the following: we run _reshape_view_helper, if we fail due to unbacked we call _view_simple which will succeed always for reshapes, (might fail for views when its impossible to do the view, in such case we throw the dde that was thrown by the original algorithm).

Ideally I would want to register _view_simple as the meta and avoid calling  _reshape_view_helper completely but I am running an issue with the dispatcher with subclasses and I do not have time to debug it. namely one test 
would end up calling view instead of view_symint during meta dispatch when i register a meta decompositions
```python test/dynamo/test_subclasses.py SubclassTests.test_subclass_views_dynamic_True ```
 #153303.
 I will follow up with that change in a separate PR.  cc bdhirsh  
 
#### Note:
Reshape can still fail when is_contiguous is called, Next PR will handle that by calling is_known_contiguous. 

cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
laithsakka added a commit that referenced this pull request May 13, 2025
ghstack-source-id: 6cdac7a
Pull Request resolved: #153198

:	modified:   torch/_refs/__init__.py
@laithsakka laithsakka marked this pull request as ready for review May 13, 2025 23:18
# Handles general case: a 1+D tensor reshaped into a distinct 1+D shape
return _reshape_view_helper_core_alg(a, shape, allow_copy)
except GuardOnDataDependentSymNode as e:
# dynamic shapes do not show up in eager. For compile this function is on
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For compile this function is what ??
fix this comment

auto stride = at::detail::computeStride(self.sym_sizes(), self.sym_strides(), inferred_size);
TORCH_INTERNAL_ASSERT(stride.has_value());
if (! stride.has_value()){
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the space intended?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no i can address this before landing i just ran the linter.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/inductor module: dynamo module: inductor oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (fsdp) release notes category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants
0