8000 new_reshape · pytorch/pytorch@fbf8e90 · GitHub
[go: up one dir, main page]

Skip to content

Commit fbf8e90

Browse files
committed
new_reshape
ghstack-source-id: fbd66fc Pull Request resolved: #153198 : modified: torch/_refs/__init__.py
1 parent 50a283c commit fbf8e90

File tree

8 files changed

+407
-94
lines changed

8 files changed

+407
-94
lines changed

aten/src/ATen/FunctionalizeFallbackKernel.cpp

+10-1
Original file line numberDiff line numberDiff line change
@@ -315,8 +315,17 @@ static at::Tensor _unsafe_view_functionalize(const at::Tensor & self, at::SymInt
315315
// See Note [Propagating strides in the functionalization pass]
316316
// (for _unsafe_view, I'm just manually doing the shape inference rule here instead of calling the meta function for unsafe_view)
317317
auto inferred_size = at::infer_size_dv(size, self.sym_numel());
318+
318319
auto stride = at::detail::computeStride(self.sym_sizes(), self.sym_strides(), inferred_size);
319-
TORCH_INTERNAL_ASSERT(stride.has_value());
320+
if (! stride.has_value()){
321+
// See if the view is valid. If it's not, then we copy.
322+
// It's OK to copy, because _unsafe_view(x) guarantees that x isn't used
323+
// anymore.
324+
if (!stride.has_value()) {
325+
auto tmp = self_.contiguous();
326+
stride = tmp.sym_strides();
327+
}
328+
}
320329
out.unsafeGetTensorImpl()->set_sizes_and_strides(inferred_size, stride.value());
321330
return out;
322331
}

aten/src/ATen/InferSize.h

