@@ -567,21 +567,11 @@ def forward(
567
567
ctx ._fw_include_key_set = torch ._C ._dispatch_tls_local_include_set ()
568
568
ctx ._fw_exclude_key_set = torch ._C ._dispatch_tls_local_exclude_set ()
569
569
570
- # TODO: we need to materialize the combine_fn because dynamo is unable to
571
- # trace through the function when torch.compile torch.autograd.grad.
572
- combine_fn_gm = materialize_as_graph (
573
- combine_fn ,
574
- (* init , * first_slice_copy_with_grad (xs ), * additional_inputs ),
575
- ctx ._fw_include_key_set ,
576
- ctx ._fw_exclude_key_set ,
577
- force_enable_grad = True ,
578
- )
579
-
580
570
# 1.) Prepare the forward graph wrapper ``combine_fn_with_carry_checkpoint``
581
571
# The wrapper of the forward graph returns carries from all iterations,
582
572
# not just from the last iteration. These are required in the backward path
583
573
def combine_fn_with_carry_checkpoint (* args ):
584
- carry , y = _extract_carry_and_out (combine_fn_gm (* args ), num_leaves_init )
574
+ carry , y = _extract_carry_and_out (combine_fn (* args ), num_leaves_init )
585
575
return [
586
576
* carry ,
587
577
# We additionally checkpoint all the intemediate carry outputs for backward.
@@ -669,56 +659,6 @@ def initialize_g_additional_inputs(
669
659
)
670
660
ctx ._combine_fn_bw = create_bw_fn (ctx ._combine_fn , fw_operands )
671
661
672
- def construct_args_single_step_bw ():
673
- # This function constructs the arguments for a single step of the backward scan.
674
- # In other words, it creates the arguments for ``ctx._combine_fn_bw``.
675
- # The ``ctx._combine_fn_bw`` expects primals followed by the tangents, thus
676
-
677
- # The first arguments are primals, i.e., the forward part of the bw_fn graph
678
- # The first argument relates to the init for the forward.
679
- # I.e., fw_init
680
-
681
- # The second argument relates to the xs for the forward.
682
- # Because the arguments are for a single step only,
683
- # only the first slice of the xs is used.
684
- # Note: It is important to preserve the requires_grad flag of xs
685
- # and thus we use the wrapper function ``first_slice_copy_with_grad``
686
- fw_xs_slice = first_slice_copy_with_grad (fw_xs )
687
-
688
- # The third argument relates to the additional inputs for the forward.
689
- # I.e., additional_inputs
690
-
691
- # The subsequent arguments are the tangents, i.e., the gradients of the bw_fn
692
- # The fourth argument relates to the gradients of the carries.
693
- # Because the arguments are for a single step only,
694
- # only the first slice of the carries is used.
695
- sliced_carries = [first_slice_copy (c ) for c in fw_carries ]
696
-
697
- # The last argument relates to the gradients of the ys.
698
- # Because the arguments are for a single step only,
699
- # only the first slice of the ys is used.
700
- sliced_ys = [first_slice_copy (o ) for o in fw_ys ]
701
-
702
- return (
703
- * fw_init ,
704
- * fw_xs_slice ,
705
- * additional_inputs ,
706
- * sliced_carries ,
707
- * sliced_ys ,
708
- )
709
-
710
- args_single_step_bw = construct_args_single_step_bw ()
711
-
712
- # TODO: we need to materialize the bw graphs because dynamo is unable to
713
- # trace through the joint function when torch.compile torch.autograd.grad.
714
- ctx ._combine_fn_bw_gm = materialize_as_graph (
715
- ctx ._combine_fn_bw ,
716
- args_single_step_bw ,
717
- ctx ._fw_include_key_set ,
718
- ctx ._fw_exclude_key_set ,
719
- force_enable_grad = True ,
720
- )
721
-
722
662
# 4.) Create the BW wrapper to accumulate the gradients for the additional_inputs
723
663
def combine_fn_bw_grad_accumulation (* args ):
724
664
# Dissect args and re-order them for the ``ctx._combine_fn_bw``
@@ -739,7 +679,7 @@ def combine_fn_bw_grad_accumulation(*args):
739
679
combine_fn_bw_args = (* combine_fn_bw_primals , * combine_fn_bw_tangents )
740
680
741
681
g_c_t , g_xs_t , g_additional_inputs_t = split_into_chunks (
742
- ctx ._combine_fn_bw_gm (* combine_fn_bw_args ),
682
+ ctx ._combine_fn_bw (* combine_fn_bw_args ),
743
683
[num_leaves_init , num_leaves_xs , num_additional_inputs ],
744
684
)
745
685
@@ -757,6 +697,69 @@ def combine_fn_bw_grad_accumulation(*args):
757
697
# The ``g_xs_t`` is encoded as the output of the backward scan operator
758
698
return [* new_g_additional_inputs , * g_c_t , * g_xs_t ]
759
699
700
+ # Materialize the ``combine_fn_bw_grad_accumulation``
701
+ def construct_args_single_step_bw ():
702
+ # This function constructs the arguments for a single step of the backward scan.
703
+ # In other words, it creates the arguments for ``combine_fn_bw_grad_accumulation``
704
+ # The order of the arguments returned is identical to the order the backward scan
705
+ # operations provides
706
+
707
+ # The following arguments are used for the backward part of the joint graph
708
+ # The first argument relates to the gradient accumulation of the additional inputs.
709
+ # Because only tensor elements of additional inputs can have requires_grad=True,
710
+ # the values for non-tensor elements of additional inputs are None
711
+ masked_additional_inputs = [
712
+ a .clone () if add_inp_tm else None
713
+ for add_inp_tm , a in zip (
714
+ additional_inputs_tensor_mask , additional_inputs
715
+ )
716
+ ]
717
+
718
+ # The second argument relates to the gradients of the carries.
719
+ # Because the arguments are for a single step only,
720
+ # only the first slice of the carries is used.
721
+ sliced_carries = [first_slice_copy (c ) for c in fw_carries ]
722
+
723
+ # The third argument relates to the gradients of the ys.
724
+ # Because the arguments are for a single step only,
725
+ # only the first slice of the ys is used.
726
+ sliced_ys = [first_slice_copy (o ) for o in fw_ys ]
727
+
728
+ # The following arguments are used for the forward part of the joint graph
729
+ # The fourth argument relates to the init for the forward.
730
+ # I.e., fw_init
731
+
732
+ # The fifth argument relates to the xs for the forward.
733
+ # Because the arguments are for a single step only,
734
+ # only the first slice of the xs is used.
735
+ # Note: It is important to preserve the requires_grad flag of xs
736
+ # and thus we use the wrapper function ``first_slice_copy_with_grad``
737
+ fw_xs_slice = first_slice_copy_with_grad (fw_xs )
738
+
739
+ # The last argument relates to the additional inputs for the forward.
740
+ # I.e., additional_inputs
741
+
742
+ return (
743
+ * masked_additional_inputs ,
744
+ * sliced_carries ,
745
+ * sliced_ys ,
746
+ * fw_init ,
747
+ * fw_xs_slice ,
748
+ * additional_inputs ,
749
+ )
750
+
751
+ args_single_step_bw = construct_args_single_step_bw ()
752
+
753
+ # TODO: we need to materialize the bw graphs because dynamo is unable to
754
+ # trace through the joint function when torch.compile torch.autograd.grad.
755
+ combine_fn_bw_grad_accumulation_gm = materialize_as_graph (
756
+ combine_fn_bw_grad_accumulation ,
757
+ args_single_step_bw ,
758
+ ctx ._fw_include_key_set ,
759
+ ctx ._fw_exclude_key_set ,
760
+ force_enable_grad = True ,
761
+ )
762
+
760
763
# Decompose the flat_grads into g_c_T, g_ys
761
764
g_c_T , g_ys = split_into_chunks (flat_grads , [num_leaves_init , num_leaves_ys ])
762
765
@@ -784,7 +787,7 @@ def combine_fn_bw_grad_accumulation(*args):
784
787
# initial_g_additional_inputs and the last carry as the ``bwd_init`` and the
785
788
# gradients of the outputs (g_ys), as well as the fw_carries and the fw_xs of the forward as the ``bwd_xs``
786
789
gradients = scan_op (
787
- combine_fn_bw_grad_accumulation ,
790
+ combine_fn_bw_grad_accumulation_gm ,
788
791
bwd_init ,
789
792
bwd_xs ,
790
793
additional_inputs ,
@@ -803,6 +806,19 @@ def combine_fn_bw_grad_accumulation(*args):
803
806
804
807
@scan_op .py_impl (DispatchKey .Autograd )
805
808
def scan_autograd (combine_fn , init , xs , additional_inputs ):
809
+ if not any (
810
+ el .requires_grad
811
+ for el in (tuple (init ) + tuple (xs ) + additional_inputs )
812
+ if isinstance (el , torch .Tensor )
813
+ ):
814
+ with torch ._C ._AutoDispatchBelowAutograd ():
815
+ return scan_op (
816
+ combine_fn ,
817
+ init ,
818
+ xs ,
819
+ additional_inputs ,
820
+ )
821
+
806
822
num_leaves_init = len (init )
807
823
num_leaves_xs = len (xs )
808
824
num_additional_inputs = len (additional_inputs )
0 commit comments