8000 [ca][dynamo] always run eager checkpoint region's recomputation in eager · pytorch/pytorch@21539bd · GitHub
[go: up one dir, main page]

Skip to content

Commit 21539bd

Browse files
committed
[ca][dynamo] always run eager checkpoint region's recomputation in eager
ghstack-source-id: b35d358 Pull Request resolved: #153300
1 parent b03f6b9 commit 21539bd

File tree

2 files changed

+18
-12
lines changed

2 files changed

+18
-12
lines changed

test/inductor/test_compiled_autograd.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4282,6 +4282,7 @@ def wrap_test_class(orig_cls):
42824282
"test_full_backward_hook_double_backward", # _pack_with_none
42834283
"test_grad_mode_restored_reentrant", # assertTrue
42844284
"test_multi_grad_any_hooks", # register_multi_grad_hook
4285+
"test_saved_variable_packing_unpacking_did_not_save_original_with_hooks", # register_hooks
42854286
}
42864287

42874288
test_contexts = {
@@ -4345,19 +4346,7 @@ def wrap_test_class(orig_cls):
43454346
"test_return_duplicate", # gradient batching rule not implemented for aten::sym_size.int
43464347
"test_return_duplicate_inplace", # gradient batching rule not implemented for aten::sym_size.int
43474348
"test_setitem", # CopySlices accuracy error
4348-
"test_save_on_cpu_and_checkpoint", # https://github.com/pytorch/pytorch/issues/147565
4349-
"test_checkpoint_detects_non_determinism", # different error
4350-
"test_checkpointing_non_reentrant_autocast_cpu", # saved != recompute
4351-
"test_checkpointing_non_reentrant_autocast_gpu", # saved != recompute
43524349
"test_checkpointing_without_reentrant_saved_object_identity", # same as https://github.com/pytorch/pytorch/issues/136193
4353-
"test_saved_variable_packing_unpacking_did_not_save_original_with_hooks", # register_hooks multiple times
4354-
"test_saved_variable_saved_original_inplace_detach", # RuntimeError not raised
4355-
"test_access_saved_tensor_twice_without_recomputation_works", # saved != recompute
4356-
"test_checkpointing_without_reentrant_dataparallel", # https://github.com/pytorch/pytorch/issues/127115
4357-
"test_checkpointing", # takes very very long
4358-
"test_checkpointing_without_reentrant_input_requires_grad_False", # takes very very long
4359-
"test_checkpointing_without_reentrant_input_requires_grad_True", # takes very very long
4360-
"test_checkpointing_without_reentrant_memory_savings", # takes very very long
43614350
"test_dtensor_different_gradient_placement", # Dynamo failed to run FX node with fake tensors
43624351
"test_dtensor_noncontiguous_output", # Dynamo failed to run FX node with fake tensors
43634352
"test_dtensor_partial_placement_graph_output", # Dynamo failed to run FX node with fake tensors

torch/utils/checkpoint.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,7 @@ def backward(ctx, *args):
328328
def noop_context_fn():
329329
return contextlib.nullcontext(), contextlib.nullcontext()
330330

331+
# Note: [torch.compile and checkpoint]
331332
# TorchDynamo does not step inside utils.checkpoint function. The flow
332333
# looks likes this
333334
# 1) TorchDynamo tries to wrap utils.checkpoint in a HigherOrderOp by
@@ -1106,6 +1107,8 @@ def pack_hook(x):
11061107
frame.x_metadatas.append(frame.metadata_fn(x))
11071108
return holder
11081109

1110+
# See Note: [compiled autograd and checkpoint unpack hook]
1111+
@torch._disable_dynamo
11091112
def unpack_hook(holder):
11101113
gid = torch._C._current_graph_task_id()
11111114
if gid == -1:
@@ -1541,3 +1544,17 @@ def recompute_fn(*inputs):
15411544
)
15421545

15431546
return
1547+
1548+
# Note: [compiled autograd and checkpoint unpack hook]
1549+
# When tracing via compiled autograd, this hook will be visible to the
1550+
# compiler if the forward of this checkpointed region ran in eager.
1551+
# If the forward had ran under compile, it would have been wrapped in a
1552+
# higher order op. See Note: [torch.compile and checkpoint].
1553+
#
1554+
# Since we run the recomputation hook under a enable_grad context,
1555+
# AOTDispatch will trace a joint graph for this hook, and may
1556+
# save different activations than in eager. This conflicts with the
1557+
# strict activation count checks in `frame.check_recomputed_tensors_match`.
1558+
# So, we disable this hook to force it to recompute eager checkpointed regions
1559+
# in eager. This could be removed if we can disable the partitioner for this
1560+
# graph segment.

0 commit comments

Comments
 (0)
0