@@ -454,40 +454,42 @@ class AssociativeScanAutogradOp(torch.autograd.Function):
454
454
r"""
455
455
Example::
456
456
xs = torch.arange(1, 5) = [1, 2, 3, 4]
457
- ys = torch.cumprod(xs) = [1, 2, 6, 24]
458
457
459
458
def combine_fn(a: torch.Tensor, b: torch.Tensor):
460
459
return a * b
461
460
462
- The ``combine_fn_bw``, computing the gradients for a and b of ``combine_fn`` is computed as:
463
- def combine_fn_bw(a: torch.Tensor, b: torch.Tensor, g_y: torch.Tensor):
464
- return g_y * b, g_y * a
465
-
466
- The first output of ``combine_fn_bw`` is the instantaneous gradient for the previous output g_y_t
467
- and the second output of ``combine_fn_bw`` is the instantaneous gradient for the input g_x_t.
461
+ ys = associative_scan(comine_fn, xs),
462
+ which can be unpacked as:
463
+ ys_0 = xs_0 = 1
464
+ ys_1 = combine_fn(ys_0, xs_1) = combine_fn(1, 2) = 2
465
+ ...
466
+ ys_T = combine_fn(ys_(T-1), xs_T) = combine_fn(6, 4) = 24
467
+ ys = [1, 2, 6, 24]
468
468
469
- Note: In a real usecase of associative_scan, there may be additional_inputs that participate in the
470
- forward as well as in the backward of the scan operator. For the sake of readability those inputs
471
- have been omitted in the following example, but are included in the subsequent detailed description below.
469
+ The function ``combine_fn_bw`` returns the gradients of a and b from ``combine_fn``.
470
+ For the ``combine_fn`` above, this results in:
471
+ def combine_fn_bw(a: torch.Tensor, b: torch.Tensor, g_ys_t: torch.Tensor):
472
+ return g_ys_t * b, g_ys_t * a,
472
473
473
- The forward output of associative_scan is computed as:
474
- ys = associative_scan(combine_fn, xs).
474
+ where g_ys are the upstream gradients in torch.autograd.Function.
475
+ In particular, g_ys is the vector of all intermediate upstream gradients
476
+ of the outputs [g_ys_0, g_ys_1, ..., g_ys_T].
475
477
476
- For example, this computation can be unpacked as:
477
- ys_0 = xs_0
478
- ys_1 = combine_fn(ys_0, xs_1)
479
- ...
480
- ys_T = combine_fn(ys_(T-1), xs_T)
478
+ The first output of ``combine_fn_bw`` is the gradient g_y_t,
479
+ which is the gradient for the forward output at step t-1, i.e., a.
480
+ The second output of ``combine_fn_bw`` is the gradient g_x_t,
481
+ which is the gradient for the previous forward input at step t, i.e., b.
481
482
482
- Note: In a real usecase of associative_scan this operation is parallelized from O(T) to O(log(T)).
483
+ Note: In a real usecase of ``associative_scan``, there may be additional_inputs that participate in the
484
+ forward as well as in the backward of the scan operator. For the sake of readability those inputs
485
+ have been omitted in the following example, but are included in the subsequent detailed description below.
483
486
484
- Given the ys, the gradients for xs can be computed as follows:
485
- We receive the upstream gradients in torch.autograd.Function, i.e., we get g_ys,
486
- where g_ys is the vector of all intermediate gradients of the outputs [g_ys_0, g_ys_1, ..., g_ys_T]
487
+ Note: In a real usecase of ``associative_scan`` the operation is parallelized from O(T) to O(log(T)).
487
488
488
- We can then utilize the ``combine_fn_bw`` to compute the instantaneous gradients g_x_t and g_y_t
489
+ Given the outputs ys, the gradients for xs can be computed as follows:
490
+ We fist utilize the ``combine_fn_bw`` to compute the instantaneous gradients g_x_t and g_y_t
489
491
at every step as:
490
- g_y_t, g_x_t = combine_fn_bw(ys_(t-1), xs_t, 1.) ,
492
+ g_y_t, g_x_t = combine_fn_bw(ys_(t-1), xs_t, 1.),
491
493
where instead of using the elements of g_ys_t, we use 1s. This is required to get the instantaneous
492
494
gradients at every step t and we incorporate the upstream gradients g_ys at a later time.
493
495
@@ -501,32 +503,33 @@ def combine_fn_bw(a: torch.Tensor, b: torch.Tensor, g_y: torch.Tensor):
501
8000
503
g_y = [1, 2, 3, 4]
502
504
g_x = [1, 1, 2, 6]
503
505
504
- With these instantaneous gradients, one can compute the gradients of the inputs xs (g_xs) naively as:
506
+ With these instantaneous gradients, we can compute the gradients of the inputs xs (g_xs) naively as:
505
507
g_xs_t = (\sum_{i=T}^t g_ys_i . (\prod_{k=i}^{k>t} g_y_k)) . g_x_t (1)
506
508
507
509
In particular,
508
- g_xs_T = g_ys_T . g_x_T
510
+ g_xs_T = g_ys_T . 1 . g_x_T
509
511
g_xs_(T-1) = g_ys_T . g_y_T . g_x_(T-1) + g_ys_(T-1) . g_x_(T-1)
510
512
g_xs_(T-2) = g_ys_T . g_y_T . g_y_(T-1) . g_x_(T-2) + g_ys_(T-1) . g_y_(T-1) . g_x_(T-2) + g_ys_(T-2) . g_x_(T-2)
511
513
...
512
514
513
- Which for the example above results in the final input gradients:
514
- g_xs_3 = 6
515
- g_xs_2 = 10
516
- g_xs_1 = 16
517
- g_xs_0 = 33
518
-
519
- This recursive way of computing may not be the most efficient one and an alternative approach would be
520
- to rewrite the recursion with the help of cumulative products and matrix multiplications.
521
- In particular, when looking at equation (1) above, one can observe three key aspects:
522
8000
- 1.) The number of terms (products) in the sum is increasing by one as one progresses to earlier steps i.
523
- 2.) The first products for step j<i are the same, but there is one additional product added
524
- which contains one element g_y_j, added to the end of the product.
525
- 3.) For the input gradient g_xs_i, the instantaneous input g_x_i is multiplied at the end
526
-
527
- These three observations can be exploited to formulate a ``grid form`` to compute the gradients more
528
- efficiently. See ``grid form`` outlined in https://justintchiu.com/blog/pscan_diff/.
529
- In particular, the sum and product in (1) can be rewritten
515
+ Which for the example above results in the final input gradients (assuming g_ys_0 = g_ys_1 = ... = g_ys_T = 1):
516
+ g_xs_3 = 1 . 1 . 6 = 6 (2)
517
+ g_xs_2 = 1 . 4 . 2 + 1 . 2 = 10 (3)
518
+ g_xs_1 = 1 . 4 . 3 . 1 + 1 . 3 . 1 + 1 . 1 = 16 (4)
519
+ g_xs_0 = 1 . 4 . 3 . 2 . 1 + 1 . 3 . 2 . 1 + 1 . 2 . 1 + 1 . 1 = 33 (5)
520
+
521
+ This "recursive" way of computing the gradients outlined above, may not be the most efficient.
522
+ An alternative approach would be to rewrite the sum and product of equation (1)
523
+ with the help of cumulative products and matrix multiplications.
524
+ In particular, when looking at equation (1), one can observe three key aspects:
525
+ 1.) The number of terms in the sum increases by one as one progresses from t back to an earlier steps t-1.
526
+ 2.) The first products for step t are the same as for step t-1,
527
+ but there is one additional term added to the end of the product.
528
+ 3.) For the input gradient g_xs_t, the instantaneous input g_x_t is multiplied at the end.
529
+
530
+ These three observations can be exploited to formulate a "grid form" to compute the gradients more
531
+ efficiently. See also "grid form" outlined in https://justintchiu.com/blog/pscan_diff/.
532
+ In particular, the sum and product in equation (1) can be rewritten
530
533
using a matrix form, combined with a cumulative product on the rows. The resulting vector can then be
531
534
elementwise multiplied with the instantaneous input gradients to obtain the final input gradients g_xs.
532
535
@@ -542,7 +545,12 @@ def combine_fn_bw(a: torch.Tensor, b: torch.Tensor, g_y: torch.Tensor):
542
545
[0, 0, 1, 4],
543
546
[0, 0, 0, 1]]
544
547
545
- Note that these are precisely the terms of the product in (1).
548
+ Note, the rows of the y_mat correspond exactly to the terms of the product in equation (1).
549
+ For example, the last row contains the coefficients found in g_xs_3 (see (2)),
550
+ the second to last row contains the coefficients found in g_xs_2 (see (3)).
551
+ ...
552
+ The first row contains the coefficients found in g_xs_0 (see (5)).
553
+ Note: The order of y_mat is reversed compared to (2)-(5), in order to avoid unnecessary flip operations.
546
554
547
555
We then scale the y_mat with the upstream gradient g_ys
548
556
@@ -571,7 +579,7 @@ def combine_fn_bw(a: torch.Tensor, b: torch.Tensor, g_y: torch.Tensor):
571
579
572
580
The forward of associative_scan can be computed with the following steps:
573
581
1.) Compute the forward output of the associative_scan
574
- ys = associative_scan_op (combine_fn, xs, additional_inputs)
582
+ ys = associative_scan (combine_fn, xs, additional_inputs)
575
583
576
584
The backward of scan can be computed as:
577
585
2.) Prepare the backward graph
@@ -601,7 +609,8 @@ def combine_fn_bw(a: torch.Tensor, b: torch.Tensor, g_y: torch.Tensor):
601
609
[0, 1 , g_y_2 , g_y_3 . g_y_2 ],
602
610
[0, 0 , 1 , g_y_3 ],
603
611
[0, 0 , 0 , 1 ]],
604
- For better readability, however, we split the calculation into several substeps. We start by
612
+ For better readability, however, we split the calculation into several substeps and utilize masks for 1s and 0s.
613
+ We start with:
605
614
606
615
5.1 Repeat the elements of g_y to form the square matrix
607
616
y_mat = [[1, g_y_1, g_y_2, g_y_3],
@@ -713,20 +722,21 @@ def backward(ctx, *g_ys):
713
722
# but we need it here in order to compute the correcte gradients
714
723
xs_slices = first_slice_copy_with_grad (itertools .chain (xs , xs ))
715
724
725
+ # Construct the operands from the forward, fw_operands
726
+ # and the operands for a single event t of the forward, fw_operands_slice
727
+ fw_operands = (* xs , * additional_inputs )
728
+ fw_operands_slice = (* xs_slices , * additional_inputs )
729
+
716
730
# 2.) Prepare the backward graph
717
- ctx ._combine_fn_bw = create_bw_fn (
718
- ctx ._combine_fn ,
719
- (* xs_slices , * additional_inputs ),
720
- )
731
+ combine_fn_bw = create_bw_fn (ctx ._combine_fn , fw_operands_slice )
721
732
722
- # 3.) Materialize the ``ctx._combine_fn_bw ``
733
+ # 3.) Materialize the ``combine_fn_bw ``
723
734
# TODO: we need to materialize the bw graphs because dynamo is unable to
724
735
# trace through the joint function when torch.compile torch.autograd.grad.
725
736
combine_fn_bw_gm = materialize_as_graph (
726
- ctx . _combine_fn_bw ,
737
+ combine_fn_bw ,
727
738
(
728
- * xs_slices ,
729
- * additional_inputs ,
739
+ * fw_operands_slice ,
730
740
* [first_slice_copy (o ) for o in outs ],
731
741
),
732
742
ctx ._fw_include_key_set ,
@@ -740,12 +750,14 @@ def backward(ctx, *g_ys):
740
750
mapped_combine_fn_bw_gm = torch .vmap (combine_fn_bw_gm , 0 , 0 )
741
751
742
752
# 4.) Compute the instantaneous gradients at every step ``t``
743
- # Use a ones_like tensor in order not to scale the g_y_t and g_x_t
753
+ # Use a ones_like tensor in order not to scale the g_y_t and g_x_t,
754
+ # with the upstream gradients yet.
755
+ # Note: All g_x_t and g_y_t are computed in parallel, thus g_x and g_y are result.
744
756
dummy_upstream_grad = (torch .ones_like (x ) for x in g_ys )
745
757
grads = mapped_combine_fn_bw_gm (
746
- * (o .roll (1 , dim ) for o in outs ), * xs , * dummy_upstream_grad
758
+ * (o .roll (1 , dim ) for o in outs ), * fw_operands , * dummy_upstream_grad
747
759
)
748
- g_y_t , g_x_t = split_into_chunks (grads , [num_xs , num_xs ])
760
+ g_y , g_x = split_into_chunks (grads , [num_xs , num_xs ])
749
761
750
762
def compute_grad_y_mat (g_y : torch .Tensor ) -> torch .Tensor :
751
763
# Prepare a ones and a zeros helper mask in order to easily compute the y_mat
@@ -804,17 +816,19 @@ def compute_grad(g_x, g_y, g_ys):
804
816
805
817
return g_xs
806
818
807
- # Stack all elements of the gradients along the first dimension.
808
- # This is useful as later the gradients of those elements can be computed in parallel.
809
- g_x_stacked = torch .stack (g_x_t )
810
- g_y_stacked = torch .stack (g_y_t )
811
- g_ys_stacked = torch .stack (g_ys )
819
+ # Stack all leaves of the gradients along the first dimension.
820
+ # This is useful as later the gradients of those leaves can be computed in parallel.
821
+ g_x_stacked_leaves = torch .stack (g_x )
822
+ g_y_stacked_leaves = torch .stack (g_y )
823
+ g_ys_stacked_leaves = torch .stack (g_ys )
812
824
813
- # The compute_grad function is parallelized across all individual elements of xs
825
+ # The compute_grad function is parallelized across all individual leaves of xs
814
826
# as these gradients can be computed independently from each other
815
827
compute_grad_mapped = torch .vmap (compute_grad , 0 , 0 )
816
828
817
- g_xs = compute_grad_mapped (g_x_stacked , g_y_stacked , g_ys_stacked )
829
+ g_xs = compute_grad_mapped (
830
+ g_x_stacked_leaves , g_y_stacked_leaves , g_ys_stacked_leaves
831
+ )
818
832
819
833
# TODO: Currently the gradients for the additional_inputs are not computed properly
820
834
return * [None ] * 3 , * g_xs , * [None ] * num_additional_inputs
0 commit comments