8000 [associative_scan] scan dim handling in user-facing associative_scan() by bohnstingl · Pull Request #139864 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[associative_scan] scan dim handling in user-facing associative_scan() #139864

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 22 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
460f753
Ensure that the combine_fn is only called with the proper slice of th…
bohnstingl Oct 24, 2024
1419a79
Fixed shape check
bohnstingl Oct 24, 2024
944649a
WIP: nested associative_scan
bohnstingl Oct 24, 2024
f974cf3
Incorporated first review round
bohnstingl Oct 26, 2024
ab0e515
Implemented better and more unified testing procedures
bohnstingl Oct 26, 2024
59b164b
Rebase to main
bohnstingl Oct 26, 2024
6dc7811
Lintrunner cleanup
bohnstingl Oct 26, 2024
308e89c
WIP: new _run_test interface
bohnstingl Oct 29, 2024
0a902eb
Integrated comments from PR and updated testcases
bohnstingl Oct 30, 2024
022a454
Integrated nested tuple for the vmap used in generic_associative_scan
bohnstingl Oct 30, 2024
8aeef66
Int 8000 egrated nit changes
bohnstingl Oct 31, 2024
9e01fff
Fixed minor issue with testcase parameters
bohnstingl Oct 31, 2024
90e9ac3
Rebased to associative_scan_70
bohnstingl Oct 31, 2024
ce619ea
Fixed rebasing issues
bohnstingl Oct 31, 2024
85a703c
Merge branch 'main' of github.com:pytorch/pytorch into associative_sc…
bohnstingl Nov 6, 2024
5564064
Merge branch 'main' of github.com:pytorch/pytorch into associative_sc…
bohnstingl Jan 18, 2025
56b299d
Fixed merge conflicts
bohnstingl Jan 18, 2025
f82f6b0
Corrections for lintrunner
bohnstingl Jan 18, 2025
c147eec
Created generic dim moving function that can be reused
bohnstingl Jan 19, 2025
845d366
Updated assertions
bohnstingl Jan 19, 2025
1921ab0
Merge branch 'main' of github.com:pytorch/pytorch into associative_sc…
bohnstingl Jan 23, 2025
a029df1
Removed wrapper around `shift_source_dim_to_target_dim`
bohnstingl Jan 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
< 8000 /div>
Prev Previous commit
Next Next commit
Integrated nit changes
  • Loading branch information
bohnstingl committed Oct 31, 2024
commit 8aeef6671cd9fa7fb5b60d83df566732888a56e5
89 changes: 28 additions & 61 deletions test/functorch/test_control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1624,27 +1624,14 @@ def test_scan_complex_pytree(self, reverse, device):

# TODO: Does not work because of the usage of vmap witin associative_scan
# The parameterization is commented out for the moment and the test is marked with expected fail
# Fails with: AssertionError: scan is not an OpOverload
@skipIfRocm(msg="Unsupported on ROCM yet")
@unittest.skipIf(not SM70OrLater, "triton")
@requires_cuda
# @parametrize("combine_mode", ["pointwise", "generic"])
# @parametrize("compile_mode_scan", ["none", "compile", "compile_dynamic_shape"])
# @parametrize("compile_mode_associative_scan", ["none", "eager", "compile", "compile_dynamic_shape"])
# @parametrize("reverse", [False, True])
# @parametrize("reverse_associative_scan", [False, True])
# @parametrize("device", [torch.device("cpu"), torch.device("cuda")])
# # Skipping combine_mode=pointwise
# # as the cond operation is classified to be non-pointwise
# @decorateIf(
# unittest.skip,
# lambda params: (
# params["combine_mode"] == "pointwise"
# ),
# )
@unittest.expectedFailure
def test_scan_associative_scan(self):
combine_mode = "generic"
compile_mode_scan = "none"
compile_mode_scan = "compile"
compile_mode_associative_scan = "none"
reverse = True
reverse_associative_scan = True
Expand Down Expand Up @@ -2537,6 +2524,11 @@ def _run_test(self, model, model_fake, inputs):
# Return the result of the functions under test for further investigations
return result

