|
9 | 9 | from torch._dynamo import reset
|
10 | 10 | from torch._dynamo.exc import BackendCompilerFailed
|
11 | 11 | from torch._dynamo.testing import rand_strided, reset_rng_state
|
| 12 | +from torch._dynamo.utils import same |
12 | 13 | from torch._inductor import config
|
13 | 14 | from torch._inductor.autotune_process import (
|
14 | 15 | BenchmarkRequest,
|
@@ -979,6 +980,49 @@ def mock_lookup(self, *args, **kwargs):
|
979 | 980 | torch.compile(lambda a, b: a.matmul(b))(a, b)
|
980 | 981 | self.assertIn("NoValidChoicesError", str(context.exception))
|
981 | 982 |
|
| 983 | + @unittest.skipIf( |
| 984 | + not torch.cuda.is_available() |
| 985 | + or torch.cuda.get_device_properties().total_memory < 2e10, |
| 986 | + "Only if the GPU has at least 20GB memory to be safe", |
| 987 | + ) |
| 988 | + @config.patch(force_shape_pad=True, max_autotune=True) |
| 989 | + def test_linear_and_cel(self): |
| 990 | + """ |
| 991 | + Similate a GPU without enough SMs. Make sure max-autotune still |
| 992 | + works even when the MultiTritonTemplate encapsulates just extern |
| 993 | + kernels. |
| 994 | + """ |
| 995 | + |
| 996 | + def mock_is_big_gpu(*args, **kwargs): |
| 997 | + return False |
| 998 | + |
| 999 | + B, T, C, V = 32, 1024, 768, 50257 |
| 1000 | + |
| 1001 | + linear = nn.Linear(C, V).bfloat16().to(device=GPU_TYPE) |
| 1002 | + ce = torch.nn.CrossEntropyLoss() |
| 1003 | + |
| 1004 | + def f(x, y): |
| 1005 | + x.grad = None |
| 1006 | + linear.weight.grad = None |
| 1007 | + linear.bias.grad = None |
| 1008 | + |
| 1009 | + loss = ce(linear(x), y) |
| 1010 | + loss.backward() |
| 1011 | + return loss |
| 1012 | + |
| 1013 | + x = torch.randn(B * T, C, requires_grad=True).cuda().bfloat16() |
| 1014 | + x.retain_grad() |
| 1015 | + y = torch.randint(0, V, (B * T,)).cuda() |
| 1016 | + |
| 1017 | + import torch._inductor.utils as inductor_utils |
| 1018 | + |
| 1019 | + with unittest.mock.patch.object(inductor_utils, "is_big_gpu", mock_is_big_gpu): |
| 1020 | + opt_f = torch.compile(f) |
| 1021 | + |
| 1022 | + expect = (f(x, y), x.grad, linear.weight.grad, linear.bias.grad) |
| 1023 | + actual = (opt_f(x, y), x.grad, linear.weight.grad, linear.bias.grad) |
| 1024 | + assert same(expect, actual, tol=1e-2), f"ref:\n{expect}\nact:\n{actual}" |
| 1025 | + |
982 | 1026 |
|
983 | 1027 | @instantiate_parametrized_tests
|
984 | 1028 | class TestMaxAutotuneRemoteCache(TestCase):
|
|
0 commit comments