8000 [ca][dynamo] always run eager checkpoint region's recomputation in eager · pytorch/pytorch@8ffc21a · GitHub
[go: up one dir, main page]

Skip to content

Commit 8ffc21a

Browse files
committed
[ca][dynamo] always run eager checkpoint region's recomputation in eager
ghstack-source-id: 95cbcc3 Pull Request resolved: #153300
1 parent a7c7106 commit 8ffc21a

File tree

4 files changed

+23
-13
lines changed

4 files changed

+23
-13
lines changed

test/dynamo_expected_failures/TestAutograd.test_access_saved_tensor_twice_without_recomputation_works

Whitespace-only changes.

test/dynamo_expected_failures/TestAutograd.test_checkpoint_detects_non_determinism

Whitespace-only changes.

test/inductor/test_compiled_autograd.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4209,10 +4209,13 @@ def wrap_test_class(orig_cls):
42094209
):
42104210
dct[name] = unittest.expectedFailure
42114211
elif name.startswith("test_"):
4212+
backend = lookup_backend(name)
4213+
if not HAS_CUDA and backend == "inductor":
4214+
continue
42124215
ctxs = [
42134216
compiled_autograd._enable(
42144217
make_compiler_fn(
4215-
backend=lookup_backend(name),
4218+
backend=backend,
42164219
fullgraph=name not in known_graph_breaks_tests,
42174220
)
42184221
),
@@ -4305,6 +4308,8 @@ def wrap_test_class(orig_cls):
43054308
"test_full_backward_hook_double_backward", # _pack_with_none
43064309
"test_grad_mode_restored_reentrant", # assertTrue
43074310
"test_multi_grad_any_hooks", # register_multi_grad_hook
4311+
"test_saved_variable_packing_unpacking_did_not_save_original_with_hooks", # register_hooks
4312+
"test_graph_save_on_cpu", # dynamo disabled
43084313
}
43094314

43104315
test_contexts = {
@@ -4370,19 +4375,7 @@ def wrap_test_class(orig_cls):
43704375
"test_to_sparse_backward", # Out of bounds: frame_state_entry.stride[i] is None
43714376
"test_custom_function_non_tensor_inputs_outputs", # gradient batching rule not implemented for aten::sym_size.int
43724377
"test_setitem", # CopySlices accuracy error
4373-
"test_save_on_cpu_and_checkpoint", # https://github.com/pytorch/pytorch/issues/147565
4374-
"test_checkpoint_detects_non_determinism", # different error
4375-
"test_checkpointing_non_reentrant_autocast_cpu", # saved != recompute
4376-
"test_checkpointing_non_reentrant_autocast_gpu", # saved != recompute
43774378
"test_checkpointing_without_reentrant_saved_object_identity", # same as https://github.com/pytorch/pytorch/issues/136193
4378-
"test_saved_variable_packing_unpacking_did_not_save_original_with_hooks", # register_hooks multiple times
4379-
"test_saved_variable_saved_original_inplace_detach", # RuntimeError not raised
4380-
"test_access_saved_tensor_twice_without_recomputation_works", # saved != recompute
4381-
"test_checkpointing_without_reentrant_dataparallel", # https://github.com/pytorch/pytorch/issues/127115
4382-
"test_checkpointing", # takes very very long
4383-
"test_checkpointing_without_reentrant_input_requires_grad_False", # takes very very long
4384-
"test_checkpointing_without_reentrant_input_requires_grad_True", # takes very very long
4385-
"test_checkpointing_without_reentrant_memory_savings", # takes very very long
43864379
"test_dtensor_different_gradient_placement", # Dynamo failed to run FX node with fake tensors
43874380
"test_dtensor_noncontiguous_output", # Dynamo failed to run FX node with fake tensors
43884381
"test_dtensor_partial_placement_graph_output", # Dynamo failed to run FX node with fake tensors

torch/utils/checkpoint.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,7 @@ def backward(ctx, *args):
328328
def noop_context_fn():
329329
return contextlib.nullcontext(), contextlib.nullcontext()
330330

331+
# Note: [torch.compile and checkpoint]
331332
# TorchDynamo does not step inside utils.checkpoint function. The flow
332333
# looks likes this
333334
# 1) TorchDynamo tries to wrap utils.checkpoint in a HigherOrderOp by
@@ -1491,6 +1492,8 @@ def _checkpoint_without_reentrant_generator(
14911492
had_device_in_fwd = True
14921493
fwd_devices, fwd_device_states = get_device_states(*args)
14931494

1495+
# See Note: [compiled autograd and checkpoint unpack hook]
1496+
@torch._disable_dynamo
14941497
def recompute_fn(*inputs):
14951498
kwargs, *args = inputs
14961499
# This will be called later during recomputation. This wrapping enables
@@ -1541,3 +1544,17 @@ def recompute_fn(*inputs):
15411544
)
15421545

15431546
return
1547+
1548+
# Note: [compiled autograd and checkpoint unpack hook]
1549+
# When tracing via compiled autograd, this hook will be visible to the
1550+
# compiler if the forward of this checkpointed region ran in eager.
1551+
# If the forward had ran under compile, it would have been wrapped in a
1552+
# higher order op. See Note: [torch.compile and checkpoint].
1553+
#
1554+
# Since we run the recomputation hook under a enable_grad context,
1555+
# AOTDispatch will trace a joint graph for this hook, and may
1556+
# save different activations than in eager. This conflicts with the
1557+
# strict activation count checks in `frame.check_recomputed_tensors_match`.
1558+
# So, we disable this hook to force it to recompute eager checkpointed regions
1559+
# in eager. This could be removed if we can disable the partitioner for this
1560+
# graph segment.

0 commit comments

Comments
 (0)
0