8000 Rework of documentation · pytorch/pytorch@a565834 · GitHub
[go: up one dir, main page]

Skip to content

Commit a565834

Browse files
committed
Rework of documentation
1 parent 3707c0d commit a565834

File tree

1 file changed

+77
-63
lines changed

1 file changed

+77
-63
lines changed

torch/_higher_order_ops/associative_scan.py

Lines changed: 77 additions & 63 deletions
8000 8000
Original file line numberDiff line numberDiff line change
@@ -454,40 +454,42 @@ class AssociativeScanAutogradOp(torch.autograd.Function):
454454
r"""
455455
Example::
456456
xs = torch.arange(1, 5) = [1, 2, 3, 4]
457-
ys = torch.cumprod(xs) = [1, 2, 6, 24]
458457
459458
def combine_fn(a: torch.Tensor, b: torch.Tensor):
460459
return a * b
461460
462-
The ``combine_fn_bw``, computing the gradients for a and b of ``combine_fn`` is computed as:
463-
def combine_fn_bw(a: torch.Tensor, b: torch.Tensor, g_y: torch.Tensor):
464-
return g_y * b, g_y * a
465-
466-
The first output of ``combine_fn_bw`` is the instantaneous gradient for the previous output g_y_t
467-
and the second output of ``combine_fn_bw`` is the instantaneous gradient for the input g_x_t.
461+
ys = associative_scan(comine_fn, xs),
462+
which can be unpacked as:
463+
ys_0 = xs_0 = 1
464+
ys_1 = combine_fn(ys_0, xs_1) = combine_fn(1, 2) = 2
465+
...
466+
ys_T = combine_fn(ys_(T-1), xs_T) = combine_fn(6, 4) = 24
467+
ys = [1, 2, 6, 24]
468468
469-
Note: In a real usecase of associative_scan, there may be additional_inputs that participate in the
470-
forward as well as in the backward of the scan operator. For the sake of readability those inputs
471-
have been omitted in the following example, but are included in the subsequent detailed description below.
469+
The function ``combine_fn_bw`` returns the gradients of a and b from ``combine_fn``.
470+
For the ``combine_fn`` above, this results in:
471+
def combine_fn_bw(a: torch.Tensor, b: torch.Tensor, g_ys_t: torch.Tensor):
472+
return g_ys_t * b, g_ys_t * a,
472473
473-
The forward output of associative_scan is computed as:
474-
ys = associative_scan(combine_fn, xs).
474+
where g_ys are the upstream gradients in torch.autograd.Function.
475+
In particular, g_ys is the vector of all intermediate upstream gradients
476+
of the outputs [g_ys_0, g_ys_1, ..., g_ys_T].
475477
476-
For example, this computation can be unpacked as:
477-
ys_0 = xs_0
478-
ys_1 = combine_fn(ys_0, xs_1)
479-
...
480-
ys_T = combine_fn(ys_(T-1), xs_T)
478+
The first output of ``combine_fn_bw`` is the gradient g_y_t,
479+
which is the gradient for the forward output at step t-1, i.e., a.
480+
The second output of ``combine_fn_bw`` is the gradient g_x_t,
481+
which is the gradient for the previous forward input at step t, i.e., b.
481482
482-
Note: In a real usecase of associative_scan this operation is parallelized from O(T) to O(log(T)).
483+
Note: In a real usecase of ``associative_scan``, there may be additional_inputs that participate in the
484+
forward as well as in the backward of the scan operator. For the sake of readability those inputs
485+
have been omitted in the following example, but are included in the subsequent detailed description below.
483486
484-
Given the ys, the gradients for xs can be computed as follows:
485-
We receive the upstream gradients in torch.autograd.Function, i.e., we get g_ys,
486-
where g_ys is the vector of all intermediate gradients of the outputs [g_ys_0, g_ys_1, ..., g_ys_T]
487+
Note: In a real usecase of ``associative_scan`` the operation is parallelized from O(T) to O(log(T)).
487488
488-
We can then utilize the ``combine_fn_bw`` to compute the instantaneous gradients g_x_t and g_y_t
489+
Given the outputs ys, the gradients for xs can be computed as follows:
490+
We fist utilize the ``combine_fn_bw`` to compute the instantaneous gradients g_x_t and g_y_t
489491
at every step as:
490-
g_y_t, g_x_t = combine_fn_bw(ys_(t-1), xs_t, 1.) ,
492+
g_y_t, g_x_t = combine_fn_bw(ys_(t-1), xs_t, 1.),
491493
where instead of using the elements of g_ys_t, we use 1s. This is required to get the instantaneous
492494
gradients at every step t and we incorporate the upstream gradients g_ys at a later time.
493495
@@ -501,32 +503,33 @@ def combine_fn_bw(a: torch.Tensor, b: torch.Tensor, g_y: torch.Tensor):
501503
g_y = [1, 2, 3, 4]
502504
g_x = [1, 1, 2, 6]
503505
504-
With these instantaneous gradients, one can compute the gradients of the inputs xs (g_xs) naively as:
506+
With these instantaneous gradients, we can compute the gradients of the inputs xs (g_xs) naively as:
505507
g_xs_t = (\sum_{i=T}^t g_ys_i . (\prod_{k=i}^{k>t} g_y_k)) . g_x_t (1)
506508
507509
In particular,
508-
g_xs_T = g_ys_T . g_x_T
510+
g_xs_T = g_ys_T . 1 . g_x_T
509511
g_xs_(T-1) = g_ys_T . g_y_T . g_x_(T-1) + g_ys_(T-1) . g_x_(T-1)
510512
g_xs_(T-2) = g_ys_T . g_y_T . g_y_(T-1) . g_x_(T-2) + g_ys_(T-1) . g_y_(T-1) . g_x_(T-2) + g_ys_(T-2) . g_x_(T-2)
511513
...
512514
513-
Which for the example above results in the final input gradients:
514-
g_xs_3 = 6
515-
g_xs_2 = 10
516-
g_xs_1 = 16
517-
g_xs_0 = 33
518-
519-
This recursive way of computing may not be the most efficient one and an alternative approach would be
520-
to rewrite the recursion with the help of cumulative products and matrix multiplications.
521-
In particular, when looking at equation (1) above, one can observe three key aspects:
522-
1.) The number of terms (products) in the sum is increasing by one as one progresses to earlier steps i.
523-
2.) The first products for step j<i are the same, but there is one additional product added
524-
which contains one element g_y_j, added to the end of the product.
525-
3.) For the input gradient g_xs_i, the instantaneous input g_x_i is multiplied at the end
526-
527-
These three observations can be exploited to formulate a ``grid form`` to compute the gradients more
528-
efficiently. See ``grid form`` outlined in https://justintchiu.com/blog/pscan_diff/.
529-
In particular, the sum and product in (1) can be rewritten
515+
Which for the example above results in the final input gradients (assuming g_ys_0 = g_ys_1 = ... = g_ys_T = 1):
516+
g_xs_3 = 1 . 1 . 6 = 6 (2)
517+
g_xs_2 = 1 . 4 . 2 + 1 . 2 = 10 (3)
518+
g_xs_1 = 1 . 4 . 3 . 1 + 1 . 3 . 1 + 1 . 1 = 16 (4)
519+
g_xs_0 = 1 . 4 . 3 . 2 . 1 + 1 . 3 . 2 . 1 + 1 . 2 . 1 + 1 . 1 = 33 (5)
520+
521+
This "recursive" way of computing the gradients outlined above, may not be the most efficient.
522+
An alternative approach would be to rewrite the sum and product of equation (1)
523+
with the help of cumulative products and matrix multiplications.
524+
In particular, when looking at equation (1), one can observe three key aspects:
525+
1.) The number of terms in the sum increases by one as one progresses from t back to an earlier steps t-1.
526+
2.) The first products for step t are the same as for step t-1,
527+
but there is one additional term added to the end of the product.
528+
3.) For the input gradient g_xs_t, the instantaneous input g_x_t is multiplied at the end.
529+
530+
These three observations can be exploited to formulate a "grid form" to compute the gradients more
531+
efficiently. See also "grid form" outlined in https://justintchiu.com/blog/pscan_diff/.
532+
In particular, the sum and product in equation (1) can be rewritten
530533
using a matrix form, combined with a cumulative product on the rows. The resulting vector can then be
531534
elementwise multiplied with the instantaneous input gradients to obtain the final input gradients g_xs.
532535
@@ -542,7 +545,12 @@ def combine_fn_bw(a: torch.Tensor, b: torch.Tensor, g_y: torch.Tensor):
542545
[0, 0, 1, 4],
543546
[0, 0, 0, 1]]
544547
545-
Note that these are precisely the terms of the product in (1).
548+
Note, the rows of the y_mat correspond exactly to the terms of the product in equation (1).
549+
For example, the last row contains the coefficients found in g_xs_3 (see (2)),
550+
the second to last row contains the coefficients found in g_xs_2 (see (3)).
551+
...
552+
The first row contains the coefficients found in g_xs_0 (see (5)).
553+
Note: The order of y_mat is reversed compared to (2)-(5), in order to avoid unnecessary flip operations.
546554
547555
We then scale the y_mat with the upstream gradient g_ys
548556
@@ -571,7 +579,7 @@ def combine_fn_bw(a: torch.Tensor, b: torch.Tensor, g_y: torch.Tensor):
571579
572580
The forward of associative_scan can be computed with the following steps:
573581
1.) Compute the forward output of the associative_scan
574-
ys = associative_scan_op(combine_fn, xs, additional_inputs)
582+
ys = associative_scan(combine_fn, xs, additional_inputs)
575583
576584
The backward of scan can be computed as:
577585
2.) Prepare the backward graph
@@ -601,7 +609,8 @@ def combine_fn_bw(a: torch.Tensor, b: torch.Tensor, g_y: torch.Tensor):
601609
[0, 1 , g_y_2 , g_y_3 . g_y_2 ],
602610
[0, 0 , 1 , g_y_3 ],
603611
[0, 0 , 0 , 1 ]],
604-
For better readability, however, we split the calculation into several substeps. We start by
612+
For better readability, however, we split the calculation into several substeps and utilize masks for 1s and 0s.
613+
We start with:
605614
606615
5.1 Repeat the elements of g_y to form the square matrix
607616
y_mat = [[1, g_y_1, g_y_2, g_y_3],
@@ -713,20 +722,21 @@ def backward(ctx, *g_ys):
713722
# but we need it here in order to compute the correcte gradients
714723
xs_slices = first_slice_copy_with_grad(itertools.chain(xs, xs))
715724

725+
# Construct the operands from the forward, fw_operands
726+
# and the operands for a single event t of the forward, fw_operands_slice
727+
fw_operands = (*xs, *additional_inputs)
728+
fw_operands_slice = (*xs_slices, *additional_inputs)
729+
716730
# 2.) Prepare the backward graph
717-
ctx._combine_fn_bw = create_bw_fn(
718-
ctx._combine_fn,
719-
(*xs_slices, *additional_inputs),
720-
)
731+
combine_fn_bw = create_bw_fn(ctx._combine_fn, fw_operands_slice)
721732

722-
# 3.) Materialize the ``ctx._combine_fn_bw``
733+
# 3.) Materialize the ``combine_fn_bw``
723734
# TODO: we need to materialize the bw graphs because dynamo is unable to
724735
# trace through the joint function when torch.compile torch.autograd.grad.
725736
combine_fn_bw_gm = materialize_as_graph(
726-
ctx._combine_fn_bw,
737+
combine_fn_bw,
727738
(
728-
*xs_slices,
729-
*additional_inputs,
739+
*fw_operands_slice,
730740
*[first_slice_copy(o) for o in outs],
731741
),
732742
ctx._fw_include_key_set,
@@ -740,12 +750,14 @@ def backward(ctx, *g_ys):
740750
mapped_combine_fn_bw_gm = torch.vmap(combine_fn_bw_gm, 0, 0)
741751

742752
# 4.) Compute the instantaneous gradients at every step ``t``
743-
# Use a ones_like tensor in order not to scale the g_y_t and g_x_t
753+
# Use a ones_like tensor in order not to scale the g_y_t and g_x_t,
754+
# with the upstream gradients yet.
755+
# Note: All g_x_t and g_y_t are computed in parallel, thus g_x and g_y are result.
744756
dummy_upstream_grad = (torch.ones_like(x) for x in g_ys)
745757
grads = mapped_combine_fn_bw_gm(
746-
*(o.roll(1, dim) for o in outs), *xs, *dummy_upstream_grad
758+
*(o.roll(1, dim) for o in outs), *fw_operands, *dummy_upstream_grad
747759
)
748-
g_y_t, g_x_t = split_into_chunks(grads, [num_xs, num_xs])
760+
g_y, g_x = split_into_chunks(grads, [num_xs, num_xs])
749761

750762
def compute_grad_y_mat(g_y: torch.Tensor) -> torch.Tensor:
751763
# Prepare a ones and a zeros helper mask in order to easily compute the y_mat
@@ -804,17 +816,19 @@ def compute_grad(g_x, g_y, g_ys):
804816

805817
return g_xs
806818

807-
# Stack all elements of the gradients along the first dimension.
808-
# This is useful as later the gradients of those elements can be computed in parallel.
809-
g_x_stacked = torch.stack(g_x_t)
810-
g_y_stacked = torch.stack(g_y_t)
811-
g_ys_stacked = torch.stack(g_ys)
819+
# Stack all leaves of the gradients along the first dimension.
820+
# This is useful as later the gradients of those leaves can be computed in parallel.
821+
g_x_stacked_leaves = torch.stack(g_x)
822+
g_y_stacked_leaves = torch.stack(g_y)
823+
g_ys_stacked_leaves = torch.stack(g_ys)
812824

813-
# The compute_grad function is parallelized across all individual elements of xs
825+
# The compute_grad function is parallelized across all individual leaves of xs
814826
# as these gradients can be computed independently from each other
815827
compute_grad_mapped = torch.vmap(compute_grad, 0, 0)
816828

817-
g_xs = compute_grad_mapped(g_x_stacked, g_y_stacked, g_ys_stacked)
829+
g_xs = compute_grad_mapped(
830+
g_x_stacked_leaves, g_y_stacked_leaves, g_ys_stacked_leaves
831+
)
818832

819833
# TODO: Currently the gradients for the additional_inputs are not computed properly
820834
return *[None] * 3, *g_xs, *[None] * num_additional_inputs

0 commit comments

Comments
 (0)
0