-
Notifications
You must be signed in to change notification settings - Fork 24.2k
[associative_scan] Autograd separated #139939
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/139939
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 4 New FailuresAs of commit a565834 with merge base dd7d231 ( NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@pytorchbot label "topic: not user facing" |
Any review on this? |
Thanks for your implementation! I have a question regarding the shape check for assert x.shape == shape, "All xs tensors must have the same shape" Why does it require the tensors to have exactly the same shape? In the JAX implementation, only the first dimension is required to match: num_elems = int(elems_flat[0].shape[axis])
if not all(int(elem.shape[axis]) == num_elems for elem in elems_flat[1:]): |
Hi @WeihanLikk Thank you for looking into this. I was slow in working on this over the holidays but will pick up steam again. You are right, I don't think that this is necessarily required. Just the scanned dimension needs to be identical for all |
y_T = f(y_{T-1}, x_T) | ||
|
||
The gradients of y_T with respect to the vector x are computed as: | ||
dy_T / dx = dy_T/dx_1 + dy_T/dx_2 + ... + dy_T/dx_T |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not understanding this expression:
dy_T / dx = dy_T/dx_1 + dy_T/dx_2 + ... + dy_T/dx_T
Is there some typo in here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, I guess this should be the gradient of the element y_T with respect to the vector of inputs, i.e., with respect to every element x_1, x_2, ..., x_T of the vector. This would give individual elements like [dy_T / dx_1, dy_T / dx_2, ... dy_T / dx_T] and to get the final gradient these elements are summed.
@garrett361 I've implemented a first version of the backward approach we discussed offline. The algorithm per se works, but there are still some things to sort out. In particular the lifted argument and partial gradient support. EDIT: Of course there is still further room to cleanup the code and to adjust the documentation. |
What does lifted argument mean? Partial gradient = when only some inputs |
In some cases, variables and other properties from the
Yes, that is correct. Same applies here as well. I know how to handle it, but I wanted to wait for the autograd rework. |
Currently using associative scan for a big research project related to linear attention. (or perhaps I should say, logarithmic attention 😉) Is there an expected timeline for autograd support to be available? Really excited about this PR!! |
|
||
import torch | ||
import torch._prims_common as utils | ||
import torch._subclasses.functional_tensor | ||
import torch.utils._pytree as pytree | ||
from torch._C import DispatchKey | ||
from torch._dispatch.python import suspend_functionalization | ||
from torch._higher_order_ops.cond import create_bw_fn, materialize_as_graph |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that whatever PR lands first, either this one or #146285, should move these functions from torch._higher_order_ops.cond
to torch._higher_order_ops.utils
?
We then proceed to compute the gradients for the xs (g_xs) by computing for every step: | ||
the instantaneous for g_x_t, as well as | ||
the gradient for the previous intermediate results g_h_t using | ||
g_h_t, g_x_t = ctx._combine_fn_bw(ys_(t-1), xs_t) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I get lost at why _combine_fn_bw doesn't need the gradeint input g_ys_t?
Is g_h_t just ys_t? We should also define what g_h_t is in the unpacked forward.
g_h_t, g_x_t = ctx._combine_fn_bw(ys_(t-1), xs_t) | ||
|
||
For example: | ||
g_h_0, g_x_0 = [1, 1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm confused how we get these two, i thought we have xs, ys (saved by forward), and g_ys (backward inputs)?
We can then compute the 'gradient transition matrix' h_mat using the instantaneous gradient components g_h_t as | ||
h_mat = [[1, g_h_1, g_h_2 . g_h_1, g_h_3 . g_h_2 . g_h_1], | ||
[0, 1 , g_h_2 , g_h_3 . g_h_2 ], | ||
[0, 0 , 1 , g_h_3 ], | ||
[0, 0 , 0 , 1 ]], | ||
which for the example above results in: | ||
h_mat = [[1, 2, 6, 24], | ||
[0, 1, 3, 12], | ||
[0, 0, 1, 4], | ||
[0, 0, 0, 1]] | ||
|
||
We then scale the h_mat with the upstream gradient g_ys | ||
|
||
scaled_h_mat = h_mat * g_ys | ||
Assuming all 1s for the upstream gradients this would result in: | ||
scaled_h_mat = [[1, 2, 6, 24], | ||
[0, 1, 3, 12], | ||
[0, 0, 1, 4], | ||
[0, 0, 0, 1]] | ||
|
||
Sum the h_mat row-wise | ||
summed_h_mat = scaled_h_mat.sum(1) # Row-wise summation | ||
which would be | ||
summed_h_mat = [33, 16, 5, 1] | ||
|
||
and multiply with the instantaneous gradients g_x. | ||
g_xs = summed_h_mat * g_x | ||
|
||
g_xs = [33, 16, 5, 1] * [1, 1, 2, 6] | ||
g_xs = [33, 16, 10, 6] | ||
With this procedure we end up with the | ||
the gradients for the xs -> g_xs. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To me this only explains how but doesn't explain why, it's not obvious to me what we're doing here. Can you explain the overall idea behind why this works?
@ydwu4 I merged the latest main and also reworked the documentation quite a bit. In particular, I start now by explaining the naive gradient implementation and with all its steps and from there describe the grid form and detail the steps of that. Can you take another look to see whether it makes more sense to you now? |
@parametrize("combine_mode", ["generic"]) | ||
@parametrize("device", [torch.device("cpu")]) | ||
@parametrize("autograd", [True]) | ||
def test_associative_scan_compile_fail( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I separated off these two specific cases which lead to CPP compilation failures, which I don't think are necessarily related to the associative_scan
, as they appear only here. I marked those tests as expected fail and we should do it in a follow-up PR. For reference, the issue we observe with those tests is:
ERROR: test_associative_scan_compile_fail_reverse_True_compile_mode_compile_dynamic_shape_combine_mode_generic_cpu_autograd_True (__main__.AssociativeScanTests.test_associative_scan_compile_fail_reverse_True_compile_mode_compile_dynamic_shape_combine_mode_generic_cpu_autograd_True)
----------------------------------------------------------------------
Traceback (most recent call last):
File "/data_malta3_ssd/pytorch/torch/testing/_internal/common_utils.py", line 3154, in wrapper
method(*args, **kwargs)
File "/data_malta3_ssd/pytorch/torch/testing/_internal/common_utils.py", line 552, in instantiated_test
test(self, **param_kwargs)
File "/data_malta3_ssd/pytorch/test/functorch/test_control_flow.py", line 3594, in test_associative_scan_compile_fail
result = self._run_test(
^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/test/functorch/test_control_flow.py", line 3456, in _run_test
self._check_autograd(result, result_exp, autograd_param)
File "/data_malta3_ssd/pytorch/test/functorch/test_control_flow.py", line 3444, in _check_autograd
grads = torch.autograd.grad(result_flatten, grad_param, grad_init)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/autograd/__init__.py", line 503, in grad
result = _engine_run_backward(
^^^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/autograd/graph.py", line 829, in _engine_run_backward
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/autograd/function.py", line 307, in apply
return user_fn(self, *args)
^^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 2179, in backward
return impl_fn()
^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 2165, in impl_fn
out = CompiledFunction._backward_impl(ctx, all_args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 2257, in _backward_impl
CompiledFunction.compiled_bw = aot_config.bw_compiler(
^^^^^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/_functorch/aot_autograd.py", line 483, in __call__
return self.compiler_fn(gm, example_inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/_dynamo/backends/common.py", line 73, in _wrapped_bw_compiler
disable(
File "/data_malta3_ssd/pytorch/torch/_dynamo/eval_frame.py", line 872, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/_utils_internal.py", line 97, in wrapper_function
return function(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/_inductor/compile_fx.py", line 2234, in bw_compiler
return inner_compile(
^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/_inductor/compile_fx.py", line 710, in compile_fx_inner
return wrap_compiler_debug(_compile_fx_inner, compiler_name="inductor")(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/_dynamo/repro/after_aot.py", line 124, in debug_wrapper
inner_compiled_fn = compiler_fn(gm, example_inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/_inductor/compile_fx.py", line 880, in _compile_fx_inner
raise InductorError(e, currentframe()).with_traceback(
File "/data_malta3_ssd/pytorch/torch/_inductor/compile_fx.py", line 864, in _compile_fx_inner
mb_compiled_graph = fx_codegen_and_compile(
^^^^^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/_inductor/compile_fx.py", line 1487, in fx_codegen_and_compile
return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/_inductor/compile_fx.py", line 1374, in codegen_and_compile
compiled_module = graph.compile_to_module()
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/_inductor/graph.py", line 2238, in compile_to_module
return self._compile_to_module()
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/_inductor/graph.py", line 2248, in _compile_to_module
mod = self._compile_to_module_lines(wrapper_code)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/_inductor/graph.py", line 2312, in _compile_to_module_lines
mod = PyCodeCache.load_by_key_path(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/_inductor/codecache.py", line 3022, in load_by_key_path
mod = _reload_python_module(key, path, set_sys_modules=in_toplevel)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/_inductor/runtime/compile_tasks.py", line 31, in _reload_python_module
exec(code, mod.__dict__, mod.__dict__)
File "/tmp/torchinductor_boh/cd/ccdmedyasojek2hupoamg66aetwrqjyasivhvip62n63grppshs5.py", line 416, in <module>
async_compile.wait(globals())
File "/data_malta3_ssd/pytorch/torch/_inductor/async_compile.py", line 481, in wait
self._wait_futures(scope)
File "/data_malta3_ssd/pytorch/torch/_inductor/async_compile.py", line 501, in _wait_futures
kernel = result.result()
^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/_inductor/codecache.py", line 3524, in result
return self.result_fn()
^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/_inductor/codecache.py", line 2505, in future
result = get_result()
^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/_inductor/codecache.py", line 2313, in load_fn
future.result()
File "/data_malta3_ssd/miniforge3/envs/pt23/lib/python3.11/concurrent/futures/_base.py", line 449, in result
return self.__get_result()
^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/miniforge3/envs/pt23/lib/python3.11/concurrent/futures/_base.py", line 401, in __get_result
raise self._exception
File "/data_malta3_ssd/pytorch/torch/_inductor/compile_fx.py", line 864, in _compile_fx_inner
mb_compiled_graph = fx_codegen_and_compile(
^^^^^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/_inductor/compile_fx.py", line 1487, in fx_codegen_and_compile
return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/_inductor/compile_fx.py", line 1374, in codegen_and_compile
compiled_module = graph.compile_to_module()
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/_inductor/graph.py", line 2238, in compile_to_module
return self._compile_to_module()
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/_inductor/graph.py", line 2248, in _compile_to_module
mod = self._compile_to_module_lines(wrapper_code)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/_inductor/graph.py", line 2312, in _compile_to_module_lines
mod = PyCodeCache.load_by_key_path(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/_inductor/codecache.py", line 3022, in load_by_key_path
mod = _reload_python_module(key, path, set_sys_modules=in_toplevel)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/_inductor/runtime/compile_tasks.py", line 31, in _reload_python_module
exec(code, mod.__dict__, mod.__dict__)
File "/tmp/torchinductor_boh/cd/ccdmedyasojek2hupoamg66aetwrqjyasivhvip62n63grppshs5.py", line 416, in <module>
async_compile.wait(globals())
File "/data_malta3_ssd/pytorch/torch/_inductor/async_compile.py", line 481, in wait
self._wait_futures(scope)
File "/data_malta3_ssd/pytorch/torch/_inductor/async_compile.py", line 501, in _wait_futures
kernel = result.result()
^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/_inductor/codecache.py", line 3524, in result
return self.result_fn()
^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/_inductor/codecache.py", line 2505, in future
result = get_result()
^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/_inductor/codecache.py", line 2313, in load_fn
future.result()
File "/data_malta3_ssd/miniforge3/envs/pt23/lib/python3.11/concurrent/futures/_base.py", line 456, in result
return self.__get_result()
^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/miniforge3/envs/pt23/lib/python3.11/concurrent/futures/_base.py", line 401, in __get_result
raise self._exception
File "/data_malta3_ssd/miniforge3/envs/pt23/lib/python3.11/concurrent/futures/thread.py", line 58, in run
result = self.fn(*self.args, **self.kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/_inductor/codecache.py", line 2342, in _worker_compile_cpp
cpp_builder.build()
File "/data_malta3_ssd/pytorch/torch/_inductor/cpp_builder.py", line 1687, in build
run_compile_cmd(build_cmd, cwd=_build_tmp_dir)
File "/data_malta3_ssd/pytorch/torch/_inductor/cpp_builder.py", line 358, in run_compile_cmd
_run_compile_cmd(cmd_line, cwd)
File "/data_malta3_ssd/pytorch/torch/_inductor/cpp_builder.py", line 353, in _run_compile_cmd
raise exc.CppCompileError(cmd, output) from e
torch._inductor.exc.InductorError: CppCompileError: C++ compile error
Command:
g++ /tmp/torchinductor_boh/i7/ci7asxgns7q3tcdqd7ptdfdnsn65pdggp5yb5mofl4qhemvc7lph.cpp -D TORCH_INDUCTOR_CPP_WRAPPER -D STANDALONE_TORCH_HEADER -D C10_USING_CUSTOM_GENERATED_MACROS -D CPU_CAPABILITY_AVX2 -shared -fPIC -O3 -DNDEBUG -fno-trapping-math -funsafe-math-optimizations -ffinite-math-only -fno-signed-zeros -fno-math-errno -fexcess-precision=fast -fno-finite-math-only -fno-unsafe-math-optimizations -ffp-contract=off -fno-tree-loop-vectorize -march=native -Wall -std=c++17 -Wno-unused-variable -Wno-unknown-pragmas -fopenmp -I/data_malta3_ssd/miniforge3/envs/pt23/include/python3.11 -I/data_malta3_ssd/pytorch/torch/include -I/data_malta3_ssd/pytorch/torch/include/torch/csrc/api/include -mavx2 -mfma -mf16c -D_GLIBCXX_USE_CXX11_ABI=1 -o /tmp/torchinductor_boh/i7/ci7asxgns7q3tcdqd7ptdfdnsn65pdggp5yb5mofl4qhemvc7lph.so -ltorch -ltorch_cpu -ltorch_python -lgomp -L/opt/ssd/miniforge3/envs/pt23/lib -L/data_malta3_ssd/pytorch/torch/lib
Output:
/tmp/torchinductor_boh/i7/ci7asxgns7q3tcdqd7ptdfdnsn65pdggp5yb5mofl4qhemvc7lph.cpp: In lambda function:
/tmp/torchinductor_boh/i7/ci7asxgns7q3tcdqd7ptdfdnsn65pdggp5yb5mofl4qhemvc7lph.cpp:52:38: error: invalid cast from type ‘at::vec::CPU_CAPABILITY::Vectorized<float>’ to type ‘float’
52 | auto tmp19 = float(tmp7 + tmp18);
| ^~~~~~~~~~~~~~~~~~~
/tmp/torchinductor_boh/i7/ci7asxgns7q3tcdqd7ptdfdnsn65pdggp5yb5mofl4qhemvc7lph.cpp: In lambda function:
/tmp/torchinductor_boh/i7/ci7asxgns7q3tcdqd7ptdfdnsn65pdggp5yb5mofl4qhemvc7lph.cpp:96:43: warning: self-comparison always evaluates to false [-Wtautological-compare]
96 | auto tmp34 = tmp3 < tmp3;
| ~~~~ ^ ~~~~
/tmp/torchinductor_boh/i7/ci7asxgns7q3tcdqd7ptdfdnsn65pdggp5yb5mofl4qhemvc7lph.cpp:122:38: error: invalid cast from type ‘at::vec::CPU_CAPABILITY::Vectorized<float>’ to type ‘float’
122 | auto tmp43 = float(tmp33 + tmp42);
| ^~~~~~~~~~~~~~~~~~~~
/tmp/torchinductor_boh/i7/ci7asxgns7q3tcdqd7ptdfdnsn65pdggp5yb5mofl4qhemvc7lph.cpp: In lambda function:
/tmp/torchinductor_boh/i7/ci7asxgns7q3tcdqd7ptdfdnsn65pdggp5yb5mofl4qhemvc7lph.cpp:166:43: warning: self-comparison always evaluates to false [-Wtautological-compare]
166 | auto tmp58 = tmp3 < tmp3;
| ~~~~ ^ ~~~~
/tmp/torchinductor_boh/i7/ci7asxgns7q3tcdqd7ptdfdnsn65pdggp5yb5mofl4qhemvc7lph.cpp:192:38: error: invalid cast from type ‘at::vec::CPU_CAPABILITY::Vectorized<float>’ to type ‘float’
192 | auto tmp67 = float(tmp57 + tmp66);
| ^~~~~~~~~~~~~~~~~~~~
/tmp/torchinductor_boh/i7/ci7asxgns7q3tcdqd7ptdfdnsn65pdggp5yb5mofl4qhemvc7lph.cpp: In lambda function:
/tmp/torchinductor_boh/i7/ci7asxgns7q3tcdqd7ptdfdnsn65pdggp5yb5mofl4qhemvc7lph.cpp:254:47: warning: self-comparison always evaluates to false [-Wtautological-compare]
254 | auto tmp25 = tmp2 < tmp2;
| ~~~~ ^ ~~~~
/tmp/torchinductor_boh/i7/ci7asxgns7q3tcdqd7ptdfdnsn65pdggp5yb5mofl4qhemvc7lph.cpp: In lambda function:
/tmp/torchinductor_boh/i7/ci7asxgns7q3tcdqd7ptdfdnsn65pdggp5yb5mofl4qhemvc7lph.cpp:280:47: warning: self-comparison always evaluates to false [-Wtautological-compare]
280 | auto tmp43 = tmp2 < tmp2;
| ~~~~ ^ ~~~~
To execute this test, run the following from the base repo dir:
python test/functorch/test_control_flow.py AssociativeScanTests.test_associative_scan_compile_fail_reverse_True_compile_mode_compile_dynamic_shape_combine_mode_generic_cpu_autograd_True
This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0
----------------------------------------------------------------------```
xs = torch.arange(1, 5) = [1, 2, 3, 4] | ||
ys = torch.cumprod(xs) = [1, 2, 6, 24] | ||
|
||
def combine_fn(a: torch.Tensor, b: torch.Tensor): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a jump in logic between ys = torch.cumprod(xs) = [1, 2, 6, 24] and present the combine_fn and combine_fn_bw: I assume you want to write a simple associtaive_scan that does the same thing as torch.cumprod(xs) but i didn't see the connection explicitly written down.
Some idea Maybe put the following here:
The forward output of associative_scan is computed as:
ys = associative_scan(combine_fn, xs).
For example, this computation can be unpacked as:
ys_0 = xs_0
ys_1 = combine_fn(ys_0, xs_1)
...
ys_T = combine_fn(ys_(T-1), xs_T)
I recommend doing a serious proof reading and make sure the statements are well organized and closely connected. make sure the notations are well defined before they're used.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, I have locally fixed the docu for the cumprod example:
def combine_fn(a: torch.Tensor, b: torch.Tensor):
return a * b
ys = associative_scan(comine_fn, xs)
ys = [1, 2, 6, 24]
The function combine_fn_bw returns the gradients of a and b from combine_fn. It can be computes as:
def combine_fn_bw(a: torch.Tensor, b: torch.Tensor, g_y: torch.Tensor):
return g_y * b, g_y * a
The first output of combine_fn_bw is the instantaneous gradient for the previous output g_y_t
and the second output of combine_fn_bw is the instantaneous gradient for the input g_x_t.
Note: In a real usecase of associative_scan, there may be additional_inputs that participate in the
forward as well as in the backward of the scan operator. For the sake of readability those inputs
have been omitted in the following example, but are included in the subsequent detailed description below.
For the example above:
ys = associative_scan(comine_fn, xs),
the computation can be unpacked as:
ys_0 = xs_0
ys_1 = combine_fn(ys_0, xs_1)
...
ys_T = combine_fn(ys_(T-1), xs_T)
I will also do a thorough check that the notation in the docu matches the code. But I already tried to do that and at a first glance it couldn't really spot anything obvious. Do you have some specific concerns? Otherwise, I will check it again.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ydwu4 I gave it another go and update the documentation, trying to make the examples clear and also trying to double check that the variable names of the code match the docu. WDYT?
This PR implements the Autograd feature of the associative_scan.
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov @ydwu4