8000 [scan] Autograd with partial gradient support by bohnstingl · Pull Request #146285 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[scan] Autograd with partial gradient support #146285

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 54 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
f6586dd
WIP: Integration of scan autograd
bohnstingl Feb 1, 2025
bda8cb9
Introduced autograd for scan with partial gradients support
bohnstingl Feb 3, 2025
34be923
Fixed CI test issue with graph
bohnstingl Feb 6, 2025
5eadb1e
Merge branch 'main' of github.com:pytorch/pytorch into scan_autograd22
bohnstingl Feb 7, 2025
adfa593
Integrated code review comments
bohnstingl Feb 18, 2025
7b41bde
Merge branch 'main' of github.com:pytorch/pytorch into scan_autograd22
bohnstingl Feb 19, 2025
c2465cf
Fixed type annotation
bohnstingl Feb 20, 2025
5bc1d27
Removed reverse flag from backend and implemented reverse with torch.…
bohnstingl Feb 25, 2025
2786a68
Fix of graph in testcase
bohnstingl Feb 25, 2025
7f48b4f
Merge branch 'scan_flip_reverse' of github.com:bohnstingl/pytorch int…
bohnstingl Feb 26, 2025
6717f9b
Integrated new reverse handling into scan autograd
8000 bohnstingl Feb 26, 2025
a60d2a8
Merge branch 'main' of github.com:pytorch/pytorch into scan_autograd22
bohnstingl Mar 7, 2025
0343056
Fixed issues with testcases and with combine_fn now returning pytrees
bohnstingl Mar 7, 2025
0dda954
[cond] don't trace fw and bw graph in autograd key
ydwu4 Mar 10, 2025
c11cac4
Update on "[cond] don't trace fw and bw graph in autograd key"
ydwu4 Mar 17, 2025
ebb5b57
Update on "[cond] don't trace fw and bw graph in autograd key"
ydwu4 Mar 18, 2025
46608f8
Update on "[cond] don't trace fw and bw graph in autograd key"
ydwu4 Mar 18, 2025
26e4a53
Update on "[cond] don't trace fw and bw graph in autograd key"
ydwu4 Mar 19, 2025
813c883
Merge branch 'main' of github.com:pytorch/pytorch into scan_autograd22
bohnstingl Mar 19, 2025
9e4f3c3
Merge branch 'gh/ydwu4/222/head' of github.com:pytorch/pytorch into s…
bohnstingl Mar 19, 2025
114c982
Update base for Update on "[cond] don't trace fw and bw graph in auto…
ydwu4 Mar 20, 2025
fdbfe11
Update on "[cond] don't trace fw and bw graph in autograd key"
ydwu4 Mar 20, 2025
7796ea9
Update base for Update on "[cond] don't trace fw and bw graph in auto…
ydwu4 Mar 20, 2025
6082a60
Update on "[cond] don't trace fw and bw graph in autograd key"
ydwu4 Mar 20, 2025
d33705f
Update base for Update on "[cond] don't trace fw and bw graph in auto…
ydwu4 Mar 20, 2025
7cee41c
Update on "[cond] don't trace fw and bw graph in autograd key"
ydwu4 Mar 20, 2025
23170b1
WIP: scan autograd
bohnstingl Mar 21, 2025
0ac5a15
Merge branch 'gh/ydwu4/222/head' of github.com:pytorch/pytorch into s…
bohnstingl Mar 21, 2025
93f5bab
Update base for Update on "[cond] don't trace fw and bw graph in auto…
ydwu4 Mar 21, 2025
ad92afa
Update on "[cond] don't trace fw and bw graph in autograd key"
ydwu4 Mar 21, 2025
5947348
Merge branch 'gh/ydwu4/222/head' of github.com:pytorch/pytorch into s…
bohnstingl Mar 22, 2025
60fc08f
Working function of scan autograd with new infrastructure
bohnstingl Mar 24, 2025
75fe416
Removed unnecessary code and other cleanups
bohnstingl Mar 24, 2025
26c94e9
Merge branch 'main' of github.com:pytorch/pytorch into scan_autograd22
bohnstingl Mar 24, 2025
357dde5
Merge branch 'main' of github.com:pytorch/pytorch into scan_autograd22
bohnstingl Mar 26, 2025
b018397
Fixed tracing issue causing missing fake meta values
bohnstingl Mar 27, 2025
8a288d3
Merge branch 'main' of github.com:pytorch/pytorch into scan_autograd22
bohnstingl Mar 27, 2025
63414f9
Fixed CI test
bohnstingl Mar 28, 2025
66f0b37
Merge branch 'main' of github.com:pytorch/pytorch into scan_autograd22
bohnstingl Mar 28, 2025
b50e7b9
Merge branch 'main' of github.com:pytorch/pytorch into scan_autograd22
bohnstingl Mar 31, 2025
0db488f
Fixed testcase
bohnstingl Mar 31, 2025
9df0724
Integrated review feedback; before code cleanup
bohnstingl Apr 1, 2025
ffacadd
After cleanup and lintrunner
bohnstingl Apr 2, 2025
4ce5f06
Fixes
bohnstingl Apr 2, 2025
88084f5
Further cleanup and restructuring
bohnstingl Apr 2, 2025
473db76
WIP: integration of code review
bohnstingl Apr 3, 2025
e6ca8b0
Added extensive documentation and lintrunner corrections
bohnstingl Apr 3, 2025
d27d50a
Revert change to cond.py
bohnstingl Apr 3, 2025
0ee4d01
Removed unrelated testcases and fixed lint issues
bohnstingl Apr 3, 2025
66c0ae1
Integrated review changes for documentation and test cases
bohnstingl Apr 4, 2025
c976b5e
Added some more comments
bohnstingl Apr 8, 2025
afa5d96
Introduced materialize_as_graph for combine_fn and for combine_bw_fn
8000 bohnstingl Apr 9, 2025
6d67d3e
Merge branch 'main' of github.com:pytorch/pytorch into scan_autograd22
bohnstingl Apr 10, 2025
ebe8c22
Added bypass for autograd DispatchKey if no gradient for either the i…
bohnstingl Apr 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Added bypass for autograd DispatchKey if no gradient for either the i…
…nit, the xs or the additional_inputs is required
  • Loading branch information
