8000 [inductor] fix MA on poor gpu · pytorch/pytorch@69006be · GitHub
[go: up one dir, main page]

Skip to content

Commit 69006be

Browse files
committed
[inductor] fix MA on poor gpu
[ghstack-poisoned]
1 parent 0f051ea commit 69006be

File tree

2 files changed

+49
-3
lines changed

2 files changed

+49
-3
lines changed

test/inductor/test_max_autotune.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from torch._dynamo import reset
1010
from torch._dynamo.exc import BackendCompilerFailed
1111
from torch._dynamo.testing import rand_strided, reset_rng_state
12+
from torch._dynamo.utils import same
1213
from torch._inductor import config
1314
from torch._inductor.autotune_process import (
1415
BenchmarkRequest,
@@ -979,6 +980,49 @@ def mock_lookup(self, *args, **kwargs):
979980
torch.compile(lambda a, b: a.matmul(b))(a, b)
980981
self.assertIn("NoValidChoicesError", str(context.exception))
981982

983+
@unittest.skipIf(
984+
not torch.cuda.is_available()
985+
or torch.cuda.get_device_properties().total_memory < 2e10,
986+
"Only if the GPU has at least 20GB memory to be safe",
987+
)
988+
@config.patch(force_shape_pad=True, max_autotune=True)
989+
def test_linear_and_cel(self):
990+
"""
991+
Similate a GPU without enough SMs. Make sure max-autotune still
992+
works even when the MultiTritonTemplate encapsulates just extern
993+
kernels.
994+
"""
995+
996+
def mock_is_big_gpu(*args, **kwargs):
997+
return False
998+
999+
B, T, C, V = 32, 1024, 768, 50257
1000+
1001+
linear = nn.Linear(C, V).bfloat16().to(device=GPU_TYPE)
1002+
ce = torch.nn.CrossEntropyLoss()
1003+
1004+
def f(x, y):
1005+
x.grad = None
1006+
linear.weight.grad = None
1007+
linear.bias.grad = None
1008+
1009+
loss = ce(linear(x), y)
1010+
loss.backward()
1011+
return loss
1012+
1013+
x = torch.randn(B * T, C, requires_grad=True).cuda().bfloat16()
1014+
x.retain_grad()
1015+
y = torch.randint(0, V, (B * T,)).cuda()
1016+
1017+
import torch._inductor.utils as inductor_utils
1018+
1019+
with unittest.mock.patch.object(inductor_utils, "is_big_gpu", mock_is_big_gpu):
1020+
opt_f = torch.compile(f)
1021+
1022+
expect = (f(x, y), x.grad, linear.weight.grad, linear.bias.grad)
1023+
actual = (opt_f(x, y), x.grad, linear.weight.grad, linear.bias.grad)
1024+
assert same(expect, actual, tol=1e-2), f"ref:\n{expect}\nact:\n{actual}"
1025+
9821026

9831027
@instantiate_parametrized_tests
9841028
class TestMaxAutotuneRemoteCache(TestCase):

torch/_inductor/ir.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1150,9 +1150,11 @@ def inner_reduction_splits(
11501150
# No need to split.
11511151
return ReductionHint.INNER, split
11521152
if input_node is not None and isinstance(input_node, TensorBox):
1153-
new_ranges, new_reduction_ranges = extract_input_node_reduction_ranges(
1154-
input_node
1155-
)
1153+
with patch.object(FlexibleLayout, "allow_indexing", True):
1154+
(
1155+
new_ranges,
1156+
new_reduction_ranges,
1157+
) = extract_input_node_reduction_ranges(input_node)
11561158
if new_ranges is not None and new_reduction_ranges is not None:
11571159
extracted_numel_hint = V.graph.sizevars.symbolic_hint(
11581160
sympy_product(new_ranges + new_reduction_ranges)

0 commit comments

Comments
 (0)
0