8000 [associative_scan] Autograd separated by bohnstingl · Pull Request #139939 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[associative_scan] Autograd separated #139939

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 34 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
9c49a36
WIP: Associative_scan Autograd
bohnstingl Nov 6, 2024
100c598
Working implementation of Autograd
bohnstingl Nov 6, 2024
0e7c8d5
Added partial gradient tests
bohnstingl Nov 7, 2024
67d62fb
Separated out the partial gradient functionality to a separate PR
bohnstingl Nov 7, 2024
9e5e1f9
Merge branch 'main' of github.com:pytorch/pytorch into associative_sc…
bohnstingl Jan 18, 2025
6b41565
Working version uncleaned
bohnstingl Mar 8, 2025
d68b31b
Almost all tests pass
bohnstingl Mar 8, 2025
9de0caf
Merge branch 'main' of github.com:pytorch/pytorch into associative_sc…
bohnstingl Mar 8, 2025
653eab0
First working implementation of simplified autograd for combine_mode=…
bohnstingl Mar 9, 2025
ac1a12b
Merge branch 'associative_scan_74' of github.com:bohnstingl/pytorch i…
bohnstingl Mar 24, 2025
d51b16c
Merge branch 'main' of github.com:pytorch/pytorch into associative_sc…
bohnstingl Mar 24, 2025
dbfd544
WIP: No additional_input support yet
bohnstingl Mar 24, 2025
64a8b29
Updates and cosmetic fixes
bohnstingl Mar 25, 2025
08b7251
iFixed problem with python<3.11
bohnstingl Mar 25, 2025
f378827
Merge branch 'main' of github.com:pytorch/pytorch into associative_sc…
bohnstingl Mar 26, 2025
b12777e
Fixed issue with adding tuple and list
bohnstingl Mar 26, 2025
6335022
Merge branch 'main' of github.com:pytorch/pytorch into associative_sc…
bohnstingl Mar 27, 2025
1c9fa70
Merge branch 'main' of github.com:pytorch/pytorch into associative_sc…
bohnstingl Mar 27, 2025
6d8353b
Merge branch 'main' of github.com:pytorch/pytorch into associative_sc…
bohnstingl Apr 2, 2025
3f81639
Rework to improve readability and unify shared function with scan
bohnstingl Apr 2, 2025
62770d5
Removed irrelevant testcases, improved readability, extended document…
bohnstingl Apr 3, 2025
6761714
Restructured documentation and synced with associative_scan
bohnstingl Apr 8, 2025
b73c553
Merge branch 'main' of github.com:pytorch/pytorch into associative_sc…
bohnstingl Apr 8, 2025
ea9568d
Merge branch 'main' of github.com:pytorch/pytorch into associative_sc…
bohnstingl Apr 15, 2025
0786e6f
Merge branch 'main' of github.com:pytorch/pytorch into associative_sc…
bohnstingl Apr 24, 2025
c628dd8
Merge branch 'main' of github.com:pytorch/pytorch into associative_sc…
bohnstingl May 8, 2025
e7d63b6
Skipping autograd test for associative_scan with lifted arguments
bohnstingl May 8, 2025
95d0e42
Reworked documentation
bohnstingl May 8, 2025
3707c0d
Factored failing test out into separate test
bohnstingl May 9, 2025
a565834
Rework of documentation
bohnstingl May 16, 2025
122d70b
Merge branch 'main' of github.com:pytorch/pytorch into associative_sc…
bohnstingl Jul 22, 2025
d89bca6
Merge branch 'main' of github.com:pytorch/pytorch into associative_sc…
bohnstingl Jul 26, 2025
c1f1267
Removed former failing testcase for associative_scan
bohnstingl Jul 27, 2025
a59ad4c
Consolidated utility functions between scan.py and associative_scan.p…
bohnstingl Jul 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Factored failing test out into separate test
  • Loading branch information