bohnstingl committed Apr 10, 2025
commit ebe8c227439d8b5f69efffa8b853cca235bc12f6
6 changes: 2 additions & 4 deletions test/functorch/test_control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3277,17 +3277,15 @@ def f(fct, init, xs):
def forward(self, fct_1, init_1, xs_1):
permute = torch.ops.aten.permute.default(xs_1, [0, 1, 2])
flip = torch.ops.aten.flip.default(permute, [0]); permute = None
select_copy = torch.ops.aten.select_copy.int(flip, 0, 0); select_copy = None
sym_size_int_1 = torch.ops.aten.sym_size.int(init_1, 1)
sym_size_int_2 = torch.ops.aten.sym_size.int(init_1, 2)
sym_size_int_3 = torch.ops.aten.sym_size.int(xs_1, 1)
sym_size_int_4 = torch.ops.aten.sym_size.int(xs_1, 2); xs_1 = None
scan_combine_graph_0 = self.scan_combine_graph_0
scan = torch.ops.higher_order.scan(scan_combine_graph_0, [init_1], [flip], (sym_size_int_1, sym_size_int_2, sym_size_int_3, sym_size_int_4)); scan_combine_graph_0 = init_1 = flip = sym_size_int_1 = sym_size_int_2 = sym_size_int_3 = sym_size_int_4 = None
getitem = scan[0]
getitem_1 = scan[1]; getitem_1 = None
getitem_2 = scan[2]; scan = None
flip_1 = torch.ops.aten.flip.default(getitem_2, [0]); getitem_2 = None
getitem_1 = scan[1]; scan = None
flip_1 = torch.ops.aten.flip.default(getitem_1, [0]); getitem_1 = None
return (getitem, flip_1)""", # noqa: B950
)

Expand Down
142 changes: 79 additions & 63 deletions torch/_higher_order_ops/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,21 +567,11 @@ def forward(
ctx._fw_include_key_set = torch._C._dispatch_tls_local_include_set()
ctx._fw_exclude_key_set = torch._C._dispatch_tls_local_exclude_set()

# TODO: we need to materialize the combine_fn because dynamo is unable to
# trace through the function when torch.compile torch.autograd.grad.
combine_fn_gm = materialize_as_graph(
combine_fn,
(*init, *first_slice_copy_with_grad(xs), *additional_inputs),
ctx._fw_include_key_set,
ctx._fw_exclude_key_set,
force_enable_grad=True,
)

# 1.) Prepare the forward graph wrapper ``combine_fn_with_carry_checkpoint``
# The wrapper of the forward graph returns carries from all iterations,
# not just from the last iteration. These are required in the backward path
def combine_fn_with_carry_checkpoint(*args):
carry, y = _extract_carry_and_out(combine_fn_gm(*args), num_leaves_init)
carry, y = _extract_carry_and_out(combine_fn(*args), num_leaves_init)
return [
*carry,
# We additionally checkpoint all the intemediate carry outputs for backward.
Expand Down Expand Up @@ -669,56 +659,6 @@ def initialize_g_additional_inputs(
)
ctx._combine_fn_bw = create_bw_fn(ctx._combine_fn, fw_operands)

def construct_args_single_step_bw():
# This function constructs the arguments for a single step of the backward scan.
# In other words, it creates the arguments for ``ctx._combine_fn_bw``.
# The ``ctx._combine_fn_bw`` expects primals followed by the tangents, thus

# The first arguments are primals, i.e., the forward part of the bw_fn graph
# The first argument relates to the init for the forward.
# I.e., fw_init

# The second argument relates to the xs for the forward.
# Because the arguments are for a single step only,
# only the first slice of the xs is used.
# Note: It is important to preserve the requires_grad flag of xs
# and thus we use the wrapper function ``first_slice_copy_with_grad``
fw_xs_slice = first_slice_copy_with_grad(fw_xs)

# The third argument relates to the additional inputs for the forward.
# I.e., additional_inputs

# The subsequent arguments are the tangents, i.e., the gradients of the bw_fn
# The fourth argument relates to the gradients of the carries.
# Because the arguments are for a single step only,
# only the first slice of the carries is used.
sliced_carries = [first_slice_copy(c) for c in fw_carries]

# The last argument relates to the gradients of the ys.
# Because the arguments are for a single step only,
# only the first slice of the ys is used.
sliced_ys = [first_slice_copy(o) for o in fw_ys]

return (
*fw_init,
*fw_xs_slice,
*additional_inputs,
*sliced_carries,
*sliced_ys,
)

args_single_step_bw = construct_args_single_step_bw()

# TODO: we need to materialize the bw graphs because dynamo is unable to
# trace through the joint function when torch.compile torch.autograd.grad.
ctx._combine_fn_bw_gm = materialize_as_graph(
ctx._combine_fn_bw,
args_single_step_bw,
ctx._fw_include_key_set,
ctx._fw_exclude_key_set,
force_enable_grad=True,
)

# 4.) Create the BW wrapper to accumulate the gradients for the additional_inputs
def combine_fn_bw_grad_accumulation(*args):
# Dissect args and re-order them for the ``ctx._combine_fn_bw``
Expand All @@ -739,7 +679,7 @@ def combine_fn_bw_grad_accumulation(*args):
combine_fn_bw_args = (*combine_fn_bw_primals, *combine_fn_bw_tangents)

g_c_t, g_xs_t, g_additional_inputs_t = split_into_chunks(
ctx._combine_fn_bw_gm(*combine_fn_bw_args),
ctx._combine_fn_bw(*combine_fn_bw_args),
[num_leaves_init, num_leaves_xs, num_additional_inputs],
)

Expand All @@ -757,6 +697,69 @@ def combine_fn_bw_grad_accumulation(*args):
# The ``g_xs_t`` is encoded as the output of the backward scan operator
return [*new_g_additional_inputs, *g_c_t, *g_xs_t]

# Materialize the ``combine_fn_bw_grad_accumulation``
def construct_args_single_step_bw():
# This function constructs the arguments for a single step of the backward scan.
# In other words, it creates the arguments for ``combine_fn_bw_grad_accumulation``
# The order of the arguments returned is identical to the order the backward scan
# operations provides

# The following arguments are used for the backward part of the joint graph
# The first argument relates to the gradient accumulation of the additional inputs.
# Because only tensor elements of additional inputs can have requires_grad=True,
# the values for non-tensor elements of additional inputs are None
masked_additional_inputs = [
a.clone() if add_inp_tm else None
for add_inp_tm, a in zip(
additional_inputs_tensor_mask, additional_inputs
)
]

# The second argument relates to the gradients of the carries.
# Because the arguments are for a single step only,
# only the first slice of the carries is used.
sliced_carries = [first_slice_copy(c) for c in fw_carries]

# The third argument relates to the gradients of the ys.
# Because the arguments are for a single step only,
# only the first slice of the ys is used.
sliced_ys = [first_slice_copy(o) for o in fw_ys]

# The following arguments are used for the forward part of the joint graph
# The fourth argument relates to the init for the forward.
# I.e., fw_init

# The fifth argument relates to the xs for the forward.
# Because the arguments are for a single step only,
# only the first slice of the xs is used.
# Note: It is important to preserve the requires_grad flag of xs
# and thus we use the wrapper function ``first_slice_copy_with_grad``
fw_xs_slice = first_slice_copy_with_grad(fw_xs)

# The last argument relates to the additional inputs for the forward.
# I.e., additional_inputs

return (
*masked_additional_inputs,
*sliced_carries,
*sliced_ys,
*fw_init,
*fw_xs_slice,
*additional_inputs,
)

args_single_step_bw = construct_args_single_step_bw()

# TODO: we need to materialize the bw graphs because dynamo is unable to
# trace through the joint function when torch.compile torch.autograd.grad.
combine_fn_bw_grad_accumulation_gm = materialize_as_graph(
combine_fn_bw_grad_accumulation,
args_single_step_bw,
ctx._fw_include_key_set,
ctx._fw_exclude_key_set,
force_enable_grad=True,
)

# Decompose the flat_grads into g_c_T, g_ys
g_c_T, g_ys = split_into_chunks(flat_grads, [num_leaves_init, num_leaves_ys])

Expand Down Expand Up @@ -784,7 +787,7 @@ def combine_fn_bw_grad_accumulation(*args):
# initial_g_additional_inputs and the last carry as the ``bwd_init`` and the
# gradients of the outputs (g_ys), as well as the fw_carries and the fw_xs of the forward as the ``bwd_xs``
gradients = scan_op(
combine_fn_bw_grad_accumulation,
combine_fn_bw_grad_accumulation_gm,
bwd_init,
bwd_xs,
additional_inputs,
Expand All @@ -803,6 +806,19 @@ def combine_fn_bw_grad_accumulation(*args):

@scan_op.py_impl(DispatchKey.Autograd)
def scan_autograd(combine_fn, init, xs, additional_inputs):
if not any(
el.requires_grad
for el in (tuple(init) + tuple(xs) + additional_inputs)
if isinstance(el, torch.Tensor)
):
with torch._C._AutoDispatchBelowAutograd():
return scan_op(
combine_fn,
init,
xs,
additional_inputs,
)

num_leaves_init = len(init)
num_leaves_xs = len(xs)
num_additional_inputs = len(additional_inputs)
Expand Down
Loading
0