def _prepare_fake_kwargs(self, original_kwargs):
kwargs_fake = original_kwargs.copy()
kwargs_fake["compile_mode"] = "fake"
return kwargs_fake

@unittest.skipIf(not SM70OrLater, "triton")
@requires_cuda
@parametrize("reverse", [False, True])
Expand All @@ -2562,8 +2554,7 @@ def test_associative_scan_compile(
"compile_mode": compile_mode,
"combine_mode": combine_mode,
}
kwargs_fake = kwargs.copy()
kwargs_fake["compile_mode"] = "fake"
kwargs_fake = self._prepare_fake_kwargs(kwargs)
results = self._run_test(
model=AssociativeScanModels.Simple(**kwargs),
model_fake=AssociativeScanModels.Simple(**kwargs_fake),
Expand All @@ -2585,9 +2576,7 @@ def test_associative_scan_compile(
"combine_fn": get_scan_combine_fn("add", True),
"combine_mode": combine_mode,
}
kwargs_fake = kwargs.copy()
kwargs_fake["compile_mode"] = "fake"

kwargs_fake = self._prepare_fake_kwargs(kwargs)
result = self._run_test(
model=AssociativeScanModels.CombineFn(**kwargs),
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
Expand Down Expand Up @@ -2633,8 +2622,7 @@ def test_associative_scan_dim(self, combine_mode, compile_mode, reverse, device)
"compile_mode": compile_mode,
"combine_mode": combine_mode,
}
kwargs_fake = kwargs.copy()
kwargs_fake["compile_mode"] = "fake"
kwargs_fake = self._prepare_fake_kwargs(kwargs)
results = self._run_test(
model=AssociativeScanModels.Simple(**kwargs),
model_fake=AssociativeScanModels.Simple(**kwargs_fake),
Expand Down Expand Up @@ -2665,8 +2653,7 @@ def test_associative_scan_dim_shape_failure(self):
"compile_mode": "none",
"combine_mode": "generic",
}
kwargs_fake = kwargs.copy()
kwargs_fake["compile_mode"] = "fake"
kwargs_fake = self._prepare_fake_kwargs(kwargs)
self._run_test(
model=AssociativeScanModels.Simple(**kwargs),
model_fake=AssociativeScanModels.Simple(**kwargs_fake),
Expand Down Expand Up @@ -2701,8 +2688,7 @@ def test_associative_scan_tuple(self, compile_mode, combine_mode, reverse, devic
"combine_fn": get_scan_combine_fn("tuple_fct", True),
"combine_mode": combine_mode,
}
kwargs_fake = kwargs.copy()
kwargs_fake["compile_mode"] = "fake"
kwargs_fake = self._prepare_fake_kwargs(kwargs)
self._run_test(
model=AssociativeScanModels.CombineFn(**kwargs),
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
Expand Down Expand Up @@ -2730,8 +2716,7 @@ def combine_fn(x, y):
"combine_fn": combine_fn,
"combine_mode": "generic",
}
kwargs_fake = kwargs.copy()
kwargs_fake["compile_mode"] = "fake"
kwargs_fake = self._prepare_fake_kwargs(kwargs)
self._run_test(
model=AssociativeScanModels.CombineFn(**kwargs),
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
Expand All @@ -2756,8 +2741,7 @@ def test_associative_scan_non_contiguous_tensor(
"combine_fn": get_scan_combine_fn("add", True),
"combine_mode": "generic",
}
kwargs_fake = kwargs.copy()
kwargs_fake["compile_mode"] = "fake"
kwargs_fake = self._prepare_fake_kwargs(kwargs)
self._run_test(
model=AssociativeScanModels.CombineFn(**kwargs),
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
Expand Down Expand Up @@ -2794,8 +2778,7 @@ def test_associative_scan_complex_pytree(
"combine_fn": get_scan_combine_fn("complex_pointwise", True),
"combine_mode": combine_mode,
}
kwargs_fake = kwargs.copy()
kwargs_fake["compile_mode"] = "fake"
kwargs_fake = self._prepare_fake_kwargs(kwargs)
self._run_test(
model=AssociativeScanModels.CombineFn(**kwargs),
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
Expand All @@ -2820,7 +2803,6 @@ def test_associative_scan_complex_pytree(
def test_associative_scan_downstream_scan_matmul(
self, combine_mode, compile_mode, reverse, device
):
# Chain with matmul
def first_chain_fct(scan_fct, inp, **kwargs):
o = scan_fct(get_scan_combine_fn("add", True), inp, **kwargs)
return o
Expand All @@ -2837,9 +2819,7 @@ def second_chain_fct(scan_fct, inp, **kwargs):
"combine_fn": [first_chain_fct, second_chain_fct],
"combine_mode": combine_mode,
}
kwargs_fake = kwargs.copy()
kwargs_fake["compile_mode"] = "fake"

kwargs_fake = self._prepare_fake_kwargs(kwargs)
self._run_test(
model=AssociativeScanModels.ChainFn(**kwargs),
model_fake=AssociativeScanModels.ChainFn(**kwargs_fake),
Expand All @@ -2864,7 +2844,6 @@ def second_chain_fct(scan_fct, inp, **kwargs):
def test_associative_scan_downstream_scan_scan(
self, combine_mode, compile_mode, reverse, device
):
# Chain with associative_scan
def first_chain_fct(scan_fct, inp, **kwargs):
o1 = scan_fct(get_scan_combine_fn("add", True), inp, **kwargs)
return o1
Expand All @@ -2882,8 +2861,7 @@ def second_chain_fct(scan_fct, inp, **kwargs):
"combine_fn": [first_chain_fct, second_chain_fct],
"combine_mode": combine_mode,
}
kwargs_fake = kwargs.copy()
kwargs_fake["compile_mode"] = "fake"
kwargs_fake = self._prepare_fake_kwargs(kwargs)
self._run_test(
model=AssociativeScanModels.ChainFn(**kwargs),
model_fake=AssociativeScanModels.ChainFn(**kwargs_fake),
Expand Down Expand Up @@ -2911,7 +2889,6 @@ def test_associative_scan_downstream_scan_scan_different_dim(
):
reverse_second = reverse_first if same_direction else not reverse_first

# Chain with associative_scan on different dim
def first_chain_fct(scan_fct, inp, **kwargs):
o1 = scan_fct(get_scan_combine_fn("add", True), inp, **kwargs)
return o1
Expand All @@ -2929,8 +2906,7 @@ def second_chain_fct(scan_fct, inp, **kwargs):
"combine_fn": [first_chain_fct, second_chain_fct],
"combine_mode": [combine_mode, combine_mode],
}
kwargs_fake = kwargs.copy()
kwargs_fake["compile_mode"] = "fake"
kwargs_fake = self._prepare_fake_kwargs(kwargs)
self._run_test(
model=AssociativeScanModels.ChainFn(**kwargs),
model_fake=AssociativeScanModels.ChainFn(**kwargs_fake),
Expand All @@ -2941,7 +2917,7 @@ def second_chain_fct(scan_fct, inp, **kwargs):
# TODO: Re-enable additional parameters again once this issues has been resolved
@unittest.skipIf(not SM70OrLater, "triton")
@requires_cuda
@unittest.skip
@unittest.expectedFailure
def test_associative_scan_nested(self):
combine_mode = "pointwise"
compile_mode = "eager"
Expand Down Expand Up @@ -2979,8 +2955,7 @@ def second_nested_fct(x, y):
"combine_fn": first_nested_fct,
"combine_mode": combine_mode,
}
kwargs_fake = kwargs.copy()
kwargs_fake["compile_mode"] = "fake"
kwargs_fake = self._prepare_fake_kwargs(kwargs)
kwargs_fake["combine_fn"] = first_nested_fct_fake
self._run_test(
model=AssociativeScanModels.NestedFn(**kwargs),
Expand Down Expand Up @@ -3029,8 +3004,7 @@ def body_fn(ind, loop_val):
"combine_fn": combine_fn,
"combine_mode": "generic",
}
kwargs_fake = kwargs.copy()
kwargs_fake["compile_mode"] = "fake"
kwargs_fake = self._prepare_fake_kwargs(kwargs)
self._run_test(
model=AssociativeScanModels.CombineFn(**kwargs),
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
Expand Down Expand Up @@ -3067,8 +3041,7 @@ def body_fn(ind, loop_val):
"combine_fn": combine_fn,
"combine_mode": "generic",
}
kwargs_fake = kwargs.copy()
kwargs_fake["compile_mode"] = "fake"
kwargs_fake = self._prepare_fake_kwargs(kwargs)
self._run_test(
model=AssociativeScanModels.CombineFn(**kwargs),
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
Expand All @@ -3094,8 +3067,7 @@ def combine_fn(x, y):
"combine_fn": combine_fn,
"combine_mode": "generic",
}
kwargs_fake = kwargs.copy()
kwargs_fake["compile_mode"] = "fake"
kwargs_fake = self._prepare_fake_kwargs(kwargs)
self._run_test(
model=AssociativeScanModels.CombineFn(**kwargs),
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
Expand Down Expand Up @@ -3129,8 +3101,7 @@ def body(x, y):
"combine_fn": combine_fn,
"combine_mode": "generic",
}
kwargs_fake = kwargs.copy()
kwargs_fake["compile_mode"] = "fake"
kwargs_fake = self._prepare_fake_kwargs(kwargs)
self._run_test(
model=AssociativeScanModels.CombineFn(**kwargs),
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
Expand Down Expand Up @@ -3160,8 +3131,7 @@ def body(x):
"combine_fn": combine_fn,
"combine_mode": "generic",
}
kwargs_fake = kwargs.copy()
kwargs_fake["compile_mode"] = "fake"
kwargs_fake = self._prepare_fake_kwargs(kwargs)
self._run_test(
model=AssociativeScanModels.CombineFn(**kwargs),
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
Expand Down Expand Up @@ -3191,8 +3161,7 @@ def test_associative_scan_non_pointwise_generic(
"combine_fn": get_scan_combine_fn("non_pointwise", True),
"combine_mode": "generic",
}
kwargs_fake = kwargs.copy()
kwargs_fake["compile_mode"] = "fake"
kwargs_fake = self._prepare_fake_kwargs(kwargs)
self._run_test(
model=AssociativeScanModels.CombineFn(**kwargs),
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
Expand Down Expand Up @@ -3233,8 +3202,7 @@ def test_associative_scan_binary_operator(
"combine_fn": get_scan_combine_fn("s5_operator", True),
"combine_mode": combine_mode,
}
kwargs_fake = kwargs.copy()
kwargs_fake["compile_mode"] = "fake"
kwargs_fake = self._prepare_fake_kwargs(kwargs)
self._run_test(
model=AssociativeScanModels.CombineFn(**kwargs),
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
Expand All @@ -3260,10 +3228,9 @@ def test_associative_scan_sparse_tensor(self):

@unittest.skipIf(not SM70OrLater, "triton")
@requires_cuda
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
def test_associative_scan_combine_fn_wrong_meta_in_combine_fn(self, device):
B, N, C, H, W = 3, 3, 2, 3, 3
x = torch.randn(B, N, C, H, W, device=device)
x = torch.randn(B, N, C, H, W, device=torch.device("cuda"))

def fct_wrong_dtype(x, y):
return (x + y).to(torch.int64)
Expand Down
3 changes: 2 additions & 1 deletion test/inductor/test_control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,7 +829,8 @@ def test_associative_scan_CUDA_flip(self, combine_mode, backend, device):
def fct(x: torch.Tensor, y: torch.Tensor):
return x + y

for n in range(10):
# for n in range(10):
for n in [9]:
x = torch.arange(n, device=device)
torch.compiler.reset()
associative_scan1 = torch.compile(
Expand Down
Loading
0