8000 [associative_scan] scan dim handling in user-facing associative_scan() by bohnstingl · Pull Request #139864 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[associative_scan] scan dim handling in user-facing associative_scan() #139864

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 22 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
460f753
Ensure that the combine_fn is only called with the proper slice of th…
bohnstingl Oct 24, 2024
1419a79
Fixed shape check
bohnstingl Oct 24, 2024
944649a
WIP: nested associative_scan
bohnstingl Oct 24, 2024
f974cf3
Incorporated first review round
bohnstingl Oct 26, 2024
ab0e515
Implemented better and more unified testing procedures
bohnstingl Oct 26, 2024
59b164b
Rebase to main
bohnstingl Oct 26, 2024
6dc7811
Lintrunner cleanup
bohnstingl Oct 26, 2024
308e89c
WIP: new _run_test interface
bohnstingl Oct 29, 2024
0a902eb
Integrated comments from PR and updated testcases
bohnstingl Oct 30, 2024
022a454
Integrated nested tuple for the vmap used in generic_associative_scan
bohnstingl Oct 30, 2024
8aeef66
Integrated nit changes
bohnstingl Oct 31, 2024
8000
9e01fff
Fixed minor issue with testcase parameters
bohnstingl Oct 31, 2024
90e9ac3
Rebased to associative_scan_70
bohnstingl Oct 31, 2024
ce619ea
Fixed rebasing issues
bohnstingl Oct 31, 2024
85a703c
Merge branch 'main' of github.com:pytorch/pytorch into associative_sc…
bohnstingl Nov 6, 2024
5564064
Merge branch 'main' of github.com:pytorch/pytorch into associative_sc…
bohnstingl Jan 18, 2025
56b299d
Fixed merge conflicts
bohnstingl Jan 18, 2025
f82f6b0
Corrections for lintrunner
bohnstingl Jan 18, 2025
c147eec
Created generic dim moving function that can be reused
bohnstingl Jan 19, 2025
845d366
Updated assertions
bohnstingl Jan 19, 2025
1921ab0
Merge branch 'main' of github.com:pytorch/pytorch into associative_sc…
bohnstingl Jan 23, 2025
a029df1
Removed wrapper around `shift_source_dim_to_target_dim`
bohnstingl Jan 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions torch/_dynamo/variables/higher_order_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1232,10 +1232,10 @@ def call_function(

args, kwargs = LazyVariableTracker.realize_all((args, kwargs))

def arg_extractor(combine_fn, xs, dim):
return combine_fn, xs, dim
def arg_extractor(combine_fn, xs):
return combine_fn, xs

combine_fn, xs, dim = arg_extractor(*args, **kwargs)
combine_fn, xs = arg_extractor(*args, **kwargs)

if xs.python_type() != list:
unimplemented(
Expand All @@ -1248,7 +1248,7 @@ def arg_extractor(combine_fn, xs, dim):
# the sub_args shape will be (4, ).
with discard_graph_changes(tx):
sub_args = [
_make_inlined(tx, first_slice_copy)(leaf, dim)
_make_inlined(tx, first_slice_copy)(leaf)
for leaf in itertools.chain(xs.items, xs.items)
]
(
Expand Down Expand Up @@ -1277,7 +1277,7 @@ def arg_extractor(combine_fn, xs, dim):

xs_proxy = xs.as_proxy()
check_meta_consistency_vt(
[_make_inlined(tx, first_slice_copy)(t, dim) for t in xs.items],
[_make_inlined(tx, first_slice_copy)(t) for t in xs.items],
combine_result.unpack_var_sequence(tx),
"initial_xs",
"combine_fn_output",
Expand All @@ -1291,7 +1291,6 @@ def arg_extractor(combine_fn, xs, dim):
p_args = (
make_attr(tx, combine_fn_name),
xs_proxy,
dim.as_proxy(),
)

with tx.fake_mode:
Expand Down
49 changes: 27 additions & 22 deletions torch/_higher_order_ops/associative_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def wrap_combine_fn_flat(*args, combine_fn, spec, num_leaves):
return combined_leaves


def _interleave(a, b, dim):
def _interleave(a, b, dim=0):
# https://stackoverflow.com/questions/60869537/how-can-i-interleave-5-pytorch-tensors
if b_trunc := (a.shape[dim] == b.shape[dim] + 1):
pad = (
Expand Down Expand Up @@ -74,8 +74,8 @@ class AssociativeScanOp(HigherOrderOperator):
def __init__(self):
super().__init__("associative_scan")

def __call__(self, combine_fn, xs, dim):
return super().__call__(combine_fn, xs, dim)
def __call__(self, combine_fn, xs):
return super().__call__(combine_fn, xs)


associative_scan_op = AssociativeScanOp()
Expand Down Expand Up @@ -165,11 +165,13 @@ def add(x: torch.Tensor, y: torch.Tensor):
leaves = [torch.flip(elem, [dim]) for elem in leaves]

ndim = leaves[0].ndim
dim = utils.canonicalize_dim(ndim, dim)
orig_scan_dim = utils.canonicalize_dim(ndim, dim)
# leaves = [torch.movedim(elem, dim, 0) for elem in leaves]
leaves = [torch.movedim(elem, dim, 0) for elem in leaves]

# Call the combine_fn with only a slice along the scan dim
# and check whether the output leaves have the same slice dimensions
sliced_leaves = [first_slice_copy(leaf, dim) for leaf in leaves]
sliced_leaves = [first_slice_copy(leaf) for leaf in leaves]

out = combine_fn(
pytree.tree_unflatten(sliced_leaves, spec),
Expand Down Expand Up @@ -214,26 +216,29 @@ def add(x: torch.Tensor, y: torch.Tensor):
combine_fn=torch.vmap(
combine_fn,
in_dims=(
pytree.tree_unflatten([dim] * len(leaves), spec),
pytree.tree_unflatten([dim] * len(leaves), spec),
pytree.tree_unflatten([0] * len(leaves), spec),
pytree.tree_unflatten([0] * len(leaves), spec),
),
out_dims=dim,
out_dims=0,
),
spec=spec,
num_leaves=len(leaves),
)
result_flat = generic_associative_scan(combine_fn, leaves, dim)
result_flat = generic_associative_scan(combine_fn, leaves)
else:
combine_fn = functools.partial(
wrap_combine_fn_flat,
combine_fn=combine_fn,
spec=spec,
num_leaves=len(leaves),
)
result_flat = associative_scan_op(combine_fn, leaves, dim)
result_flat = associative_scan_op(combine_fn, leaves)

if reverse:
result_flat = [torch.flip(elem, [dim]) for elem in result_flat]
result_flat = [torch.flip(elem, [0]) for elem in result_flat]

# result_flat = [torch.movedim(elem, 0, orig_scan_dim) for elem in result_flat]
result_flat = [torch.movedim(elem, 0, orig_scan_dim) for elem in result_flat]

return pytree.tree_unflatten(result_flat, spec)

Expand Down Expand Up @@ -335,10 +340,10 @@ def _scan(elems):


def trace_associative_scan(
proxy_mode, func_overload, combine_fn: Callable, xs: list[torch.Tensor], dim: int
proxy_mode, func_overload, combine_fn: Callable, xs: list[torch.Tensor]
):
with disable_proxy_modes_tracing():
sample_xs = [first_slice_copy(x, dim) for x in itertools.chain(xs, xs)]
sample_xs = [first_slice_copy(x) for x in itertools.chain(xs, xs)]
combine_graph = reenter_make_fx(combine_fn)(*sample_xs)

outputs = None
Expand Down Expand Up @@ -369,7 +374,7 @@ def trace_associative_scan(

proxy_mode.tracer.root.register_module(combine_graph_name, combine_graph)

args = (combine_graph, xs, dim)
args = (combine_graph, xs)
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, args)
out_proxy = proxy_mode.tracer.create_proxy(
"call_function", func_overload, proxy_args, {}, name="associative_scan"
Expand All @@ -382,8 +387,8 @@ def trace_associative_scan(


@associative_scan_op.py_impl(DispatchKey.CompositeExplicitAutograd)
def associative_scan_op_dense(combine_fn, xs, dim):
return generic_associative_scan(combine_fn, xs, dim)
def associative_scan_op_dense(combine_fn, xs):
return generic_associative_scan(combine_fn, xs)


associative_scan_op.py_impl(DispatchKey.Autograd)(
Expand All @@ -392,28 +397,28 @@ def associative_scan_op_dense(combine_fn, xs, dim):


@associative_scan_op.py_impl(ProxyTorchDispatchMode)
def associative_scan_proxy_mode(mode, combine_fn, xs, dim):
return trace_associative_scan(mode, associative_scan_op, combine_fn, xs, dim)
def associative_scan_proxy_mode(mode, combine_fn, xs):
return trace_associative_scan(mode, associative_scan_op, combine_fn, xs)


@associative_scan_op.py_impl(FakeTensorMode)
def assoiciative_scan_fake_tensor_mode(mode, combine_fn, xs, dim):
def assoiciative_scan_fake_tensor_mode(mode, combine_fn, xs):
with mode:
return [x.clone() for x in xs]


@associative_scan_op.py_functionalize_impl
def associative_scan_functionalize(ctx, combine_fn, xs, dim):
def associative_scan_functionalize(ctx, combine_fn, xs):
unwrapped_xs = ctx.unwrap_tensors(xs)
with ctx.redispatch_to_next():
functional_combine_fn = ctx.functionalize(
_maybe_run_with_interpreter(combine_fn)
)
ret = associative_scan_op(functional_combine_fn, unwrapped_xs, dim)
ret = associative_scan_op(functional_combine_fn, unwrapped_xs)
return ctx.wrap_tensors(ret)


def _fake_associative_scan(combine_fn, xs, dim, reverse=False): # noqa: F811
def _fake_associative_scan(combine_fn, xs, dim, reverse=False):
inp_leaves, spec = pytree.tree_flatten(xs)
result_flat: list[Any] = []
num_leaves = len(inp_leaves)
Expand Down
4 changes: 2 additions & 2 deletions torch/_inductor/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -6722,7 +6722,7 @@ def invoke_subgraph(subgraph_fn: ir.Subgraph, identifier: str, operands):


@register_lowering(associative_scan_op, type_promotion_kind=None)
def associative_scan(combine_fn: ir.Subgraph, xs, dim: int):
def associative_scan(combine_fn: ir.Subgraph, xs):
from .subgraph_lowering import InputDescriptor, lower_pointwise_subgraph

subgraph_inputs = [
Expand All @@ -6737,7 +6737,7 @@ def wrapped_combine_fn(lhs, rhs):
*pytree.tree_leaves(rhs),
)

kwargs = _make_scan_inner(xs[0], axis=dim, dtype=None)
kwargs = _make_scan_inner(xs[0], axis=0, dtype=None)
kwargs["dtypes"] = tuple(x.get_dtype() for x in xs)
kwargs["inner_fns"] = tuple(x.make_loader() for x in xs)
result = ir.Scan.create(
Expand Down
Loading
0