|
15 | 15 | import torch.nn.functional as F
|
16 | 16 | from torch import sym_int, SymBool, SymFloat, SymInt
|
17 | 17 | from torch._C import _disabled_torch_function_impl
|
| 18 | +from torch._dynamo.testing import CompileCounterWithBackend |
| 19 | +from torch._inductor.utils import fresh_inductor_cache |
18 | 20 | from torch.fx.experimental import sym_node
|
19 | 21 | from torch.fx.experimental.proxy_tensor import make_fx
|
20 | 22 | from torch.fx.experimental.sym_node import method_to_operator, SymNode, to_node
|
|
42 | 44 | skipIfTorchDynamo,
|
43 | 45 | TestCase,
|
44 | 46 | )
|
| 47 | +from torch.testing._internal.logging_utils import logs_to_string |
45 | 48 | from torch.utils import _pytree as pytree
|
46 | 49 | from torch.utils._python_dispatch import TorchDispatchMode
|
47 | 50 | from torch.utils._sympy.functions import (
|
@@ -3050,6 +3053,217 @@ def func(a, b):
|
3050 | 3053 | with self.assertRaises(RuntimeError):
|
3051 | 3054 | func(a, torch.rand(2, 1))
|
3052 | 3055 |
|
| 3056 | + @fresh_inductor_cache() |
| 3057 | + @skipIfTorchDynamo("not allowed to trace mark_unbacked") |
| 3058 | + @torch._dynamo.config.patch("capture_scalar_outputs", True) |
| 3059 | + def test_unbacked_reshape1(self): |
| 3060 | + cnt = CompileCounterWithBackend("inductor") |
| 3061 | + |
| 3062 | + # Reshape happens in place reshape (no-clone) |
| 3063 | + # reshape u1 -> (u0*u0) |
| 3064 | + def func(x, y): |
| 3065 | + f = y.item() |
| 3066 | + t1 = x.view((f, f)) |
| 3067 | + t2 = x.reshape((f, f)) |
| 3068 | + # TODO avoid _check_is_size here. |
| 3069 | + torch._check_is_size(f) |
| 3070 | + return t1 * 10, t2 * 10 |
| 3071 | + |
| 3072 | + compiled_func = torch.compile( |
| 3073 | + fullgraph=True, |
| 3074 | + backend=cnt, |
| 3075 | + dynamic=True, |
| 3076 | + )(func) |
| 3077 | + |
| 3078 | + # create a non-contigious with data being even numbers in [0:cnt-1] |
| 3079 | + # and reshape it into sqrt(cnt)*sqrt(cnt) |
| 3080 | + def make_non_contiguous_tensor_and_test(cnt): |
| 3081 | + # create a non-contiguous tensor x that is skipping odd indices. |
| 3082 | + x = torch.arange(cnt * 2) |
| 3083 | + x = x.as_strided((x.size()[0] // 2,), (2,)) |
| 3084 | + |
| 3085 | + torch._dynamo.decorators.mark_unbacked(x, 0) |
| 3086 | + sz = torch.tensor([int(math.sqrt(cnt))]) |
| 3087 | + compiled_result = compiled_func(x, sz) |
| 3088 | + eager_result = func(x, sz) |
| 3089 | + self.assertEqual(compiled_result, eager_result) |
| 3090 | + |
| 3091 | + log_stream, ctx = logs_to_string( |
| 3092 | + "torch._functorch._aot_autograd.dispatch_and_compile_graph", "aot_graphs" |
| 3093 | + ) |
| 3094 | + with ctx(): |
| 3095 | + make_non_contiguous_tensor_and_test(4) |
| 3096 | + aot_graphs = "\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip() |
| 3097 | + self.assertExpectedInline( |
| 3098 | + aot_graphs, |
| 3099 | + """\ |
| 3100 | +def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "Sym(s7)", arg3_1: "i64[u1][s7]cpu"): |
| 3101 | + ge_1: "Sym(u1 >= 0)" = arg1_1 >= 0 |
| 3102 | + _assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u1 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None |
| 3103 | + _local_scalar_dense: "Sym(u0)" = torch.ops.aten._local_scalar_dense.default(arg0_1); arg0_1 = None |
| 3104 | + ge_3: "Sym(u0 >= 0)" = _local_scalar_dense >= 0 |
| 3105 | + _assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_3, "Runtime assertion failed for expression u0 >= 0 on node 'ge_1'"); ge_3 = _assert_scalar_1 = None |
| 3106 | + pow_1: "Sym(u0**2)" = _local_scalar_dense ** 2 |
| 3107 | + eq: "Sym(Eq(u1, u0**2))" = arg1_1 == pow_1; arg1_1 = pow_1 = None |
| 3108 | + _assert_scalar_2 = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(u1, u0**2) on node 'eq'"); eq = _assert_scalar_2 = None |
| 3109 | + view: "i64[u0, u0][s7*u0, s7]cpu" = torch.ops.aten.view.default(arg3_1, [_local_scalar_dense, _local_scalar_dense]) |
| 3110 | + view_1: "i64[u0, u0][s7*u0, s7]cpu" = torch.ops.aten.view.default(arg3_1, [_local_scalar_dense, _local_scalar_dense]); arg3_1 = _local_scalar_dense = None |
| 3111 | + mul_9: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.mul.Tensor(view, 10); view = None |
| 3112 | + mul_12: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.mul.Tensor(view_1, 10); view_1 = None |
| 3113 | + return (mul_9, mul_12)""", # noqa: B950 |
| 3114 | + ignore_comments=True, |
| 3115 | + ignore_empty_lines=True, |
| 3116 | + ) |
| 3117 | + |
| 3118 | + make_non_contiguous_tensor_and_test(49) |
| 3119 | + self.assertEqual(cnt.frame_count, 1) |
| 3120 | + |
| 3121 | + # Pass in a contiguous tensor, it will recompile due to stride being 1 (0/1 specialization). |
| 3122 | + # marking strides unabcked would have avoided the recompilatipn here. |
| 3123 | + x = torch.arange(100) |
| 3124 | + torch._dynamo.decorators.mark_unbacked(x, 0) |
| 3125 | + |
| 3126 | + log_stream, ctx = logs_to_string( |
| 3127 | + "torch._functorch._aot_autograd.dispatch_and_compile_graph", "aot_graphs" |
| 3128 | + ) |
| 3129 | + with ctx(): |
| 3130 | + compiled_result = compiled_func(x, torch.tensor([10])) |
| 3131 | + eager_result = func(x, torch.tensor([10])) |
| 3132 | + self.assertEqual(compiled_result, eager_result) |
| 3133 | + self.assertEqual(cnt.frame_count, 2) |
| 3134 | + |
| 3135 | + aot_graphs = "\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip() |
| 3136 | + self.assertExpectedInline( |
| 3137 | + aot_graphs, |
| 3138 | + """\ |
| 3139 | +def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "i64[u1][1]cpu"): |
| 3140 | + ge_1: "Sym(u1 >= 0)" = arg1_1 >= 0 |
| 3141 | + _assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u1 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None |
| 3142 | + _local_scalar_dense: "Sym(u0)" = torch.ops.aten._local_scalar_dense.default(arg0_1); arg0_1 = None |
| 3143 | + ge_3: "Sym(u0 >= 0)" = _local_scalar_dense >= 0 |
| 3144 | + _assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_3, "Runtime assertion failed for expression u0 >= 0 on node 'ge_1'"); ge_3 = _assert_scalar_1 = None |
| 3145 | + pow_1: "Sym(u0**2)" = _local_scalar_dense ** 2 |
| 3146 | + eq: "Sym(Eq(u1, u0**2))" = arg1_1 == pow_1; arg1_1 = pow_1 = None |
| 3147 | + _assert_scalar_2 = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(u1, u0**2) on node 'eq'"); eq = _assert_scalar_2 = None |
| 3148 | + view: "i64[u0, u0][u0, 1]cpu" = torch.ops.aten.view.default(arg2_1, [_local_scalar_dense, _local_scalar_dense]) |
| 3149 | + view_1: "i64[u0, u0][u0, 1]cpu" = torch.ops.aten.view.default(arg2_1, [_local_scalar_dense, _local_scalar_dense]); arg2_1 = _local_scalar_dense = None |
| 3150 | + mul_4: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.mul.Tensor(view, 10); view = None |
| 3151 | + mul_7: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.mul.Tensor(view_1, 10); view_1 = None |
| 3152 | + return (mul_4, mul_7)""", # noqa: B950 |
| 3153 | + ignore_comments=True, |
| 3154 | + ignore_empty_lines=True, |
| 3155 | + ) |
| 3156 | + |
| 3157 | + x = torch.arange(25) |
| 3158 | + compiled_result = compiled_func(x, torch.tensor([5])) |
| 3159 | + eager_result = func(x, torch.tensor([5])) |
| 3160 | + self.assertEqual(cnt.frame_count, 2) |
| 3161 | + |
| 3162 | + @skipIfTorchDynamo("not allowed to trace mark_unbacked") |
| 3163 | + @torch._dynamo.config.patch("capture_scalar_outputs", True) |
| 3164 | + def test_unbacked_reshape2(self): |
| 3165 | + cnt = CompileCounterWithBackend("inductor") |
| 3166 | + |
| 3167 | + # This reshape requires a clone when the input is not contiguous and we cant compute strides. |
| 3168 | + # reshape (u2, u3) -> (u0, u1) |
| 3169 | + def func(x, y, with_view=False): |
| 3170 | + u0, u1 = y.tolist() |
| 3171 | + torch._check_is_size(u0) |
| 3172 | + torch._check_is_size(u1) |
| 3173 | + |
| 3174 | + result1 = torch.reshape(x, (u0, u1)) |
| 3175 | + result2 = None |
| 3176 | + if with_view: |
| 3177 | + result2 = x.view(x, (u0, u1)) * 10 |
| 3178 | + return result1 * 10, result2 |
| 3179 | + |
| 3180 | + compiled_func = torch.compile(fullgraph=True, backend=cnt, dynamic=True)(func) |
| 3181 | + |
| 3182 | + x = torch.randn(10, 10) |
| 3183 | + # make x not contiguous. |
| 3184 | + x = x.t_() |
| 3185 | + torch._dynamo.decorators.mark_unbacked(x, 0) |
| 3186 | + torch._dynamo.decorators.mark_unbacked(x, 1) |
| 3187 | + |
| 3188 | + log_stream, ctx = logs_to_string( |
| 3189 | + "torch._functorch._aot_autograd.dispatch_and_compile_graph", "aot_graphs" |
| 3190 | + ) |
| 3191 | + with ctx(): |
| 3192 | + result_eager = func(x, torch.tensor([5, 20])) |
| 3193 | + result_compiled = compiled_func(x, torch.tensor([5, 20])) |
| 3194 | + self.assertEqual(result_compiled, result_eager) |
| 3195 | + self.assertEqual(cnt.frame_count, 1) |
| 3196 | + |
| 3197 | + aot_graphs = "\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip() |
| 3198 | + self.assertExpectedInline( |
| 3199 | + aot_graphs, |
| 3200 | + """\ |
| <
F438
code>3201 | +def forward(self, arg0_1: "i64[2][1]cpu", arg1_1: "Sym(u2)", arg2_1: "Sym(u3)", arg3_1: "f32[u2, u3][1, u2]cpu"): |
| 3202 | + ge_1: "Sym(u2 >= 0)" = arg1_1 >= 0 |
| 3203 | + _assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u2 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None |
| 3204 | + ge_3: "Sym(u3 >= 0)" = arg2_1 >= 0 |
| 3205 | + _assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_3, "Runtime assertion failed for expression u3 >= 0 on node 'ge_1'"); ge_3 = _assert_scalar_1 = None |
| 3206 | + select: "i64[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 0) |
| 3207 | + _local_scalar_dense: "Sym(u0)" = torch.ops.aten._local_scalar_dense.default(select); select = None |
| 3208 | + ge_5: "Sym(u0 >= 0)" = _local_scalar_dense >= 0 |
| 3209 | + _assert_scalar_2 = torch.ops.aten._assert_scalar.default(ge_5, "Runtime assertion failed for expression u0 >= 0 on node 'ge_2'"); ge_5 = _assert_scalar_2 = None |
| 3210 | + select_1: "i64[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 1); arg0_1 = None |
| 3211 | + _local_scalar_dense_1: "Sym(u1)" = torch.ops.aten._local_scalar_dense.default(select_1); select_1 = None |
| 3212 | + ge_7: "Sym(u1 >= 0)" = _local_scalar_dense_1 >= 0 |
| 3213 | + _assert_scalar_3 = torch.ops.aten._assert_scalar.default(ge_7, "Runtime assertion failed for expression u1 >= 0 on node 'ge_3'"); ge_7 = _assert_scalar_3 = None |
| 3214 | + mul: "Sym(u2*u3)" = arg1_1 * arg2_1; arg1_1 = arg2_1 = None |
| 3215 | + mul_1: "Sym(u0*u1)" = _local_scalar_dense * _local_scalar_dense_1 |
| 3216 | + eq: "Sym(Eq(u2*u3, u0*u1))" = mul == mul_1; mul = mul_1 = None |
| 3217 | + _assert_scalar_4 = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(u2*u3, u0*u1) on node 'eq'"); eq = _assert_scalar_4 = None |
| 3218 | + clone: "f32[u2, u3][Max(1, u3), 1]cpu" = torch.ops.aten.clone.default(arg3_1, memory_format = torch.contiguous_format); arg3_1 = None |
| 3219 | + view: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.view.default(clone, [_local_scalar_dense, _local_scalar_dense_1]); clone = _local_scalar_dense = _local_scalar_dense_1 = None |
| 3220 | + mul_16: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.mul.Tensor(view, 10); view = None |
| 3221 | + return (mul_16,)""", # noqa: B950 |
| 3222 | + ignore_comments=True, |
| 3223 | + ignore_empty_lines=True, |
| 3224 | + ) |
| 3225 | + |
| 3226 | + result_eager = func(x, torch.tensor([2, 50])) |
| 3227 | + result_compiled = compiled_func(x, torch.tensor([2, 50])) |
| 3228 | + self.assertEqual(result_compiled, result_eager) |
| 3229 | + self.assertEqual(cnt.frame_count, 1) |
| 3230 | + |
| 3231 | + x = torch.randn(4, 4).t_() |
| 3232 | + result_eager = func(x, torch.tensor([2, 8])) |
| 3233 | + result_compiled = compiled_func(x, torch.tensor([2, 8])) |
| 3234 | + self.assertEqual(result_compiled, result_eager) |
| 3235 | + self.assertEqual(cnt.frame_count, 1) |
| 3236 | + |
| 3237 | + @unittest.skip("this test fails due to inductor/autograd issue #153041") |
| 3238 | + @torch._dynamo.config.patch("capture_scalar_outputs", True) |
| 3239 | + def test_unbacked_non_contigious_reshape_failing(self): |
| 3240 | + # reshape u1 -> (u0*u0) |
| 3241 | + # this result in the tensor "i64[u0, u0][s7*u0, s7]. |
| 3242 | + # reshape happens in place reshape (no-clone) |
| 3243 | + def func(x, y): |
| 3244 | + f = y.item() |
| 3245 | + t1 = x.view((f, f)) |
| 3246 | + t2 = x.reshape((f, f)) |
| 3247 | + return t1, t2 |
| 3248 | + |
| 3249 | + # create a non-contigious with data being even numbers in [0:cnt-1] |
| 3250 | + def make_non_contiguous_tensor(cnt): |
| 3251 | + # create a non-contiguous tensor x that is skipping odd indices. |
| 3252 | + x = torch.arange(cnt * 2, device="cuda") |
| 3253 | + x = x.as_strided((x.size()[0] // 2,), (2,)) |
| 3254 | + return x |
| 3255 | + |
| 3256 | + x = make_non_contiguous_tensor(4) |
| 3257 | + torch._dynamo.decorators.mark_unbacked(x, 0) |
| 3258 | + compiled_func = torch.compile( |
| 3259 | + fullgraph=True, |
| 3260 | + backend="inductor", |
| 3261 | + )(func) |
| 3262 | + |
| 3263 | + compiled_result = compiled_func(x, torch.tensor([2])) |
| 3264 | + eager_result = func(x, torch.tensor([2])) |
| 3265 | + self.assertEqual(compiled_result, eager_result) |
| 3266 | + |
3053 | 3267 |
|
3054 | 3268 | if __name__ == "__main__":
|
3055 | 3269 | run_tests()
|
0 commit comments