8000 Added bypass for autograd DispatchKey if no gradient for either the i… · pytorch/pytorch@ebe8c22 · GitHub
[go: up one dir, main page]

Skip to content

Commit ebe8c22

Browse files
committed
Added bypass for autograd DispatchKey if no gradient for either the init, the xs or the additional_inputs is required
1 parent 6d67d3e commit ebe8c22

File tree

2 files changed

+81
-67
lines changed

2 files changed

+81
-67
lines changed

test/functorch/test_control_flow.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3277,17 +3277,15 @@ def f(fct, init, xs):
32773277
def forward(self, fct_1, init_1, xs_1):
32783278
permute = torch.ops.aten.permute.default(xs_1, [0, 1, 2])
32793279
flip = torch.ops.aten.flip.default(permute, [0]); permute = None
3280-
select_copy = torch.ops.aten.select_copy.int(flip, 0, 0); select_copy = None
32813280
sym_size_int_1 = torch.ops.aten.sym_size.int(init_1, 1)
32823281
sym_size_int_2 = torch.ops.aten.sym_size.int(init_1, 2)
32833282
sym_size_int_3 = torch.ops.aten.sym_size.int(xs_1, 1)
32843283
sym_size_int_4 = torch.ops.aten.sym_size.int(xs_1, 2); xs_1 = None
32853284
scan_combine_graph_0 = self.scan_combine_graph_0
32863285
scan = torch.ops.higher_order.scan(scan_combine_graph_0, [init_1], [flip], (sym_size_int_1, sym_size_int_2, sym_size_int_3, sym_size_int_4)); scan_combine_graph_0 = init_1 = flip = sym_size_int_1 = sym_size_int_2 = sym_size_int_3 = sym_size_int_4 = None
32873286
getitem = scan[0]
3288-
getitem_1 = scan[1]; getitem_1 = None
3289-
getitem_2 = scan[2]; scan = None
3290-
flip_1 = torch.ops.aten.flip.default(getitem_2, [0]); getitem_2 = None
3287+
getitem_1 = scan[1]; scan = None
3288+
flip_1 = torch.ops.aten.flip.default(getitem_1, [0]); getitem_1 = None
32913289
return (getitem, flip_1)""", # noqa: B950
32923290
)
32933291

torch/_higher_order_ops/scan.py

Lines changed: 79 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -567,21 +567,11 @@ def forward(
567567
ctx._fw_include_key_set = torch._C._dispatch_tls_local_include_set()
568568
ctx._fw_exclude_key_set = torch._C._dispatch_tls_local_exclude_set()
569569

570-
# TODO: we need to materialize the combine_fn because dynamo is unable to
571-
# trace through the function when torch.compile torch.autograd.grad.
572-
combine_fn_gm = materialize_as_graph(
573-
combine_fn,
574-
(*init, *first_slice_copy_with_grad(xs), *additional_inputs),
575-
ctx._fw_include_key_set,
576-
ctx._fw_exclude_key_set,
577-
force_enable_grad=True,
578-
)
579-
580570
# 1.) Prepare the forward graph wrapper ``combine_fn_with_carry_checkpoint``
581571
# The wrapper of the forward graph returns carries from all iterations,
582572
# not just from the last iteration. These are required in the backward path
583573
def combine_fn_with_carry_checkpoint(*args):
584-
carry, y = _extract_carry_and_out(combine_fn_gm(*args), num_leaves_init)
574+
carry, y = _extract_carry_and_out(combine_fn(*args), num_leaves_init)
585575
return [
586576
*carry,
587577
# We additionally checkpoint all the intemediate carry outputs for backward.
@@ -669,56 +659,6 @@ def initialize_g_additional_inputs(
669659
)
670660
ctx._combine_fn_bw = create_bw_fn(ctx._combine_fn, fw_operands)
671661

672-
def construct_args_single_step_bw():
673-
# This function constructs the arguments for a single step of the backward scan.
674-
# In other words, it creates the arguments for ``ctx._combine_fn_bw``.
675-
# The ``ctx._combine_fn_bw`` expects primals followed by the tangents, thus
676-
677-
# The first arguments are primals, i.e., the forward part of the bw_fn graph
678-
# The first argument relates to the init for the forward.
679-
# I.e., fw_init
680-
681-
# The second argument relates to the xs for the forward.
682-
# Because the arguments are for a single step only,
683-
# only the first slice of the xs is used.
684-
# Note: It is important to preserve the requires_grad flag of xs
685-
# and thus we use the wrapper function ``first_slice_copy_with_grad``
686-
fw_xs_slice = first_slice_copy_with_grad(fw_xs)
687-
688-
# The third argument relates to the additional inputs for the forward.
689-
# I.e., additional_inputs
690-
691-
# The subsequent arguments are the tangents, i.e., the gradients of the bw_fn
692-
# The fourth argument relates to the gradients of the carries.
693-
# Because the arguments are for a single step only,
694-
# only the first slice of the carries is used.
695-
sliced_carries = [first_slice_copy(c) for c in fw_carries]
696-
697-
# The last argument relates to the gradients of the ys.
698-
# Because the arguments are for a single step only,
699-
# only the first slice of the ys is used.
700-
sliced_ys = [first_slice_copy(o) for o in fw_ys]
701-
702-
return (
703-
*fw_init,
704-
*fw_xs_slice,
705-
*additional_inputs,
706-
*sliced_carries,
707-
*sliced_ys,
708-
)
709-
710-
args_single_step_bw = construct_args_single_step_bw()
711-
712-
# TODO: we need to materialize the bw graphs because dynamo is unable to
713-
# trace through the joint function when torch.compile torch.autograd.grad.
714-
ctx._combine_fn_bw_gm = materialize_as_graph(
715-
ctx._combine_fn_bw,
716-
args_single_step_bw,
717-
ctx._fw_include_key_set,
718-
ctx._fw_exclude_key_set,
719-
force_enable_grad=True,
720-
)
721-
722662
# 4.) Create the BW wrapper to accumulate the gradients for the additional_inputs
723663
def combine_fn_bw_grad_accumulation(*args):
724664
# Dissect args and re-order them for the ``ctx._combine_fn_bw``
@@ -739,7 +679,7 @@ def combine_fn_bw_grad_accumulation(*args):
739679
combine_fn_bw_args = (*combine_fn_bw_primals, *combine_fn_bw_tangents)
740680

741681
g_c_t, g_xs_t, g_additional_inputs_t = split_into_chunks(
742-
ctx._combine_fn_bw_gm(*combine_fn_bw_args),
682+
ctx._combine_fn_bw(*combine_fn_bw_args),
743683
[num_leaves_init, num_leaves_xs, num_additional_inputs],
744684
)
745685

@@ -757,6 +697,69 @@ def combine_fn_bw_grad_accumulation(*args):
757697
# The ``g_xs_t`` is encoded as the output of the backward scan operator
758698
return [*new_g_additional_inputs, *g_c_t, *g_xs_t]
759699

700+
# Materialize the ``combine_fn_bw_grad_accumulation``
701+
def construct_args_single_step_bw():
702+
# This function constructs the arguments for a single step of the backward scan.
703+
# In other words, it creates the arguments for ``combine_fn_bw_grad_accumulation``
704+
# The order of the arguments returned is identical to the order the backward scan
705+
# operations provides
706+
707+
# The following arguments are used for the backward part of the joint graph
708+
# The first argument relates to the gradient accumulation of the additional inputs.
709+
# Because only tensor elements of additional inputs can have requires_grad=True,
710+
# the values for non-tensor elements of additional inputs are None
711+
masked_additional_inputs = [
712+
a.clone() if add_inp_tm else None
713+
for add_inp_tm, a in zip(
714+
additional_inputs_tensor_mask, additional_inputs
715+
)
716+
]
717+
718+
# The second argument relates to the gradients of the carries.
719+
# Because the arguments are for a single step only,
720+
# only the first slice of the carries is used.
721+
sliced_carries = [first_slice_copy(c) for c in fw_carries]
722+
723+
# The third argument relates to the gradients of the ys.
724+
# Because the arguments are for a single step only,
725+
# only the first slice of the ys is used.
726+
sliced_ys = [first_slice_copy(o) for o in fw_ys]
727+
728+
# The following arguments are used for the forward part of the joint graph
729+
# The fourth argument relates to the init for the forward.
730+
# I.e., fw_init
731+
732+
# The fifth argument relates to the xs for the forward.
733+
# Because the arguments are for a single step only,
734+
# only the first slice of the xs is used.
735+
# Note: It is important to preserve the requires_grad flag of xs
736+
# and thus we use the wrapper function ``first_slice_copy_with_grad``
737+
fw_xs_slice = first_slice_copy_with_grad(fw_xs)
738+
739+
# The last argument relates to the additional inputs for the forward.
740+
# I.e., additional_inputs
741+
742+
return (
743+
*masked_additional_inputs,
744+
*sliced_carries,
745+
*sliced_ys,
746+
*fw_init,
747+
*fw_xs_slice,
748+
*additional_inputs,
749+
)
750+
751+
args_single_step_bw = construct_args_single_step_bw()
752+
753+
# TODO: we need to materialize the bw graphs because dynamo is unable to
754+
# trace through the joint function when torch.compile torch.autograd.grad.
755+
combine_fn_bw_grad_accumulation_gm = materialize_as_graph(
756+
combine_fn_bw_grad_accumulation,
757+
args_single_step_bw,
758+
ctx._fw_include_key_set,
759+
ctx._fw_exclude_key_set,
760+
force_enable_grad=True,
761+
)
762+
760763
# Decompose the flat_grads into g_c_T, g_ys
761764
g_c_T, g_ys = split_into_chunks(flat_grads, [num_leaves_init, num_leaves_ys])
762765

@@ -784,7 +787,7 @@ def combine_fn_bw_grad_accumulation(*args):
784787
# initial_g_additional_inputs and the last carry as the ``bwd_init`` and the
785788
# gradients of the outputs (g_ys), as well as the fw_carries and the fw_xs of the forward as the ``bwd_xs``
786789
gradients = scan_op(
787-
combine_fn_bw_grad_accumulation,
790+
combine_fn_bw_grad_accumulation_gm,
788791
bwd_init,
789792
bwd_xs,
790793
additional_inputs,
@@ -803,6 +806,19 @@ def combine_fn_bw_grad_accumulation(*args):
803806

804807
@scan_op.py_impl(DispatchKey.Autograd)
805808
def scan_autograd(combine_fn, init, xs, additional_inputs):
809+
if not any(
810+
el.requires_grad
811+
for el in (tuple(init) + tuple(xs) + additional_inputs)
812+
if isinstance(el, torch.Tensor)
813+
):
814+
with torch._C._AutoDispatchBelowAutograd():
815+
return scan_op(
816+
combine_fn,
817+
init,
818+
xs,
819+
additional_inputs,
820+
)
821+
806822
num_leaves_init = len(init)
807823
num_leaves_xs = len(xs)
808824
num_additional_inputs = len(additional_inputs)

0 commit comments

Comments
 (0)
0