8000 [associative_scan] Autograd separated by bohnstingl · Pull Request #139939 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[associative_scan] Autograd separated #139939

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 30 commits into
base: main
Choose a base branch
from

Conversation

bohnstingl
Copy link
Collaborator
@bohnstingl bohnstingl commented Nov 6, 2024

@bohnstingl bohnstingl requested a review from zou3519 as a code owner November 6, 2024 23:39
Copy link
pytorch-bot bot commented Nov 6, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/139939

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 4 New Failures

As of commit a565834 with merge base dd7d231 (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@bohnstingl
Copy link
Collaborator Author

@pytorchbot label "topic: not user facing"

@pytorch-bot pytorch-bot bot added the topic: not user facing topic category label Nov 6, 2024
@bdhirsh bdhirsh added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Nov 8, 2024
8000
@zou3519 zou3519 requested review from ydwu4 and removed request for zou3519 November 11, 2024 17:59
@bohnstingl bohnstingl changed the title Improvements for associative_scan - Autograd separated [associative_scan] Autograd separated Nov 19, 2024
@bhack
Copy link
Contributor
bhack commented Dec 14, 2024

Any review on this?

@WeihanLikk
Copy link

Thanks for your implementation! I have a question regarding the shape check for xs:

assert x.shape == shape, "All xs tensors must have the same shape"

Why does it require the tensors to have exactly the same shape? In the JAX implementation, only the first dimension is required to match:

num_elems = int(elems_flat[0].shape[axis])
if not all(int(elem.shape[axis]) == num_elems for elem in elems_flat[1:]):

@bohnstingl
Copy link
Collaborator Author

Hi @WeihanLikk

Thank you for looking into this. I was slow in working on this over the holidays but will pick up steam again. You are right, I don't think that this is necessarily required. Just the scanned dimension needs to be identical for all xs. I will take a look

y_T = f(y_{T-1}, x_T)

The gradients of y_T with respect to the vector x are computed as:
dy_T / dx = dy_T/dx_1 + dy_T/dx_2 + ... + dy_T/dx_T

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not understanding this expression:

dy_T / dx = dy_T/dx_1 + dy_T/dx_2 + ... + dy_T/dx_T

Is there some typo in here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, I guess this should be the gradient of the element y_T with respect to the vector of inputs, i.e., with respect to every element x_1, x_2, ..., x_T of the vector. This would give individual elements like [dy_T / dx_1, dy_T / dx_2, ... dy_T / dx_T] and to get the final gradient these elements are summed.

@bohnstingl
Copy link
Collaborator Author
bohnstingl commented Mar 9, 2025

@garrett361 I've implemented a first version of the backward approach we discussed offline. The algorithm per se works, but there are still some things to sort out. In particular the lifted argument and partial gradient support.

EDIT: Of course there is still further room to cleanup the code and to adjust the documentation.

@garrett361
Copy link

lifted argument and partial gradient support

What does lifted argument mean?

Partial gradient = when only some inputs require_grad?

@bohnstingl
Copy link
Collaborator Author

What does lifted argument mean?

In some cases, variables and other properties from the combine_fn are lifted as additional inputs. For example, these could be external variables, or symbolic shapes of tensors as well. In the following case

H = torch.rand(2, device=device)
def fct_freevars1(x: torch.Tensor, y: torch.Tensor):
    return x * H + y * 2

H would become a lifted variable. I know how to handle those, but I think @ydwu4 is currently working on simplifying the autograd architecture for higher order operators, such as associative scan, to simplify this handling.

Partial gradient = when only some inputs require_grad?

Yes, that is correct. Same applies here as well. I know how to handle it, but I wanted to wait for the autograd rework.

@windsornguyen
Copy link

Currently using associative scan for a big research project related to linear attention. (or perhaps I should say, logarithmic attention 😉)

Is there an expected timeline for autograd support to be available? Really excited about this PR!!


import torch
import torch._prims_common as utils
import torch._subclasses.functional_tensor
import torch.utils._pytree as pytree
from torch._C import DispatchKey
from torch._dispatch.python import suspend_functionalization
from torch._higher_order_ops.cond import create_bw_fn, materialize_as_graph
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that whatever PR lands first, either this one or #146285, should move these functions from torch._higher_order_ops.cond to torch._higher_order_ops.utils?

@bohnstingl
Copy link
Collaborator Author

@ydwu4 Taking your feedback from #146285 into consideration, I further cleaned up this PR and also significantly expanded the documentation of the code to make it more readable. Please let me know what you think.

We then proceed to compute the gradients for the xs (g_xs) by computing for every step:
the instantaneous for g_x_t, as well as
the gradient for the previous intermediate results g_h_t using
g_h_t, g_x_t = ctx._combine_fn_bw(ys_(t-1), xs_t)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I get lost at why _combine_fn_bw doesn't need the gradeint input g_ys_t?

Is g_h_t just ys_t? We should also define what g_h_t is in the unpacked forward.

g_h_t, g_x_t = ctx._combine_fn_bw(ys_(t-1), xs_t)

For example:
g_h_0, g_x_0 = [1, 1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm confused how we get these two, i thought we have xs, ys (saved by forward), and g_ys (backward inputs)?

Comment on lines 494 to 526
We can then compute the 'gradient transition matrix' h_mat using the instantaneous gradient components g_h_t as
h_mat = [[1, g_h_1, g_h_2 . g_h_1, g_h_3 . g_h_2 . g_h_1],
[0, 1 , g_h_2 , g_h_3 . g_h_2 ],
[0, 0 , 1 , g_h_3 ],
[0, 0 , 0 , 1 ]],
which for the example above results in:
h_mat = [[1, 2, 6, 24],
[0, 1, 3, 12],
[0, 0, 1, 4],
[0, 0, 0, 1]]

We then scale the h_mat with the upstream gradient g_ys

scaled_h_mat = h_mat * g_ys
Assuming all 1s for the upstream gradients this would result in:
scaled_h_mat = [[1, 2, 6, 24],
[0, 1, 3, 12],
[0, 0, 1, 4],
[0, 0, 0, 1]]

Sum the h_mat row-wise
summed_h_mat = scaled_h_mat.sum(1) # Row-wise summation
which would be
summed_h_mat = [33, 16, 5, 1]

and multiply with the instantaneous gradients g_x.
g_xs = summed_h_mat * g_x

g_xs = [33, 16, 5, 1] * [1, 1, 2, 6]
g_xs = [33, 16, 10, 6]
With this procedure we end up with the
the gradients for the xs -> g_xs.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To me this only explains how but doesn't explain why, it's not obvious to me what we're doing here. Can you explain the overall idea behind why this works?

@bohnstingl
Copy link
Collaborator Author

@ydwu4 I merged the latest main and also reworked the documentation quite a bit. In particular, I start now by explaining the naive gradient implementation and with all its steps and from there describe the grid form and detail the steps of that.

Can you take another look to see whether it makes more sense to you now?

@bohnstingl bohnstingl requested a review from ydwu4 May 8, 2025 21:48
@parametrize("combine_mode", ["generic"])
@parametrize("device", [torch.device("cpu")])
@parametrize("autograd", [True])
def test_associative_scan_compile_fail(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I separated off these two specific cases which lead to CPP compilation failures, which I don't think are necessarily related to the associative_scan, as they appear only here. I marked those tests as expected fail and we should do it in a follow-up PR. For reference, the issue we observe with those tests is:

ERROR: test_associative_scan_compile_fail_reverse_True_compile_mode_compile_dynamic_shape_combine_mode_generic_cpu_autograd_True (__main__.AssociativeScanTests.test_associative_scan_compile_fail_reverse_True_compile_mode_compile_dynamic_shape_combine_mode_generic_cpu_autograd_True)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/data_malta3_ssd/pytorch/torch/testing/_internal/common_utils.py", line 3154, in wrapper
    method(*args, **kwargs)
  File "/data_malta3_ssd/pytorch/torch/testing/_internal/common_utils.py", line 552, in instantiated_test
    test(self, **param_kwargs)
  File "/data_malta3_ssd/pytorch/test/functorch/test_control_flow.py", line 3594, in test_associative_scan_compile_fail
    result = self._run_test(
             ^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/test/functorch/test_control_flow.py", line 3456, in _run_test
    self._check_autograd(result, result_exp, autograd_param)
  File "/data_malta3_ssd/pytorch/test/functorch/test_control_flow.py", line 3444, in _check_autograd
    grads = torch.autograd.grad(result_flatten, grad_param, grad_init)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/autograd/__init__.py", line 503, in grad
    result = _engine_run_backward(
             ^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/autograd/graph.py", line 829, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/autograd/function.py", line 307, in apply
    return user_fn(self, *args)
           ^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 2179, in backward
    return impl_fn()
           ^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 2165, in impl_fn
    out = CompiledFunction._backward_impl(ctx, all_args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 2257, in _backward_impl
    CompiledFunction.compiled_bw = aot_config.bw_compiler(
                                   ^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_functorch/aot_autograd.py", line 483, in __call__
    return self.compiler_fn(gm, example_inputs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_dynamo/backends/common.py", line 73, in _wrapped_bw_compiler
    disable(
  File "/data_malta3_ssd/pytorch/torch/_dynamo/eval_frame.py", line 872, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_utils_internal.py", line 97, in wrapper_function
    return function(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_inductor/compile_fx.py", line 2234, in bw_compiler
    return inner_compile(
           ^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_inductor/compile_fx.py", line 710, in compile_fx_inner
    return wrap_compiler_debug(_compile_fx_inner, compiler_name="inductor")(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_dynamo/repro/after_aot.py", line 124, in debug_wrapper
    inner_compiled_fn = compiler_fn(gm, example_inputs)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_inductor/compile_fx.py", line 880, in _compile_fx_inner
    raise InductorError(e, currentframe()).with_traceback(
  File "/data_malta3_ssd/pytorch/torch/_inductor/compile_fx.py", line 864, in _compile_fx_inner
    mb_compiled_graph = fx_codegen_and_compile(
                        ^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_inductor/compile_fx.py", line 1487, in fx_codegen_and_compile
    return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_inductor/compile_fx.py", line 1374, in codegen_and_compile
    compiled_module = graph.compile_to_module()
                      ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_inductor/graph.py", line 2238, in compile_to_module
    return self._compile_to_module()
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_inductor/graph.py", line 2248, in _compile_to_module
    mod = self._compile_to_module_lines(wrapper_code)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_inductor/graph.py", line 2312, in _compile_to_module_lines
    mod = PyCodeCache.load_by_key_path(
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_inductor/codecache.py", line 3022, in load_by_key_path
    mod = _reload_python_module(key, path, set_sys_modules=in_toplevel)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_inductor/runtime/compile_tasks.py", line 31, in _reload_python_module
    exec(code, mod.__dict__, mod.__dict__)
  File "/tmp/torchinductor_boh/cd/ccdmedyasojek2hupoamg66aetwrqjyasivhvip62n63grppshs5.py", line 416, in <module>
    async_compile.wait(globals())
  File "/data_malta3_ssd/pytorch/torch/_inductor/async_compile.py", line 481, in wait
    self._wait_futures(scope)
  File "/data_malta3_ssd/pytorch/torch/_inductor/async_compile.py", line 501, in _wait_futures
    kernel = result.result()
             ^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_inductor/codecache.py", line 3524, in result
    return self.result_fn()
           ^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_inductor/codecache.py", line 2505, in future
    result = get_result()
             ^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_inductor/codecache.py", line 2313, in load_fn
    future.result()
  File "/data_malta3_ssd/miniforge3/envs/pt23/lib/python3.11/concurrent/futures/_base.py", line 449, in result
    return self.__get_result()
           ^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/miniforge3/envs/pt23/lib/python3.11/concurrent/futures/_base.py", line 401, in __get_result
    raise self._exception
  File "/data_malta3_ssd/pytorch/torch/_inductor/compile_fx.py", line 864, in _compile_fx_inner
    mb_compiled_graph = fx_codegen_and_compile(
                        ^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_inductor/compile_fx.py", line 1487, in fx_codegen_and_compile
    return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_inductor/compile_fx.py", line 1374, in codegen_and_compile
    compiled_module = graph.compile_to_module()
                      ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_inductor/graph.py", line 2238, in compile_to_module
    return self._compile_to_module()
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_inductor/graph.py", line 2248, in _compile_to_module
    mod = self._compile_to_module_lines(wrapper_code)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_inductor/graph.py", line 2312, in _compile_to_module_lines
    mod = PyCodeCache.load_by_key_path(
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_inductor/codecache.py", line 3022, in load_by_key_path
    mod = _reload_python_module(key, path, set_sys_modules=in_toplevel)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_inductor/runtime/compile_tasks.py", line 31, in _reload_python_module
    exec(code, mod.__dict__, mod.__dict__)
  File "/tmp/torchinductor_boh/cd/ccdmedyasojek2hupoamg66aetwrqjyasivhvip62n63grppshs5.py", line 416, in <module>
    async_compile.wait(globals())
  File "/data_malta3_ssd/pytorch/torch/_inductor/async_compile.py", line 481, in wait
    self._wait_futures(scope)
  File "/data_malta3_ssd/pytorch/torch/_inductor/async_compile.py", line 501, in _wait_futures
    kernel = result.result()
             ^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_inductor/codecache.py", line 3524, in result
    return self.result_fn()
           ^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_inductor/codecache.py", line 2505, in future
    result = get_result()
             ^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_inductor/codecache.py", line 2313, in load_fn
    future.result()
  File "/data_malta3_ssd/miniforge3/envs/pt23/lib/python3.11/concurrent/futures/_base.py", line 456, in result
    return self.__get_result()
           ^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/miniforge3/envs/pt23/lib/python3.11/concurrent/futures/_base.py", line 401, in __get_result
    raise self._exception
  File "/data_malta3_ssd/miniforge3/envs/pt23/lib/python3.11/concurrent/futures/thread.py", line 58, in run
    result = self.fn(*self.args, **self.kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_inductor/codecache.py", line 2342, in _worker_compile_cpp
    cpp_builder.build()
  File "/data_malta3_ssd/pytorch/torch/_inductor/cpp_builder.py", line 1687, in build
    run_compile_cmd(build_cmd, cwd=_build_tmp_dir)
  File "/data_malta3_ssd/pytorch/torch/_inductor/cpp_builder.py", line 358, in run_compile_cmd
    _run_compile_cmd(cmd_line, cwd)
  File "/data_malta3_ssd/pytorch/torch/_inductor/cpp_builder.py", line 353, in _run_compile_cmd
    raise exc.CppCompileError(cmd, output) from e
torch._inductor.exc.InductorError: CppCompileError: C++ compile error

Command:
g++ /tmp/torchinductor_boh/i7/ci7asxgns7q3tcdqd7ptdfdnsn65pdggp5yb5mofl4qhemvc7lph.cpp -D TORCH_INDUCTOR_CPP_WRAPPER -D STANDALONE_TORCH_HEADER -D C10_USING_CUSTOM_GENERATED_MACROS -D CPU_CAPABILITY_AVX2 -shared -fPIC -O3 -DNDEBUG -fno-trapping-math -funsafe-math-optimizations -ffinite-math-only -fno-signed-zeros -fno-math-errno -fexcess-precision=fast -fno-finite-math-only -fno-unsafe-math-optimizations -ffp-contract=off -fno-tree-loop-vectorize -march=native -Wall -std=c++17 -Wno-unused-variable -Wno-unknown-pragmas -fopenmp -I/data_malta3_ssd/miniforge3/envs/pt23/include/python3.11 -I/data_malta3_ssd/pytorch/torch/include -I/data_malta3_ssd/pytorch/torch/include/torch/csrc/api/include -mavx2 -mfma -mf16c -D_GLIBCXX_USE_CXX11_ABI=1 -o /tmp/torchinductor_boh/i7/ci7asxgns7q3tcdqd7ptdfdnsn65pdggp5yb5mofl4qhemvc7lph.so -ltorch -ltorch_cpu -ltorch_python -lgomp -L/opt/ssd/miniforge3/envs/pt23/lib -L/data_malta3_ssd/pytorch/torch/lib

Output:
/tmp/torchinductor_boh/i7/ci7asxgns7q3tcdqd7ptdfdnsn65pdggp5yb5mofl4qhemvc7lph.cpp: In lambda function:
/tmp/torchinductor_boh/i7/ci7asxgns7q3tcdqd7ptdfdnsn65pdggp5yb5mofl4qhemvc7lph.cpp:52:38: error: invalid cast from type ‘at::vec::CPU_CAPABILITY::Vectorized<float>’ to type ‘float’
   52 |                         auto tmp19 = float(tmp7 + tmp18);
      |                                      ^~~~~~~~~~~~~~~~~~~
/tmp/torchinductor_boh/i7/ci7asxgns7q3tcdqd7ptdfdnsn65pdggp5yb5mofl4qhemvc7lph.cpp: In lambda function:
/tmp/torchinductor_boh/i7/ci7asxgns7q3tcdqd7ptdfdnsn65pdggp5yb5mofl4qhemvc7lph.cpp:96:43: warning: self-comparison always evaluates to false [-Wtautological-compare]
   96 |                         auto tmp34 = tmp3 < tmp3;
      |                                      ~~~~ ^ ~~~~
/tmp/torchinductor_boh/i7/ci7asxgns7q3tcdqd7ptdfdnsn65pdggp5yb5mofl4qhemvc7lph.cpp:122:38: error: invalid cast from type ‘at::vec::CPU_CAPABILITY::Vectorized<float>’ to type ‘float’
  122 |                         auto tmp43 = float(tmp33 + tmp42);
      |                                      ^~~~~~~~~~~~~~~~~~~~
/tmp/torchinductor_boh/i7/ci7asxgns7q3tcdqd7ptdfdnsn65pdggp5yb5mofl4qhemvc7lph.cpp: In lambda function:
/tmp/torchinductor_boh/i7/ci7asxgns7q3tcdqd7ptdfdnsn65pdggp5yb5mofl4qhemvc7lph.cpp:166:43: warning: self-comparison always evaluates to false [-Wtautological-compare]
  166 |                         auto tmp58 = tmp3 < tmp3;
      |                                      ~~~~ ^ ~~~~
/tmp/torchinductor_boh/i7/ci7asxgns7q3tcdqd7ptdfdnsn65pdggp5yb5mofl4qhemvc7lph.cpp:192:38: error: invalid cast from type ‘at::vec::CPU_CAPABILITY::Vectorized<float>’ to type ‘float’
  192 |                         auto tmp67 = float(tmp57 + tmp66);
      |                                      ^~~~~~~~~~~~~~~~~~~~
/tmp/torchinductor_boh/i7/ci7asxgns7q3tcdqd7ptdfdnsn65pdggp5yb5mofl4qhemvc7lph.cpp: In lambda function:
/tmp/torchinductor_boh/i7/ci7asxgns7q3tcdqd7ptdfdnsn65pdggp5yb5mofl4qhemvc7lph.cpp:254:47: warning: self-comparison always evaluates to false [-Wtautological-compare]
  254 |                             auto tmp25 = tmp2 < tmp2;
      |                                          ~~~~ ^ ~~~~
/tmp/torchinductor_boh/i7/ci7asxgns7q3tcdqd7ptdfdnsn65pdggp5yb5mofl4qhemvc7lph.cpp: In lambda function:
/tmp/torchinductor_boh/i7/ci7asxgns7q3tcdqd7ptdfdnsn65pdggp5yb5mofl4qhemvc7lph.cpp:280:47: warning: self-comparison always evaluates to false [-Wtautological-compare]
  280 |                             auto tmp43 = tmp2 < tmp2;
      |                                          ~~~~ ^ ~~~~


To execute this test, run the following from the base repo dir:
    python test/functorch/test_control_flow.py AssociativeScanTests.test_associative_scan_compile_fail_reverse_True_compile_mode_compile_dynamic_shape_combine_mode_generic_cpu_autograd_True

This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0

----------------------------------------------------------------------```

xs = torch.arange(1, 5) = [1, 2, 3, 4]
ys = torch.cumprod(xs) = [1, 2, 6, 24]

def combine_fn(a: torch.Tensor, b: torch.Tensor):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a jump in logic between ys = torch.cumprod(xs) = [1, 2, 6, 24] and present the combine_fn and combine_fn_bw: I assume you want to write a simple associtaive_scan that does the same thing as torch.cumprod(xs) but i didn't see the connection explicitly written down.

Some idea Maybe put the following here:

        The forward output of associative_scan is computed as:
        ys = associative_scan(combine_fn, xs).
        For example, this computation can be unpacked as:
        ys_0 = xs_0
        ys_1 = combine_fn(ys_0, xs_1)
        ...
        ys_T = combine_fn(ys_(T-1), xs_T)

I recommend doing a serious proof reading and make sure the statements are well organized and closely connected. make sure the notations are well defined before they're used.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, I have locally fixed the docu for the cumprod example:

def combine_fn(a: torch.Tensor, b: torch.Tensor):
    return a * b

ys = associative_scan(comine_fn, xs)
ys = [1, 2, 6, 24]

The function combine_fn_bw returns the gradients of a and b from combine_fn. It can be computes as:
def combine_fn_bw(a: torch.Tensor, b: torch.Tensor, g_y: torch.Tensor):
    return g_y * b, g_y * a

The first output of combine_fn_bw is the instantaneous gradient for the previous output g_y_t
and the second output of combine_fn_bw is the instantaneous gradient for the input g_x_t.

Note: In a real usecase of associative_scan, there may be additional_inputs that participate in the
forward as well as in the backward of the scan operator. For the sake of readability those inputs
have been omitted in the following example, but are included in the subsequent detailed description below.

For the example above:
ys = associative_scan(comine_fn, xs),
the computation can be unpacked as:
ys_0 = xs_0
ys_1 = combine_fn(ys_0, xs_1)
...
ys_T = combine_fn(ys_(T-1), xs_T)

I will also do a thorough check that the notation in the docu matches the code. But I already tried to do that and at a first glance it couldn't really spot anything obvious. Do you have some specific concerns? Otherwise, I will check it again.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ydwu4 I gave it another go and update the documentation, trying to make the examples clear and also trying to double check that the variable names of the code match the docu. WDYT?

@bohnstingl bohnstingl requested a review from ydwu4 May 16, 2025 23:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: dynamo module: inductor open source topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants
0