8000 Disabling amp context when invoking compiler (#138624) · pytorch/pytorch@a1e8334 · GitHub
[go: up one dir, main page]

Skip to content

Commit a1e8334

Browse files
eellisonpytorchbot
authored andcommitted
Disabling amp context when invoking compiler (#138624)
Fix for #133974 Pull Request resolved: #138624 Approved by: https://github.com/bdhirsh, https://github.com/drisspg (cherry picked from commit 5942b29)
1 parent f31b8bb commit a1e8334

File tree

2 files changed

+63
-24
lines changed

2 files changed

+63
-24
lines changed

test/inductor/test_cpu_repro.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3941,6 +3941,47 @@ def forward(self, x):
39413941
x = torch.randn(1, 4, 2, 2)
39423942
self.common(fn, (x,))
39433943

3944+
@parametrize("is_inference", (True, False))
3945+
def test_disabled_amp(self, is_inference):
3946+
class M(torch.nn.Module):
3947+
def __init__(self):
3948+
super().__init__()
3949+
self.all_head_size = 12 * 64
3950+
self.dense = nn.Linear(self.all_head_size, self.all_head_size)
3951+
3952+
def forward(self, q, k, v):
3953+
context_layer = F.scaled_dot_product_attention(
3954+
q, k, v, attn_mask=None, dropout_p=0.2
3955+
)
3956+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
3957+
new_context_layer_shape = context_layer.size()[:-2] + (
3958+
self.all_head_size,
3959+
)
3960+
context_layer = context_layer.view(new_context_layer_shape)
3961+
return self.dense(context_layer)
3962+
3963+
mod = M().to(torch.bfloat16).eval()
3964+
3965+
q = torch.randn((4, 12, 512, 64), dtype=torch.bfloat16) / 10.0
3966+
k = torch.randn((4, 12, 512, 64), dtype=torch.bfloat16) / 10.0
3967+
v = torch.randn((4, 12, 512, 64), dtype=torch.bfloat16) / 10.0
3968+
inputs = (
3969+
q,
3970+
k,
3971+
v,
3972+
)
3973+
compiler_mode = torch.compile(mod)
3974+
from torch.nn.attention import sdpa_kernel, SDPBackend
3975+
3976+
context = contextlib.nullcontext if not is_inference else torch.no_grad
3977+
with config.patch(
3978+
{"fallback_random": True}
3979+
), torch.cpu.amp.autocast(), context(), sdpa_kernel(SDPBackend.MATH):
3980+
torch.manual_seed(0)
3981+
eager = mod(*inputs)
3982+
torch.manual_seed(0)
3983+
self.assertEqual(compiler_mode(*inputs), eager)
3984+
39443985
@requires_vectorization
39453986
def test_vec_indirect_load_cse_cache(self):
39463987
# https://github.com/pytorch/pytorch/issues/123502

torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py

Lines changed: 22 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -555,7 +555,9 @@ def aot_dispatch_autograd(
555555
),
556556
)
557557

558-
with track_graph_compiling(aot_config, "forward"):
558+
# AMP is already traced out in joint graph. we do not wish to reapply it accidentally
559+
# in the compiler.
560+
with track_graph_compiling(aot_config, "forward"), torch._C._DisableAutocast():
559561
# flat_args at this point might still be subclasses-
560562
# make sure to pass the unwrapped fake tensors into the compiler!
561563
adjusted_flat_args = joint_inputs[0]
@@ -620,7 +622,7 @@ def aot_dispatch_autograd(
620622
# NB: It's important to compile backwards ahead of time, as this may
621623
# add extra guards which we need to apply to the Dynamo cache at
622624
# forwards
623-
with track_graph_compiling(aot_config, "backward"):
625+
with track_graph_compiling(aot_config, "backward"), torch._C._DisableAutocast():
624626
placeholder_list = fx_placeholder_vals(bw_module)
625627

626628
forward_saved_for_backwards_strides = None
@@ -672,28 +674,24 @@ def aot_dispatch_autograd(
672674

673675
compiled_bw_func = None
674676
if num_symints_saved_for_bw > 0:
675-
context = torch._C._DisableAutocast if disable_amp else nullcontext
676-
with context():
677-
try:
678-
compiled_bw_func = aot_config.bw_compiler(
679-
bw_module, placeholder_list
680-
)
681-
except Exception as e:
682-
exc = e
683-
trace_structured(
684-
"artifact",
685-
metadata_fn=lambda: {
686-
"name": "eager_compile_backwards_failure",
687-
"encoding": "string",
688-
},
689-
payload_fn=lambda: "\n".join(
690-
traceback.format_exception(exc)
691-
),
692-
)
693-
log.warning(
694-
"failed to eagerly compile backwards for dynamic, suppressing in case backwards not needed",
695-
exc_info=True,
696-
)
677+
try:
678+
compiled_bw_func = aot_config.bw_compiler(
679+
bw_module, placeholder_list
680+
)
681+
except Exception as e:
682+
exc = e
683+
trace_structured(
684+
"artifact",
685+
metadata_fn=lambda: {
686+
"name": "eager_compile_backwards_failure",
687+
"encoding": "string",
688+
},
689+
payload_fn=lambda: "\n".join(traceback.format_exception(exc)),
690+
)
691+
log.warning(
692+
"failed to eagerly compile backwards for dynamic, suppressing in case backwards not needed",
693+
exc_info=True,
694+
)
697695
# Compiled autograd will run the bw_module in the backward pass,
698696
# so recompilation need happen anyway if the backward pass is ever
699697
# called.

0 commit comments

Comments
 (0)
0