bohnstingl committed May 9, 2025
commit 3707c0df1bdc33a0f4d9fea2d2d74f11bcdec9a0
70 changes: 70 additions & 0 deletions test/functorch/test_control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3485,6 +3485,18 @@ def _prepare_fake_kwargs(self, original_kwargs):
)
),
)
# Skipping this combination as there is a CPP compilation failure that
# may be unrelated to associative_scan itself. There is a dedicated tests for
# this case below.
@decorateIf(
unittest.skip,
lambda params: (
params["compile_mode"] == "compile_dynamic_shape"
and params["combine_mode"] == "generic"
and params["device"] == torch.device("cpu")
and params["autograd"]
),
)
def test_associative_scan_compile(
self, combine_mode, reverse, compile_mode, device, autograd
):
Expand Down Expand Up @@ -3535,6 +3547,64 @@ def test_associative_scan_compile(

self.assertEqual(result, results_torch)

@unittest.skipIf(not SM70OrLater, "triton")
@requires_cuda
@unittest.expectedFailure
@parametrize("reverse", [False, True])
@parametrize("compile_mode", ["compile_dynamic_shape"])
@parametrize("combine_mode", ["generic"])
@parametrize("device", [torch.device("cpu")])
@parametrize("autograd", [True])
def test_associative_scan_compile_fail(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I separated off these two specific cases which lead to CPP compilation failures, which I don't think are necessarily related to the associative_scan, as they appear only here. I marked those tests as expected fail and we should do it in a follow-up PR. For reference, the issue we observe with those tests is:

ERROR: test_associative_scan_compile_fail_reverse_True_compile_mode_compile_dynamic_shape_combine_mode_generic_cpu_autograd_True (__main__.AssociativeScanTests.test_associative_scan_compile_fail_reverse_True_compile_mode_compile_dynamic_shape_combine_mode_generic_cpu_autograd_True)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/data_malta3_ssd/pytorch/torch/testing/_internal/common_utils.py", line 3154, in wrapper
    method(*args, **kwargs)
  File "/data_malta3_ssd/pytorch/torch/testing/_internal/common_utils.py", line 552, in instantiated_test
    test(self, **param_kwargs)
  File "/data_malta3_ssd/pytorch/test/functorch/test_control_flow.py", line 3594, in test_associative_scan_compile_fail
    result = self._run_test(
             ^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/test/functorch/test_control_flow.py", line 3456, in _run_test
    self._check_autograd(result, result_exp, autograd_param)
  File "/data_malta3_ssd/pytorch/test/functorch/test_control_flow.py", line 3444, in _check_autograd
    grads = torch.autograd.grad(result_flatten, grad_param, grad_init)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/autograd/__init__.py", line 503, in grad
    result = _engine_run_backward(
             ^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/autograd/graph.py", line 829, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/autograd/function.py", line 307, in apply
    return user_fn(self, *args)
           ^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 2179, in backward
    return impl_fn()
           ^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 2165, in impl_fn
    out = CompiledFunction._backward_impl(ctx, all_args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 2257, in _backward_impl
    CompiledFunction.compiled_bw = aot_config.bw_compiler(
                                   ^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_functorch/aot_autograd.py", line 483, in __call__
    return self.compiler_fn(gm, example_inputs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_dynamo/backends/common.py", line 73, in _wrapped_bw_compiler
    disable(
  File "/data_malta3_ssd/pytorch/torch/_dynamo/eval_frame.py", line 872, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_utils_internal.py", line 97, in wrapper_function
    return function(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_inductor/compile_fx.py", line 2234, in bw_compiler
    return inner_compile(
           ^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_inductor/compile_fx.py", line 710, in compile_fx_inner
    return wrap_compiler_debug(_compile_fx_inner, compiler_name="inductor")(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_dynamo/repro/after_aot.py", line 124, in debug_wrapper
    inner_compiled_fn = compiler_fn(gm, example_inputs)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_inductor/compile_fx.py", line 880, in _compile_fx_inner
    raise InductorError(e, currentframe()).with_traceback(
  File "/data_malta3_ssd/pytorch/torch/_inductor/compile_fx.py", line 864, in _compile_fx_inner
    mb_compiled_graph = fx_codegen_and_compile(
                        ^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_inductor/compile_fx.py", line 1487, in fx_codegen_and_compile
    return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_inductor/compile_fx.py", line 1374, in codegen_and_compile
    compiled_module = graph.compile_to_module()
                      ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_inductor/graph.py", line 2238, in compile_to_module
    return self._compile_to_module()
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_inductor/graph.py", line 2248, in _compile_to_module
    mod = self._compile_to_module_lines(wrapper_code)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_inductor/graph.py", line 2312, in _compile_to_module_lines
    mod = PyCodeCache.load_by_key_path(
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_inductor/codecache.py", line 3022, in load_by_key_path
    mod = _reload_python_module(key, path, set_sys_modules=in_toplevel)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_inductor/runtime/compile_tasks.py", line 31, in _reload_python_module
    exec(code, mod.__dict__, mod.__dict__)
  File "/tmp/torchinductor_boh/cd/ccdmedyasojek2hupoamg66aetwrqjyasivhvip62n63grppshs5.py", line 416, in <module>
    async_compile.wait(globals())
  File "/data_malta3_ssd/pytorch/torch/_inductor/async_compile.py", line 481, in wait
    self._wait_futures(scope)
  File "/data_malta3_ssd/pytorch/torch/_inductor/async_compile.py", line 501, in _wait_futures
    kernel = result.result()
             ^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_inductor/codecache.py", line 3524, in result
    return self.result_fn()
           ^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_inductor/codecache.py", line 2505, in future
    result = get_result()
             ^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_inductor/codecache.py", line 2313, in load_fn
    future.result()
  File "/data_malta3_ssd/miniforge3/envs/pt23/lib/python3.11/concurrent/futures/_base.py", line 449, in result
    return self.__get_result()
           ^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/miniforge3/envs/pt23/lib/python3.11/concurrent/futures/_base.py", line 401, in __get_result
    raise self._exception
  File "/data_malta3_ssd/pytorch/torch/_inductor/compile_fx.py", line 864, in _compile_fx_inner
    mb_compiled_graph = fx_codegen_and_compile(
                        ^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_inductor/compile_fx.py", line 1487, in fx_codegen_and_compile
    return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_inductor/compile_fx.py", line 1374, in codegen_and_compile
    compiled_module = graph.compile_to_module()
                      ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_inductor/graph.py", line 2238, in compile_to_module
    return self._compile_to_module()
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_inductor/graph.py", line 2248, in _compile_to_module
    mod = self._compile_to_module_lines(wrapper_code)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_inductor/graph.py", line 2312, in _compile_to_module_lines
    mod = PyCodeCache.load_by_key_path(
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_inductor/codecache.py", line 3022, in load_by_key_path
    mod = _reload_python_module(key, path, set_sys_modules=in_toplevel)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_inductor/runtime/compile_tasks.py", line 31, in _reload_python_module
    exec(code, mod.__dict__, mod.__dict__)
  File "/tmp/torchinductor_boh/cd/ccdmedyasojek2hupoamg66aetwrqjyasivhvip62n63grppshs5.py", line 416, in <module>
    async_compile.wait(globals())
  File "/data_malta3_ssd/pytorch/torch/_inductor/async_compile.py", line 481, in wait
    self._wait_futures(scope)
  File "/data_malta3_ssd/pytorch/torch/_inductor/async_compile.py", line 501, in _wait_futures
    kernel = result.result()
             ^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_inductor/codecache.py", line 3524, in result
    return self.result_fn()
           ^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_inductor/codecache.py", line 2505, in future
    result = get_result()
             ^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_inductor/codecache.py", line 2313, in load_fn
    future.result()
  File "/data_malta3_ssd/miniforge3/envs/pt23/lib/python3.11/concurrent/futures/_base.py", line 456, in result
    return self.__get_result()
           ^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/miniforge3/envs/pt23/lib/python3.11/concurrent/futures/_base.py", line 401, in __get_result
    raise self._exception
  File "/data_malta3_ssd/miniforge3/envs/pt23/lib/python3.11/concurrent/futures/thread.py", line 58, in run
    result = self.fn(*self.args, **self.kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_inductor/codecache.py", line 2342, in _worker_compile_cpp
    cpp_builder.build()
  File "/data_malta3_ssd/pytorch/torch/_inductor/cpp_builder.py", line 1687, in build
    run_compile_cmd(build_cmd, cwd=_build_tmp_dir)
  File "/data_malta3_ssd/pytorch/torch/_inductor/cpp_builder.py", line 358, in run_compile_cmd
    _run_compile_cmd(cmd_line, cwd)
  File "/data_malta3_ssd/pytorch/torch/_inductor/cpp_builder.py", line 353, in _run_compile_cmd
    raise exc.CppCompileError(cmd, output) from e
torch._inductor.exc.InductorError: CppCompileError: C++ compile error

Command:
g++ /tmp/torchinductor_boh/i7/ci7asxgns7q3tcdqd7ptdfdnsn65pdggp5yb5mofl4qhemvc7lph.cpp -D TORCH_INDUCTOR_CPP_WRAPPER -D STANDALONE_TORCH_HEADER -D C10_USING_CUSTOM_GENERATED_MACROS -D CPU_CAPABILITY_AVX2 -shared -fPIC -O3 -DNDEBUG -fno-trapping-math -funsafe-math-optimizations -ffinite-math-only -fno-signed-zeros -fno-math-errno -fexcess-precision=fast -fno-finite-math-only -fno-unsafe-math-optimizations -ffp-contract=off -fno-tree-loop-vectorize -march=native -Wall -std=c++17 -Wno-unused-variable -Wno-unknown-pragmas -fopenmp -I/data_malta3_ssd/miniforge3/envs/pt23/include/python3.11 -I/data_malta3_ssd/pytorch/torch/include -I/data_malta3_ssd/pytorch/torch/include/torch/csrc/api/include -mavx2 -mfma -mf16c -D_GLIBCXX_USE_CXX11_ABI=1 -o /tmp/torchinductor_boh/i7/ci7asxgns7q3tcdqd7ptdfdnsn65pdggp5yb5mofl4qhemvc7lph.so -ltorch -ltorch_cpu -ltorch_python -lgomp -L/opt/ssd/miniforge3/envs/pt23/lib -L/data_malta3_ssd/pytorch/torch/lib

Output:
/tmp/torchinductor_boh/i7/ci7asxgns7q3tcdqd7ptdfdnsn65pdggp5yb5mofl4qhemvc7lph.cpp: In lambda function:
/tmp/torchinductor_boh/i7/ci7asxgns7q3tcdqd7ptdfdnsn65pdggp5yb5mofl4qhemvc7lph.cpp:52:38: error: invalid cast from type ‘at::vec::CPU_CAPABILITY::Vectorized<float>’ to type ‘float’
   52 |                         auto tmp19 = float(tmp7 + tmp18);
      |                                      ^~~~~~~~~~~~~~~~~~~
/tmp/torchinductor_boh/i7/ci7asxgns7q3tcdqd7ptdfdnsn65pdggp5yb5mofl4qhemvc7lph.cpp: In lambda function:
/tmp/torchinductor_boh/i7/ci7asxgns7q3tcdqd7ptdfdnsn65pdggp5yb5mofl4qhemvc7lph.cpp:96:43: warning: self-comparison always evaluates to false [-Wtautological-compare]
   96 |                         auto tmp34 = tmp3 < tmp3;
      |                                      ~~~~ ^ ~~~~
/tmp/torchinductor_boh/i7/ci7asxgns7q3tcdqd7ptdfdnsn65pdggp5yb5mofl4qhemvc7lph.cpp:122:38: error: invalid cast from type ‘at::vec::CPU_CAPABILITY::Vectorized<float>’ to type ‘float’
  122 |                         auto tmp43 = float(tmp33 + tmp42);
      |                                      ^~~~~~~~~~~~~~~~~~~~
/tmp/torchinductor_boh/i7/ci7asxgns7q3tcdqd7ptdfdnsn65pdggp5yb5mofl4qhemvc7lph.cpp: In lambda function:
/tmp/torchinductor_boh/i7/ci7asxgns7q3tcdqd7ptdfdnsn65pdggp5yb5mofl4qhemvc7lph.cpp:166:43: warning: self-comparison always evaluates to false [-Wtautological-compare]
  166 |                         auto tmp58 = tmp3 < tmp3;
      |                                      ~~~~ ^ ~~~~
/tmp/torchinductor_boh/i7/ci7asxgns7q3tcdqd7ptdfdnsn65pdggp5yb5mofl4qhemvc7lph.cpp:192:38: error: invalid cast from type ‘at::vec::CPU_CAPABILITY::Vectorized<float>’ to type ‘float’
  192 |                         auto tmp67 = float(tmp57 + tmp66);
      |                                      ^~~~~~~~~~~~~~~~~~~~
/tmp/torchinductor_boh/i7/ci7asxgns7q3tcdqd7ptdfdnsn65pdggp5yb5mofl4qhemvc7lph.cpp: In lambda function:
/tmp/torchinductor_boh/i7/ci7asxgns7q3tcdqd7ptdfdnsn65pdggp5yb5mofl4qhemvc7lph.cpp:254:47: warning: self-comparison always evaluates to false [-Wtautological-compare]
  254 |                             auto tmp25 = tmp2 < tmp2;
      |                                          ~~~~ ^ ~~~~
/tmp/torchinductor_boh/i7/ci7asxgns7q3tcdqd7ptdfdnsn65pdggp5yb5mofl4qhemvc7lph.cpp: In lambda function:
/tmp/torchinductor_boh/i7/ci7asxgns7q3tcdqd7ptdfdnsn65pdggp5yb5mofl4qhemvc7lph.cpp:280:47: warning: self-comparison always evaluates to false [-Wtautological-compare]
  280 |                             auto tmp43 = tmp2 < tmp2;
      |                                          ~~~~ ^ ~~~~


To execute this test, run the following from the base repo dir:
    python test/functorch/test_control_flow.py AssociativeScanTests.test_associative_scan_compile_fail_reverse_True_compile_mode_compile_dynamic_shape_combine_mode_generic_cpu_autograd_True

This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0

----------------------------------------------------------------------```

self, combine_mode, reverse, compile_mode, device, autograd
):
x = torch.randn(3, 10, 2, device=device, requires_grad=autograd)
kwargs = {
"dim": 0,
"reverse": reverse,
"compile_mode": compile_mode,
"combine_mode": combine_mode,
}
kwargs_fake = self._prepare_fake_kwargs(kwargs)
results = self._run_test(
model=AssociativeScanModels.Simple(**kwargs),
model_fake=AssociativeScanModels.Simple(**kwargs_fake),
inputs=x,
autograd_param=None if not autograd else (x,),
)

if not reverse:
results_torch = []
for op_pt in [torch.cumsum, torch.cumprod]:
results_torch.append(op_pt(x, 0))
self.assertEqual(results, results_torch)

# Jax Examples
x = torch.arange(
0, 4, device=device, dtype=torch.float32, requires_grad=autograd
)
kwargs = {
"dim": 0,
"reverse": reverse,
"compile_mode": compile_mode,
"combine_fn": get_scan_combine_fn("add", True),
"combine_mode": combine_mode,
}
kwargs_fake = self._prepare_fake_kwargs(kwargs)
result = self._run_test(
model=AssociativeScanModels.CombineFn(**kwargs),
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
inputs=x,
autograd_param=None if not autograd else (x,),
)

if not reverse:
results_torch = torch.tensor([0.0, 1.0, 3.0, 6.0], dtype=torch.float32)
else:
results_torch = torch.tensor([6.0, 6.0, 5.0, 3.0], dtype=torch.float32)

self.assertEqual(result, results_torch)

@unittest.skipIf(not SM70OrLater, "triton")
@requires_cuda
@parametrize("reverse", [False, True])
Expand Down
Loading
0