[sparse][semi-structured] Fix RuntimeError when passing in non-contiguous input to SparseSemiStructured linear #114593
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Summary:
This PR also brings in changes from #105595, which are needed for the changes in #110420
Currently, PyTorch incorrectly calculates the size of the returned matrix when we pass a non-contiguous batched (>2d) input to the semi-structured sparse subclass.
This is most common in MLP layers, where we have 2 linear layers back to back.
This will lead to an error like the following:
Where the size of the sparse matmul result is off because we infer the output shape with the wrong tensor shape.
This happens because of a bug where we did not update the subclass tensor shape when doing transpose.
For semi-structured sparsity, transposing is a no-op where we just set the boolean flag, but we forgot to also update the tensor shape.
Note that this error goes away in inference mode, since we avoid decomposing the aten.linear op and handle shape folding ourselves, which changes the execution path.
An alternative way to fix this issue is to set
TORCH_FLATTEN_LINEAR_3D=True, which will also fix this error.
Test Plan:
Reviewers:
Subscribers:
Tasks:
Tags:
Pull Request resolved: #110420 Approved by: https://github.com/alexsamardzic, https://github.com/cpuhrsch