@@ -4306,7 +4306,7 @@ class M_v0(torch.nn.Module):
4306
4306
def forward(self, t):
4307
4307
items = [t[i].item() for i in range(t.numel())]
4308
4308
r = torch.randn([items[0], items[1]])
4309
- # Could not guard on data-dependent expression Ne(Mod(u1, u2), 0 )
4309
+ # Could not guard on data-dependent expression Eq(u2, -1 )
4310
4310
return r.view(items[0], items[2])
4311
4311
4312
4312
M = M_v0
@@ -4315,23 +4315,69 @@ def forward(self, t):
4315
4315
"The following call raised this error(.*\n)+"
4316
4316
f".*{re.escape('return r.view(items[0], items[2])')}(.*\n)+"
4317
4317
"To fix the error, insert one of the following checks before this call.*:\n"
4318
- f".*{re.escape('torch._check((items[1] % items[2]) != 0)')}.*\n"
4319
- f".*{re.escape('torch._check((items[1] % items[2]) == 0)')}(.*\n)+"
4320
- f".*{re.escape('(These suggested fixes were derived by replacing `u1` with items[1]')}"
4321
- f".*{re.escape('or r.shape[1], `u2` with items[2] in Ne(Mod(u1, u2), 0) and its negation.')}",
4318
+ f".*{re.escape('torch._check(items[2] == (-1))')}.*\n"
4319
+ f".*{re.escape('torch._check(items[2] != (-1))')}(.*\n)+"
4320
+ f".*{re.escape('(These suggested fixes were derived by replacing `u2` with items[2] in Eq(u2, -1) and its negation.)')}",
4322
4321
):
4323
4322
export(N(), (t,), strict=strict)
4324
4323
4325
4324
class M_v1(torch.nn.Module):
4326
4325
def forward(self, t):
4327
4326
items = [t[i].item() for i in range(t.numel())]
4328
4327
r = torch.randn([items[0], items[1]])
4329
- # TODO(pianpwk): this isn't the suggested fixes.
4330
- # fix issue with % being interpreted as PythonMod instead of Mod
4331
- torch._check(items[1] == items[2])
4328
+ # Could not guard on data-dependent expression Eq(u2, -1)
4329
+ torch._check(items[2] != -1)
4330
+ # Could not guard on data-dependent expression u2 >= 0
4332
4331
return r.view(items[0], items[2])
4333
4332
4334
4333
M = M_v1
4334
+ with self.assertRaisesRegex(
4335
+ error_type,
4336
+ "The following call raised this error(.*\n)+"
4337
+ f".*{re.escape('return r.view(items[0], items[2])')}(.*\n)+"
4338
+ "To fix the error, insert one of the following checks before this call.*:\n"
4339
+ f".*{re.escape('You can add either: torch._check_is_size(u2) or torch._check(u2>=0) Note: torch._check_is_size(u2) could prevent data dependent errors that happen in a guard_size_oblivious(..) context by opting into guard_size_oblivious reasoning. See documentation on guard_size_oblivious for more details: https://pytorch.org/docs/stable/generated/torch.fx.experimental.symbolic_shapes.guard_size_oblivious.html')}.*\n"
4340
+ f".*{re.escape('torch._check(items[2] < 0)')}(.*\n)+"
4341
+ f".*{re.escape('(These suggested fixes were derived by replacing `u2` with items[2] in u2 >= 0 and its negation.)')}",
4342
+ ):
4343
+ export(N(), (t,), strict=strict)
4344
+
4345
+ class M_v2(torch.nn.Module):
4346
+ def forward(self, t):
4347
+ items = [t[i].item() for i in range(t.numel())]
4348
+ r = torch.randn([items[0], items[1]])
4349
+ # Could not guard on data-dependent expression Eq(u2, -1)
4350
+ torch._check(items[2] != -1)
4351
+ # Could not guard on data-dependent expression u2 >= 0
4352
+ torch._check(items[2] >= 0)
4353
+ # Could not guard on data-dependent expression Eq(u1, u2)
4354
+ return r.view(items[0], items[2])
4355
+
4356
+ M = M_v2
4357
+ with self.assertRaisesRegex(
4358
+ error_type,
4359
+ "The following call raised this error(.*\n)+"
4360
+ f".*{re.escape('return r.view(items[0], items[2])')}(.*\n)+"
4361
+ "To fix the error, insert one of the following checks before this call.*:\n"
4362
+ f".*{re.escape('torch._check(items[2] == items[1])')}.*\n"
4363
+ f".*{re.escape('torch._check(items[2] != items[1])')}(.*\n)+"
4364
+ f".*{re.escape('(These suggested fixes were derived by replacing `u1` with items[1] or r.shape[1], `u2` with items[2] in Eq(u2, u1) and its negation.)')}",
4365
+ ):
4366
+ export(N(), (t,), strict=strict)
4367
+
4368
+ class M_v3(torch.nn.Module):
4369
+ def forward(self, t):
4370
+ items = [t[i].item() for i in range(t.numel())]
4371
+ r = torch.randn([items[0], items[1]])
4372
+ # Could not guard on data-dependent expression Eq(u2, -1)
4373
+ torch._check(items[2] != -1)
4374
+ # Could not guard on data-dependent expression u2 >= 0
4375
+ torch._check(items[2] >= 0)
4376
+ # Could not guard on data-dependent expression Eq(u1, u2)
4377
+ torch._check(items[2] == r.shape[1])
4378
+ return r.view(items[0], items[2])
4379
+
4380
+ M = M_v3
4335
4381
export(N(), (t,), strict=strict)
4336
4382
4337
4383
def test_suggested_fixes_for_data_dependent_errors_puzzlers(self):
@@ -4443,29 +4489,6 @@ def forward(self, x, offsets_t, fixes):
4443
4489
fixes=[], # nothing to fix!
4444
4490
)
4445
4491
4446
- def test_simple_unbacked_view(self):
4447
- class Foo(torch.nn.Module):
4448
- def forward(self, x):
4449
- u0 = x.item()
4450
- y = torch.empty(5, u0)
4451
- return y.view(u0, 5) # [5, u0] -> [u0, 5]
4452
-
4453
- ep = export(Foo(), (torch.tensor([9]),))
4454
- self.assertEqual(ep.module()(torch.tensor([8])).size(0), 8)
4455
- self.assertEqual(ep.module()(torch.tensor([5])).size(0), 5)
4456
-
4457
- class Foov2(torch.nn.Module):
4458
- def forward(self, xs):
4459
- xsl = xs.tolist()
4460
- a, b = xsl
4461
- x = torch.zeros(a)
4462
- return x.reshape(b)
4463
-
4464
- xs = torch.tensor([4, 4])
4465
- ep = export(Foov2(), (xs,))
4466
- self.assertEqual(ep.module()(xs).size(0), 4)
4467
- self.assertEqual(ep.module()(torch.tensor([5, 5])).size(0), 5)
4468
-
4469
4492
def test_no_suggested_fixes_for_data_dependent_errors(self):
4470
4493
# suggested fixes for data-dependent errors only work in non-strict mode
4471
4494
strict = False
@@ -7388,19 +7411,22 @@ def forward(self, xs, y):
7388
7411
len([node for node in gm.graph.nodes if node.op == "placeholder"]), 2
7389
7412
)
7390
7413
7391
- def test_no_check_is_size_error (self):
7414
+ def test_check_is_size_error (self):
7392
7415
class Module(torch.nn.Module):
7393
7416
def forward(self, x):
7394
7417
a = x.item()
7418
+ # We cannot automatically infer a is a size here because view
7419
+ # accepts -1
7395
7420
return torch.randn(24).view(a, 4)
7396
7421
7397
7422
f = Module()
7398
- ep = export(f, (torch.tensor(6),))
7399
- ep.module()(torch.tensor(6))
7400
- with self.assertRaisesRegex(
7401
- RuntimeError, r"Runtime assertion failed for .* u.* 6"
7402
- ):
7403
- ep.module()(torch.tensor(5))
7423
+ if is_non_strict_test(self._testMethodName):
7424
+ error = torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode
7425
+ else:
7426
+ error = torch._dynamo.exc.UserError
7427
+ error_msg = r"Could not guard on data-dependent expression"
7428
+ with self.assertRaisesRegex(error, error_msg):
7429
+ _ = export(f, (torch.tensor(6),))
7404
7430
7405
7431
def test_is_non_negative_check_function(self):
7406
7432
import sympy as sp
@@ -13244,7 +13270,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
13244
13270
node.target == torch.ops.aten._assert_scalar.default
13245
13271
for node in ep.graph.nodes
13246
13272
].count(True)
13247
- self.assertEqual(num_asserts, 2 )
13273
+ self.assertEqual(num_asserts, 1 )
13248
13274
with self.assertRaises(RuntimeError):
13249
13275
ep.module()(torch.randn(4, 2))
13250
13276
0 commit comments