|
12 | 12 | import torch
|
13 | 13 | import torch._export
|
14 | 14 | import torch._inductor
|
| 15 | +import torch._inductor.config |
15 | 16 | import torch.nn as nn
|
16 | 17 | from torch._dynamo.testing import rand_strided, same
|
17 | 18 | from torch._dynamo.utils import counters
|
@@ -1313,14 +1314,19 @@ def fn(a, b, alpha=1.0):
|
1313 | 1314 | with self.assertRaises(RuntimeError):
|
1314 | 1315 | torch._export.aot_compile(fn, args=(a, b), kwargs={"alpha": 2.0})
|
1315 | 1316 |
|
1316 |
| - so_path = torch._export.aot_compile( |
1317 |
| - torch.ops.aten.add, args=(a, b), kwargs={"alpha": 2.0}, same_signature=False |
1318 |
| - ) |
1319 |
| - kernel_runner = AOTIRunnerUtil.load_runner(self.device, so_path) |
1320 |
| - res = kernel_runner.run([a, b]) |
1321 |
| - self.assertTrue(isinstance(res, list)) |
1322 |
| - self.assertTrue(len(res) == 1) |
1323 |
| - self.assertEqual(fn(a, b, alpha=2.0), res[0]) |
| 1317 | + for simdlen in [0, None]: |
| 1318 | + with torch._inductor.config.patch({"cpp.simdlen": simdlen}): |
| 1319 | + so_path = torch._export.aot_compile( |
| 1320 | + torch.ops.aten.add, |
| 1321 | + args=(a, b), |
| 1322 | + kwargs={"alpha": 2.0}, |
| 1323 | + same_signature=False, |
| 1324 | + ) |
| 1325 | + kernel_runner = AOTIRunnerUtil.load_runner(self.device, so_path) |
| 1326 | + res = kernel_runner.run([a, b]) |
| 1327 | + self.assertTrue(isinstance(res, list)) |
| 1328 | + self.assertTrue(len(res) == 1) |
| 1329 | + self.assertEqual(fn(a, b, alpha=2.0), res[0]) |
1324 | 1330 |
|
1325 | 1331 | def test_buffer_mutation_2(self):
|
1326 | 1332 | class Model(torch.nn.Module):
|
|
0 commit comments