8000 make reshape work for reshapeing 1dim unbacked to non-contig to anything · pytorch/pytorch@a1d3af5 · GitHub
[go: up one dir, main page]

Skip to content

Commit a1d3af5

Browse files
committed
make reshape work for reshapeing 1dim unbacked to non-contig to anything
ghstack-source-id: 9b5e433 Pull Request resolved: #148899
1 parent f8a6d88 commit a1d3af5

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

aten/src/ATen/TensorUtils.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -370,11 +370,12 @@ inline std::optional<ResultVec> computeStride_impl(
370370
for (int64_t tensor_d = oldshape.size() - 1; tensor_d >= 0; tensor_d--) {
371371
tensor_numel *= oldshape[tensor_d];
372372
// if end of tensor size chunk, check view
373-
if ((tensor_d == 0) ||
374-
(TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(oldshape[tensor_d - 1], 1)) &&
375-
TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(oldstride[tensor_d - 1], tensor_numel * chunk_base_stride)))) {
373+
if ( tensor_d == 0 ||
374+
TORCH_GUARD_SIZE_OBLIVIOUS(
375+
sym_ne(oldshape[tensor_d - 1], 1) &&
376+
sym_ne(oldstride[tensor_d - 1], tensor_numel * chunk_base_stride))) {
376377
while (view_d >= 0 &&
377-
(TORCH_GUARD_SIZE_OBLIVIOUS(sym_lt(view_numel, tensor_numel)) || TORCH_GUARD_SIZE_OBLIVIOUS(sym_eq(newshape[view_d], 1)))) {
378+
TORCH_GUARD_SIZE_OBLIVIOUS(sym_lt(view_numel, tensor_numel)|| sym_eq(newshape[view_d], 1) )) {
378379
newstride[view_d] = view_numel * chunk_base_stride;
379380
view_numel *= newshape[view_d];
380381
view_d--;

0 commit comments

Comments
 (0)
0