8000 [ca][dynamo] always run eager checkpoint region's recomputation in eager · pytorch/pytorch@759c9b4 · GitHub
[go: up one dir, main page]

Skip to content

Commit 759c9b4

Browse files
committed
[ca][dynamo] always run eager checkpoint region's recomputation in eager
ghstack-source-id: 304e31b Pull Request resolved: #153300
1 parent de60f27 commit 759c9b4

File tree

3 files changed

+23
-13
lines changed

3 files changed

+23
-13
lines changed

test/dynamo_expected_failures/TestAutograd.test_access_saved_tensor_twice_without_recomputation_works

Whitespace-only changes.

test/inductor/test_compiled_autograd.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4209,10 +4209,13 @@ def wrap_test_class(orig_cls):
42094209
):
42104210
dct[name] = unittest.expectedFailure
42114211
elif name.startswith("test_"):
4212+
backend = lookup_backend(name)
4213+
if not HAS_CUDA and backend == "inductor":
4214+
continue
42124215
ctxs = [
42134216
compiled_autograd._enable(
42144217
make_compiler_fn(
4215-
backend=lookup_backend(name),
4218+
backend=backend,
42164219
fullgraph=name not in known_graph_breaks_tests,
42174220
)
42184221
),
@@ -4305,6 +4308,8 @@ def wrap_test_class(orig_cls):
43054308
"test_full_backward_hook_double_backward", # _pack_with_none
43064309
"test_grad_mode_restored_reentrant", # assertTrue
43074310
"test_multi_grad_any_hooks", # register_multi_grad_hook
4311+
"test_saved_variable_packing_unpacking_did_not_save_original_with_hooks", # register_hooks
4312+
"test_graph_save_on_cpu", # dynamo disabled
43084313
}
43094314

43104315
test_contexts = {
@@ -4370,19 +4375,7 @@ def wrap_test_class(orig_cls):
43704375
"test_to_sparse_backward", # Out of bounds: frame_state_entry.stride[i] is None
43714376
"test_custom_function_non_tensor_inputs_outputs", # gradient batching rule not implemented for aten::sym_size.int
43724377
"test_setitem", # CopySlices accuracy error
4373-
"test_save_on_cpu_and_checkpoint", # https://github.com/pytorch/pytorch/issues/147565
4374-
"test_checkpoint_detects_non_determinism", # different error
4375-
"test_checkpointing_non_reentrant_autocast_cpu", # saved != recompute
4376-
"test_checkpointing_non_reentrant_autocast_gpu", # saved != recompute
43774378
"test_checkpointing_without_reentrant_saved_object_identity", # same as https://github.com/pytorch/pytorch/issues/136193
4378-
"test_saved_variable_packing_unpacking_did_not_save_original_with_hooks", # register_hooks multiple times
4379-
"test_saved_variable_saved_original_inplace_detach", # RuntimeError not raised
4380-
"test_access_saved_tensor_twice_without_recomputation_works", # saved != recompute
4381-
"test_checkpointing_without_reentrant_dataparallel", # https://github.com/pytorch/pytorch/issues/127115
4382-
"test_checkpointing", # takes very very long
4383-
"test_checkpointing_without_reentrant_input_requires_grad_False", # takes very very long
4384-
"test_checkpointing_without_reentrant_input_requires_grad_True", # takes very very long
4385-
"test_checkpointing_without_reentrant_memory_savings", # takes very very long
43864379
"test_dtensor_different_gradient_placement", # Dynamo failed to run FX node with fake tensors
43874380
"test_dtensor_noncontiguous_output", # Dynamo failed to run FX node with fake tensors
43884381
"test_dtensor_partial_placement_graph_output", # Dynamo failed to run FX node with fake tensors

torch/utils/checkpoint.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,7 @@ def backward(ctx, *args):
328328
def noop_context_fn():
329329
return contextlib.nullcontext(), contextlib.nullcontext()
330330

331+
# Note: [torch.compile and checkpoint]
331332
# TorchDynamo does not step inside utils.checkpoint function. The flow
332333
# looks likes this
333334
# 1) TorchDynamo tries to wrap utils.checkpoint in a HigherOrderOp by
@@ -1106,6 +1107,8 @@ def pack_hook(x):
11061107
frame.x_metadatas.append(frame.metadata_fn(x))
11071108
return holder
11081109

1110+
# See Note: [compiled autograd and checkpoint unpack hook]
1111+
@torch._disable_dynamo
11091112
def unpack_hook(holder):
11101113
gid = torch._C._current_graph_task_id()
11111114
if gid == -1:
@@ -1541,3 +1544,17 @@ def recompute_fn(*inputs):
15411544
)
15421545

15431546
return
1547+
1548+
# Note: [compiled autograd and checkpoint unpack hook]
1549+
# When tracing via compiled autograd, this hook will be visible to the
1550+
# compiler if the forward of this checkpointed region ran in eager.
1551+
# If the forward had ran under compile, it would have been wrapped in a
1552+
# higher order op. See Note: [torch.compile and checkpoint].
1553+
#
1554+
# Since we run the recomputation hook under a enable_grad context,
1555+
# AOTDispatch will trace a joint graph for this hook, and may
1556+
# save different activations than in eager. This conflicts with the
1557+
# strict activation count checks in `frame.check_recomputed_tensors_match`.
1558+
# So, we disable this hook to force it to recompute eager checkpointed regions
1559+
# in eager. This could be removed if we can disable the partitioner for this
1560+
# graph segment.

0 commit comments

Comments
 (0)
0