10000 guard reshape for contiguous tesnors · pytorch/pytorch@2a9f1e4 · GitHub
[go: up one dir, main page]

Skip to content

Commit 2a9f1e4

Browse files
committed
guard reshape for contiguous tesnors
ghstack-source-id: 26372e1 Pull Request resolved: #148742
1 parent d789c22 commit 2a9f1e4

File tree

1 file changed

+6
-8
lines changed

1 file changed

+6
-8
lines changed

torch/_refs/__init__.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3731,14 +3731,12 @@ def _reshape_view_helper(a: TensorLikeType, *shape, allow_copy: bool) -> TensorL
37313731
return _a
37323732

37333733
if a.is_contiguous():
3734-
# Special-cases for nd_to_1d
3735-
if len(shape) == 1 and a.ndim > 1:
3736-
return torch.as_strided(a, [a.numel()], [1])
3737-
# Special-cases for 1d_to_2d
3738-
if len(shape) == 2 and a.ndim == 1:
3739-
dim0 = shape[0]
3740-
dim1 = shape[1]
3741-
return torch.as_strided(a, [dim0, dim1], [dim1, 1])
3734+
if len(shape) >= 1 and a.ndim >= 1:
3735+
strides = [1]
3736+
for x in reversed(shape[1:]):
3737+
strides.append(strides[-1] * x)
3738+
strides.reverse()
3739+
return torch.as_strided(a, shape, strides)
37423740

37433741
# Handles general case: a 1+D tensor reshaped into a distinct 1+D shape
37443742

0 commit comments

Comments
 (0)
0