@@ -4317,7 +4317,7 @@ class M_v0(torch.nn.Module):
4317
4317
def forward(self, t):
4318
4318
items = [t[i].item() for i in range(t.numel())]
4319
4319
r = torch.randn([items[0], items[1]])
4320
- # Could not guard on data-dependent expression Eq(u2, -1 )
4320
+ # Could not guard on data-dependent expression Ne(Mod(u1, u2), 0 )
4321
4321
return r.view(items[0], items[2])
4322
4322
4323
4323
M = M_v0
@@ -4326,69 +4326,23 @@ def forward(self, t):
4326
4326
"The following call raised this error(.*\n)+"
4327
4327
f".*{re.escape('return r.view(items[0], items[2])')}(.*\n)+"
4328
4328
"To fix the error, insert one of the following checks before this call.*:\n"
4329
- f".*{re.escape('torch._check(items[2] == (-1))')}.*\n"
4330
- f".*{re.escape('torch._check(items[2] != (-1))')}(.*\n)+"
4331
- f".*{re.escape('(These suggested fixes were derived by replacing `u2` with items[2] in Eq(u2, -1) and its negation.)')}",
4329
+ f".*{re.escape('torch._check((items[1] % items[2]) != 0)')}.*\n"
4330
+ f".*{re.escape('torch._check((items[1] % items[2]) == 0)')}(.*\n)+"
4331
+ f".*{re.escape('(These suggested fixes were derived by replacing `u1` with items[1]')}"
4332
+ f".*{re.escape('or r.shape[1], `u2` with items[2] in Ne(Mod(u1, u2), 0) and its negation.')}",
4332
4333
):
4333
4334
export(N(), (t,), strict=strict)
4334
4335
4335
4336
class M_v1(torch.nn.Module):
4336
4337
def forward(self, t):
4337
4338
items = [t[i].item() for i in range(t.numel())]
4338
4339
r = torch.randn([items[0], items[1]])
4339
- # Could not guard on data-dependent expression Eq(u2, -1)
4340
- torch._check(items[2] != -1)
4341
- # Could not guard on data-dependent expression u2 >= 0
4340
+ # TODO(pianpwk): this isn't the suggested fixes.
4341
+ # fix issue with % being interpreted as PythonMod instead of Mod
4342
+ torch._check(items[1] == items[2])
4342
4343
return r.view(items[0], items[2])
4343
4344
4344
4345
M = M_v1
4345
- with self.assertRaisesRegex(
4346
- error_type,
4347
- "The following call raised this error(.*\n)+"
4348
- f".*{re.escape('return r.view(items[0], items[2])')}(.*\n)+"
4349
- "To fix the error, insert one of the following checks before this call.*:\n"
4350
- f".*{re.escape('You can add either: torch._check_is_size(u2) or torch._check(u2>=0) Note: torch._check_is_size(u2) could prevent data dependent errors that happen in a guard_size_oblivious(..) context by opting into guard_size_oblivious reasoning. See documentation on guard_size_oblivious for more details: https://pytorch.org/docs/stable/generated/torch.fx.experimental.symbolic_shapes.guard_size_oblivious.html')}.*\n"
4351
- f".*{re.escape('torch._check(items[2] < 0)')}(.*\n)+"
4352
- f".*{re.escape('(These suggested fixes were derived by replacing `u2` with items[2] in u2 >= 0 and its negation.)')}",
4353
- ):
4354
- export(N(), (t,), strict=strict)
4355
-
4356
- class M_v2(torch.nn.Module):
4357
- def forward(self, t):
4358
- items = [t[i].item() for i in range(t.numel())]
4359
- r = torch.randn([items[0], items[1]])
4360
- # Could not guard on data-dependent expression Eq(u2, -1)
4361
- torch._check(items[2] != -1)
4362
- # Could not guard on data-dependent expression u2 >= 0
4363
- torch._check(items[2] >= 0)
4364
- # Could not guard on data-dependent expression Eq(u1, u2)
4365
- return r.view(items[0], items[2])
4366
-
4367
- M = M_v2
4368
- with self.assertRaisesRegex(
4369
- error_type,
4370
- "The following call raised this error(.*\n)+"
4371
- f".*{re.escape('return r.view(items[0], items[2])')}(.*\n)+"
4372
- "To fix the error, insert one of the following checks before this call.*:\n"
4373
- f".*{re.escape('torch._check(items[2] == items[1])')}.*\n"
4374
- f".*{re.escape('torch._check(items[2] != items[1])')}(.*\n)+"
4375
- f".*{re.escape('(These suggested fixes were derived by replacing `u1` with items[1] or r.shape[1], `u2` with items[2] in Eq(u2, u1) and its negation.)')}",
4376
- ):
4377
- export(N(), (t,), strict=strict)
4378
-
4379
- class M_v3(torch.nn.Module):
4380
- def forward(self, t):
4381
- items = [t[i].item() for i in range(t.numel())]
4382
- r = torch.randn([items[0], items[1]])
4383
- # Could not guard on data-dependent expression Eq(u2, -1)
4384
- torch._check(items[2] != -1)
4385
- # Could not guard on data-dependent expression u2 >= 0
4386
- torch._check(items[2] >= 0)
4387
- # Could not guard on data-dependent expression Eq(u1, u2)
4388
- torch._check(items[2] == r.shape[1])
4389
- return r.view(items[0], items[2])
4390
-
4391
- M = M_v3
4392
4346
export(N(), (t,), strict=strict)
4393
4347
4394
4348
def test_suggested_fixes_for_data_dependent_errors_puzzlers(self):
@@ -4500,6 +4454,29 @@ def forward(self, x, offsets_t, fixes):
4500
4454
fixes=[], # nothing to fix!
4501
4455
)
4502
4456
4457
9E81
td>+ def test_simple_unbacked_view(self):
4458
+ class Foo(torch.nn.Module):
4459
+ def forward(self, x):
4460
+ u0 = x.item()
4461
+ y = torch.empty(5, u0)
4462
+ return y.view(u0, 5) # [5, u0] -> [u0, 5]
4463
+
4464
+ ep = export(Foo(), (torch.tensor([9]),))
4465
+ self.assertEqual(ep.module()(torch.tensor([8])).size(0), 8)
4466
+ self.assertEqual(ep.module()(torch.tensor([5])).size(0), 5)
4467
+
4468
+ class Foov2(torch.nn.Module):
4469
+ def forward(self, xs):
4470
+ xsl = xs.tolist()
4471
+ a, b = xsl
4472
+ x = torch.zeros(a)
4473
+ return x.reshape(b)
4474
+
4475
+ xs = torch.tensor([4, 4])
4476
+ ep = export(Foov2(), (xs,))
4477
+ self.assertEqual(ep.module()(xs).size(0), 4)
4478
+ self.assertEqual(ep.module()(torch.tensor([5, 5])).size(0), 5)
4479
+
4503
4480
def test_no_suggested_fixes_for_data_dependent_errors(self):
4504
4481
# suggested fixes for data-dependent errors only work in non-strict mode
4505
4482
strict = False
@@ -7422,22 +7399,19 @@ def forward(self, xs, y):
7422
7399
len([node for node in gm.graph.nodes if node.op == "placeholder"]), 2
7423
7400
)
7424
7401
7425
- def test_check_is_size_error (self):
7402
+ def test_no_check_is_size_error (self):
7426
7403
class Module(torch.nn.Module):
7427
7404
def forward(self, x):
7428
7405
a = x.item()
7429
- # We cannot automatically infer a is a size here because view
7430
- # accepts -1
7431
7406
return torch.randn(24).view(a, 4)
7432
7407
7433
7408
f = Module()
7434
- if is_non_strict_test(self._testMethodName):
7435
- error = torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode
7436
- else:
7437
- error = torch._dynamo.exc.UserError
7438
- error_msg = r"Could not guard on data-dependent expression"
7439
- with self.assertRaisesRegex(error, error_msg):
7440
- _ = export(f, (torch.tensor(6),))
7409
+ ep = export(f, (torch.tensor(6),))
7410
+ ep.module()(torch.tensor(6))
7411
+ with self.assertRaisesRegex(
7412
+ RuntimeError, r"Runtime assertion failed for .* u.* 6"
7413
+ ):
7414
+ ep.module()(torch.tensor(5))
7441
7415
7442
7416
def test_is_non_negative_check_function(self):
7443
7417
import sympy as sp
@@ -13281,7 +13255,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
13281
13255
node.target == torch.ops.aten._assert_scalar.default
13282
13256
for node in ep.graph.nodes
13283
13257
].count(True)
13284
- self.assertEqual(num_asserts, 1 )
13258
+ self.assertEqual(num_asserts, 2 )
13285
13259
with self.assertRaises(RuntimeError):
13286
13260
ep.module()(torch.randn(4, 2))
13287
13261
0 commit comments