+12-6
Original file line numberDiff line numberDiff line change
@@ -25,17 +25,23 @@ inline void infer_size_impl(
2525
// N.B. this is an index, not a sym dim!
2626
std::optional<int64_t> infer_dim;
2727
for (int64_t dim = 0, ndim = shape.size(); dim != ndim; dim++) {
28-
// We can avoid failing on unbacked shape[dim] and assert that it is >=0
29-
// following python behaviour.
30-
if (shape[dim] == -1) {
28+
if (TORCH_GUARD_OR_FALSE(sym_eq(shape[dim], -1))) {
3129
if (infer_dim) {
3230
throw std::runtime_error("only one dimension can be inferred");
3331
}
3432
infer_dim = dim;
35-
} else if (shape[dim] >= 0) {
36-
newsize *= shape[dim];
3733
} else {
38-
TORCH_CHECK(false, "invalid shape dimension ", shape[dim]);
34+
// in case of unbacked shape[dim] we assume its not -1 and add runtime
35+
// assertion.
36+
TORCH_MAYBE_SYM_CHECK(
37+
sym_gt(shape[dim], -1),
38+
"invalid shape dimension ",
39+
shape[dim],
40+
" at index ",
41+
dim,
42+
" of shape ",
43+
shape);
44+
newsize *= shape[dim];
3945
}
4046
}
4147

aten/src/ATen/TensorUtils.cpp

+11-4
Original file line numberDiff line numberDiff line change
@@ -367,19 +367,26 @@ inline static std::optional<ResultVec> computeStride_impl(
367367
// numel in current chunk
368368
Numel tensor_numel = 1;
369369
Numel view_numel = 1;
370+
371+
// The usages of TORCH_GUARD_OR_TRUE/TORCH_GUARD_OR_FALSE below could result in returning std::nullopt which have an effect of falling
372+
// back to a clone when unbacked presented. But it will not result in returning different or wrong results.
370373
for (int64_t tensor_d = oldshape.size() - 1; tensor_d >= 0; tensor_d--) {
371374
tensor_numel *= oldshape[tensor_d];
372375
// if end of tensor size chunk, check view
373376
if ((tensor_d == 0) ||
374-
(TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(oldshape[tensor_d - 1], 1)) &&
375-
TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(oldstride[tensor_d - 1], tensor_numel * chunk_base_stride)))) {
377+
(TORCH_GUARD_OR_TRUE(sym_ne(oldshape[tensor_d - 1], 1)) &&
378+
TORCH_GUARD_OR_TRUE(sym_ne(oldstride[tensor_d - 1], tensor_numel * chunk_base_stride)))) {
379+
// We want to accumulate stuff in view_numel until view_numel == tensor_numel, if we do not know yet if they are we keep
380+
// accumulating. if view_numel<tensor_numel view_numel==tensor_numel would fail also so better to look ahead.
381+
// we use TORCH_GUARD_OR_FALSE when comparing newshape[view_d] ==1 because if we know view_numel<tensor_numel is false.
382+
// we want to stop. Unless we know for sure newshape[view_d]==1 in that case we would stop in the next iteration anyway.
376383
while (view_d >= 0 &&
377-
(TORCH_GUARD_SIZE_OBLIVIOUS(sym_lt(view_numel, tensor_numel)) || TORCH_GUARD_SIZE_OBLIVIOUS(sym_eq(newshape[view_d], 1)))) {
384+
(TORCH_GUARD_OR_TRUE(sym_lt(view_numel, tensor_numel)) || TORCH_GUARD_OR_FALSE(sym_eq(newshape[view_d], 1)))) {
378385
newstride[view_d] = view_numel * chunk_base_stride;
379386
view_numel *= newshape[view_d];
380387
view_d--;
381388
}
382-
if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(view_numel, tensor_numel))) {
389+
if (TORCH_GUARD_OR_TRUE(sym_ne(view_numel, tensor_numel))) {
383390
return std::nullopt;
384391
}
385392
if (tensor_d > 0) {

test/dynamo/test_repros.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1281,7 +1281,7 @@ def test_longformer_chunk(self):
12811281
self.assertExpectedInline(cnt.op_count, """4""")
12821282
else:
12831283
self.assertExpectedInline(cnt.frame_count, """2""")
1284-
self.assertExpectedInline(cnt.op_count, """19""")
1284+
self.assertExpectedInline(cnt.op_count, """20""")
12851285

12861286
def test_hf_t5_forward(self):
12871287
input = torch.randn([1, 2048, 512])

test/export/test_export.py

+3-40
Original file line numberDiff line numberDiff line change
@@ -4429,32 +4429,9 @@ class M_v0(torch.nn.Module):
44294429
def forward(self, t):
44304430
items = [t[i].item() for i in range(t.numel())]
44314431
r = torch.randn([items[0], items[1]])
4432-
# Could not guard on data-dependent expression Ne(Mod(u1, u2), 0)
44334432
return r.view(items[0], items[2])
44344433

44354434
M = M_v0
4436-
with self.assertRaisesRegex(
4437-
error_type,
4438-
"The following call raised this error(.*\n)+"
4439-
f".*{re.escape('return r.view(items[0], items[2])')}(.*\n)+"
4440-
"To fix the error, insert one of the following checks before this call.*:\n"
4441-
f".*{re.escape('torch._check((items[1] % items[2]) == 0)')}.*\n"
4442-
f".*{re.escape('torch._check((items[1] % items[2]) != 0)')}(.*\n)+"
4443-
f".*{re.escape('(These suggested fixes were derived by replacing `u1` with items[1]')}"
4444-
f".*{re.escape('or r.shape[1], `u2` with items[2] in Eq(Mod(u1, u2), 0) and its negation.')}",
4445-
):
4446-
export(N(), (t,), strict=strict)
4447-
4448-
class M_v1(torch.nn.Module):
4449-
def forward(self, t):
4450-
items = [t[i].item() for i in range(t.numel())]
4451-
r = torch.randn([items[0], items[1]])
4452-
# TODO(pianpwk): this isn't the suggested fixes.
4453-
# fix issue with % being interpreted as PythonMod instead of Mod
4454-
torch._check(items[1] == items[2])
4455-
return r.view(items[0], items[2])
4456-
4457-
M = M_v1
44584435
export(N(), (t,), strict=strict)
44594436

44604437
def test_suggested_fixes_for_data_dependent_errors_puzzlers(self):
@@ -12409,19 +12386,13 @@ def forward(self, xs):
1240912386
):
1241012387
ep.module()(torch.tensor([1, 5]))
1241112388

12412-
def test_reshape_view_helper(self):
12413-
# see: https://github.com/pytorch/pytorch/issues/126607
12389+
def test_view(self):
1241412390
class Model(torch.nn.Module):
1241512391
def __init__(self) -> None:
1241612392
super().__init__()
1241712393

1241812394
def forward(self, x):
1241912395
x = x.view(x.size(1), -1)
12420-
# torch/_refs/__init__/_reshape_view_helper() will generate guards on reshape kernel(?)
12421-
# Ne(s0, 20), so that reshape isn't no-op
12422-
# Ne(Mod(s0, 20), 0), so that reshape needs to first flatten [s0, 20, 16] -> [s0*20, 16]
12423-
# then split_dim -> [20, s0, 16]
12424-
# check that these show up in graph
1242512396
return torch.nn.functional.softmax(
1242612397
x, dim=0
1242712398
) # don't think softmax actually creates any issues, just part of original test
@@ -12435,16 +12406,8 @@ def forward(self, x):
1243512406
dynamic_shapes=dynamic_shapes,
1243612407
allow_complex_guards_as_runtime_asserts=True,
1243712408
)
12438-
with self.assertRaisesRegex(
12439-
RuntimeError,
12440-
r"Runtime assertion failed for expression Ne\(s77, 20\)",
12441-
):
12442-
ep.module()(torch.randn(20, 20, 16))
12443-
with self.assertRaisesRegex(
12444-
RuntimeError,
12445-
r"Runtime assertion failed for expression Ne\(Mod\(s77, 20\), 0\)",
12446-
):
12447-
ep.module()(torch.randn(400, 20, 16))
12409+
ep.module()(torch.randn(20, 20, 16))
12410+
ep.module()(torch.randn(400, 20, 16))
1244812411
ep.module()(torch.randn(42, 20, 16))
1244912412

1245012413
def test_full_on_scalar_tensor(self):

test/test_dynamic_shapes.py

+214
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
import torch.nn.functional as F
1616
from torch import sym_int, SymBool, SymFloat, SymInt
1717
from torch._C import _disabled_torch_function_impl
18+
from torch._dynamo.testing import CompileCounterWithBackend
19+
from torch._inductor.utils import fresh_inductor_cache
1820
from torch.fx.experimental import sym_node
1921
from torch.fx.experimental.proxy_tensor import make_fx
2022
from torch.fx.experimental.sym_node import method_to_operator, SymNode, to_node
@@ -42,6 +44,7 @@
4244
skipIfTorchDynamo,
4345
TestCase,
4446
)
47+
from torch.testing._internal.logging_utils import logs_to_string
4548
from torch.utils import _pytree as pytree
4649
from torch.utils._python_dispatch import TorchDispatchMode
4750
from torch.utils._sympy.functions import (
@@ -3050,6 +3053,217 @@ def func(a, b):
30503053
with self.assertRaises(RuntimeError):
30513054
func(a, torch.rand(2, 1))
30523055

3056+
@fresh_inductor_cache()
3057+
@skipIfTorchDynamo("not allowed to trace mark_unbacked")
3058+
@torch._dynamo.config.patch("capture_scalar_outputs", True)
3059+
def test_unbacked_reshape1(self):
3060+
cnt = CompileCounterWithBackend("inductor")
3061+
3062+
# Reshape happens in place reshape (no-clone)
3063+
# reshape u1 -> (u0*u0)
3064+
def func(x, y):
3065+
f = y.item()
3066+
t1 = x.view((f, f))
3067+
t2 = x.reshape((f, f))
3068+
# TODO avoid _check_is_size here.
3069+
torch._check_is_size(f)
3070+
return t1 * 10, t2 * 10
3071+
3072+
compiled_func = torch.compile(
3073+
fullgraph=True,
3074+
backend=cnt,
3075+
dynamic=True,
3076+
)(func)
3077+
3078+
# create a non-contigious with data being even numbers in [0:cnt-1]
3079+
# and reshape it into sqrt(cnt)*sqrt(cnt)
3080+
def make_non_contiguous_tensor_and_test(cnt):
3081+
# create a non-contiguous tensor x that is skipping odd indices.
3082+
x = torch.arange(cnt * 2)
3083+
x = x.as_strided((x.size()[0] // 2,), (2,))
3084+
3085+
torch._dynamo.decorators.mark_unbacked(x, 0)
3086+
sz = torch.tensor([int(math.sqrt(cnt))])
3087+
compiled_result = compiled_func(x, sz)
3088+
eager_result = func(x, sz)
3089+
self.assertEqual(compiled_result, eager_result)
3090+
3091+
log_stream, ctx = logs_to_string(
3092+
"torch._functorch._aot_autograd.dispatch_and_compile_graph", "aot_graphs"
3093+
)
3094+
with ctx():
3095+
make_non_contiguous_tensor_and_test(4)
3096+
aot_graphs = "\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip()
3097+
self.assertExpectedInline(
3098+
aot_graphs,
3099+
"""\
3100+
def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "Sym(s7)", arg3_1: "i64[u1][s7]cpu"):
3101+
ge_1: "Sym(u1 >= 0)" = arg1_1 >= 0
3102+
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u1 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
3103+
_local_scalar_dense: "Sym(u0)" = torch.ops.aten._local_scalar_dense.default(arg0_1); arg0_1 = None
3104+
ge_3: "Sym(u0 >= 0)" = _local_scalar_dense >= 0
3105+
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_3, "Runtime assertion failed for expression u0 >= 0 on node 'ge_1'"); ge_3 = _assert_scalar_1 = None
3106+
pow_1: "Sym(u0**2)" = _local_scalar_dense ** 2
3107+
eq: "Sym(Eq(u1, u0**2))" = arg1_1 == pow_1; arg1_1 = pow_1 = None
3108+
_assert_scalar_2 = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(u1, u0**2) on node 'eq'"); eq = _assert_scalar_2 = None
3109+
view: "i64[u0, u0][s7*u0, s7]cpu" = torch.ops.aten.view.default(arg3_1, [_local_scalar_dense, _local_scalar_dense])
3110+
view_1: "i64[u0, u0][s7*u0, s7]cpu" = torch.ops.aten.view.default(arg3_1, [_local_scalar_dense, _local_scalar_dense]); arg3_1 = _local_scalar_dense = None
3111+
mul_9: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.mul.Tensor(view, 10); view = None
3112+
mul_12: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.mul.Tensor(view_1, 10); view_1 = None
3113+
return (mul_9, mul_12)""", # noqa: B950
3114+
ignore_comments=True,
3115+
ignore_empty_lines=True,
3116+
)
3117+
3118+
make_non_contiguous_tensor_and_test(49)
3119+
self.assertEqual(cnt.frame_count, 1)
3120+
3121+
# Pass in a contiguous tensor, it will recompile due to stride being 1 (0/1 specialization).
3122+
# marking strides unabcked would have avoided the recompilatipn here.
3123+
x = torch.arange(100)
3124+
torch._dynamo.decorators.mark_unbacked(x, 0)
3125+
3126+
log_stream, ctx = logs_to_string(
3127+
"torch._functorch._aot_autograd.dispatch_and_compile_graph", "aot_graphs"
3128+
)
3129+
with ctx():
3130+
compiled_result = compiled_func(x, torch.tensor([10]))
3131+
eager_result = func(x, torch.tensor([10]))
3132+
self.assertEqual(compiled_result, eager_result)
3133+
self.assertEqual(cnt.frame_count, 2)
3134+
3135+
aot_graphs = "\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip()
3136+
self.assertExpectedInline(
3137+
aot_graphs,
3138+
"""\
3139+
def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "i64[u1][1]cpu"):
3140+
ge_1: "Sym(u1 >= 0)" = arg1_1 >= 0
3141+
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u1 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
3142+
_local_scalar_dense: "Sym(u0)" = torch.ops.aten._local_scalar_dense.default(arg0_1); arg0_1 = None
3143+
ge_3: "Sym(u0 >= 0)" = _local_scalar_dense >= 0
3144+
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_3, "Runtime assertion failed for expression u0 >= 0 on node 'ge_1'"); ge_3 = _assert_scalar_1 = None
3145+
pow_1: "Sym(u0**2)" = _local_scalar_dense ** 2
3146+
eq: "Sym(Eq(u1, u0**2))" = arg1_1 == pow_1; arg1_1 = pow_1 = None
3147+
_assert_scalar_2 = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(u1, u0**2) on node 'eq'"); eq = _assert_scalar_2 = None
3148+
view: "i64[u0, u0][u0, 1]cpu" = torch.ops.aten.view.default(arg2_1, [_local_scalar_dense, _local_scalar_dense])
3149+
view_1: "i64[u0, u0][u0, 1]cpu" = torch.ops.aten.view.default(arg2_1, [_local_scalar_dense, _local_scalar_dense]); arg2_1 = _local_scalar_dense = None
3150+
mul_4: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.mul.Tensor(view, 10); view = None
3151+
mul_7: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.mul.Tensor(view_1, 10); view_1 = None
3152+
return (mul_4, mul_7)""", # noqa: B950
3153+
ignore_comments=True,
3154+
ignore_empty_lines=True,
3155+
)
3156+
3157+
x = torch.arange(25)
3158+
compiled_result = compiled_func(x, torch.tensor([5]))
3159+
eager_result = func(x, torch.tensor([5]))
3160+
self.assertEqual(cnt.frame_count, 2)
3161+
3162+
@skipIfTorchDynamo("not allowed to trace mark_unbacked")
3163+
@torch._dynamo.config.patch("capture_scalar_outputs", True)
3164+
def test_unbacked_reshape2(self):
3165+
cnt = CompileCounterWithBackend("inductor")
3166+
3167+
# This reshape requires a clone when the input is not contiguous and we cant compute strides.
3168+
# reshape (u2, u3) -> (u0, u1)
3169+
def func(x, y, with_view=False):
3170+
u0, u1 = y.tolist()
3171+
torch._check_is_size(u0)
3172+
torch._check_is_size(u1)
3173+
3174+
result1 = torch.reshape(x, (u0, u1))
3175+
result2 = None
3176+
if with_view:
3177+
result2 = x.view(x, (u0, u1)) * 10
3178+
return result1 * 10, result2
3179+
3180+
compiled_func = torch.compile(fullgraph=True, backend=cnt, dynamic=True)(func)
3181+
3182+
x = torch.randn(10, 10)
3183+
# make x not contiguous.
3184+
x = x.t_()
3185+
torch._dynamo.decorators.mark_unbacked(x, 0)
3186+
torch._dynamo.decorators.mark_unbacked(x, 1)
3187+
3188+
log_stream, ctx = logs_to_string(
3189+
"torch._functorch._aot_autograd.dispatch_and_compile_graph", "aot_graphs"
3190+
)
3191+
with ctx():
3192+
result_eager = func(x, torch.tensor([5, 20]))
3193+
result_compiled = compiled_func(x, torch.tensor([5, 20]))
3194+
self.assertEqual(result_compiled, result_eager)
3195+
self.assertEqual(cnt.frame_count, 1)
3196+
3197+
aot_graphs = "\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip()
3198+
self.assertExpectedInline(
3199+
aot_graphs,
3200+
"""\
< F438 code>3201+
def forward(self, arg0_1: "i64[2][1]cpu", arg1_1: "Sym(u2)", arg2_1: "Sym(u3)", arg3_1: "f32[u2, u3][1, u2]cpu"):
3202+
ge_1: "Sym(u2 >= 0)" = arg1_1 >= 0
3203+
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u2 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
3204+
ge_3: "Sym(u3 >= 0)" = arg2_1 >= 0
3205+
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_3, "Runtime assertion failed for expression u3 >= 0 on node 'ge_1'"); ge_3 = _assert_scalar_1 = None
3206+
select: "i64[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 0)
3207+
_local_scalar_dense: "Sym(u0)" = torch.ops.aten._local_scalar_dense.default(select); select = None
3208+
ge_5: "Sym(u0 >= 0)" = _local_scalar_dense >= 0
3209+
_assert_scalar_2 = torch.ops.aten._assert_scalar.default(ge_5, "Runtime assertion failed for expression u0 >= 0 on node 'ge_2'"); ge_5 = _assert_scalar_2 = None
3210+
select_1: "i64[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 1); arg0_1 = None
3211+
_local_scalar_dense_1: "Sym(u1)" = torch.ops.aten._local_scalar_dense.default(select_1); select_1 = None
3212+
ge_7: "Sym(u1 >= 0)" = _local_scalar_dense_1 >= 0
3213+
_assert_scalar_3 = torch.ops.aten._assert_scalar.default(ge_7, "Runtime assertion failed for expression u1 >= 0 on node 'ge_3'"); ge_7 = _assert_scalar_3 = None
3214+
mul: "Sym(u2*u3)" = arg1_1 * arg2_1; arg1_1 = arg2_1 = None
3215+
mul_1: "Sym(u0*u1)" = _local_scalar_dense * _local_scalar_dense_1
3216+
eq: "Sym(Eq(u2*u3, u0*u1))" = mul == mul_1; mul = mul_1 = None
3217+
_assert_scalar_4 = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(u2*u3, u0*u1) on node 'eq'"); eq = _assert_scalar_4 = None
3218+
clone: "f32[u2, u3][Max(1, u3), 1]cpu" = torch.ops.aten.clone.default(arg3_1, memory_format = torch.contiguous_format); arg3_1 = None
3219+
view: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.view.default(clone, [_local_scalar_dense, _local_scalar_dense_1]); clone = _local_scalar_dense = _local_scalar_dense_1 = None
3220+
mul_16: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.mul.Tensor(view, 10); view = None
3221+
return (mul_16,)""", # noqa: B950
3222+
ignore_comments=True,
3223+
ignore_empty_lines=True,
3224+
)
3225+
3226+
result_eager = func(x, torch.tensor([2, 50]))
3227+
result_compiled = compiled_func(x, torch.tensor([2, 50]))
3228+
self.assertEqual(result_compiled, result_eager)
3229+
self.assertEqual(cnt.frame_count, 1)
3230+
3231+
x = torch.randn(4, 4).t_()
3232+
result_eager = func(x, torch.tensor([2, 8]))
3233+
result_compiled = compiled_func(x, torch.tensor([2, 8]))
3234+
self.assertEqual(result_compiled, result_eager)
3235+
self.assertEqual(cnt.frame_count, 1)
3236+
3237+
@unittest.skip("this test fails due to inductor/autograd issue #153041")
3238+
@torch._dynamo.config.patch("capture_scalar_outputs", True)
3239+
def test_unbacked_non_contigious_reshape_failing(self):
3240+
# reshape u1 -> (u0*u0)
3241+
# this result in the tensor "i64[u0, u0][s7*u0, s7].
3242+
# reshape happens in place reshape (no-clone)
3243+
def func(x, y):
3244+
f = y.item()
3245+
t1 = x.view((f, f))
3246+
t2 = x.reshape((f, f))
3247+
return t1, t2
3248+
3249+
# create a non-contigious with data being even numbers in [0:cnt-1]
3250+
def make_non_contiguous_tensor(cnt):
3251+
# create a non-contiguous tensor x that is skipping odd indices.
3252+
x = torch.arange(cnt * 2, device="cuda")
3253+
x = x.as_strided((x.size()[0] // 2,), (2,))
3254+
return x
3255+
3256+
x = make_non_contiguous_tensor(4)
3257+
torch._dynamo.decorators.mark_unbacked(x, 0)
3258+
compiled_func = torch.compile(
3259+
fullgraph=True,
3260+
backend="inductor",
3261+
)(func)
3262+
3263+
compiled_result = compiled_func(x, torch.tensor([2]))
3264+
eager_result = func(x, torch.tensor([2]))
3265+
self.assertEqual(compiled_result, eager_result)
3266+
30533267

30543268
if __name__ == "__main__":
30553269
run_tests()

0 commit comments

Comments
 (0)
{"resolvedServerColorMode":"day"}
0