8000 [scan] Autograd with partial gradient support by bohnstingl · Pull Request #146285 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[scan] Autograd with partial gradient support #146285

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 54 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
f6586dd
WIP: Integration of scan autograd
bohnstingl Feb 1, 2025
bda8cb9
Introduced autograd for scan with partial gradients support
bohnstingl Feb 3, 2025
34be923
Fixed CI test issue with graph
bohnstingl Feb 6, 2025
5eadb1e
Merge branch 'main' of github.com:pytorch/pytorch into scan_autograd22
bohnstingl Feb 7, 2025
adfa593
Integrated code review comments
bohnstingl Feb 18, 2025
7b41bde
Merge branch 'main' of github.com:pytorch/pytorch into scan_autograd22
bohnstingl Feb 19, 2025
c2465cf
Fixed type annotation
bohnstingl Feb 20, 2025
5bc1d27
Removed reverse flag from backend and implemented reverse with torch.…
bohnstingl Feb 25, 2025
2786a68
Fix of graph in testcase
bohnstingl Feb 25, 2025
7f48b4f
Merge branch 'scan_flip_reverse' of github.com:bohnstingl/pytorch int…
bohnstingl Feb 26, 2025
6717f9b
Integrated new reverse handling into scan autograd
8000 bohnstingl Feb 26, 2025
a60d2a8
Merge branch 'main' of github.com:pytorch/pytorch into scan_autograd22
bohnstingl Mar 7, 2025
0343056
Fixed issues with testcases and with combine_fn now returning pytrees
bohnstingl Mar 7, 2025
0dda954
[cond] don't trace fw and bw graph in autograd key
ydwu4 Mar 10, 2025
c11cac4
Update on "[cond] don't trace fw and bw graph in autograd key"
ydwu4 Mar 17, 2025
ebb5b57
Update on "[cond] don't trace fw and bw graph in autograd key"
ydwu4 Mar 18, 2025
46608f8
Update on "[cond] don't trace fw and bw graph in autograd key"
ydwu4 Mar 18, 2025
26e4a53
Update on "[cond] don't trace fw and bw graph in autograd key"
ydwu4 Mar 19, 2025
813c883
Merge branch 'main' of github.com:pytorch/pytorch into scan_autograd22
bohnstingl Mar 19, 2025
9e4f3c3
Merge branch 'gh/ydwu4/222/head' of github.com:pytorch/pytorch into s…
bohnstingl Mar 19, 2025
114c982
Update base for Update on "[cond] don't trace fw and bw graph in auto…
ydwu4 Mar 20, 2025
fdbfe11
Update on "[cond] don't trace fw and bw graph in autograd key"
ydwu4 Mar 20, 2025
7796ea9
Update base for Update on "[cond] don't trace fw and bw graph in auto…
ydwu4 Mar 20, 2025
6082a60
Update on "[cond] don't trace fw and bw graph in autograd key"
ydwu4 Mar 20, 2025
d33705f
Update base for Update on "[cond] don't trace fw and bw graph in auto…
ydwu4 Mar 20, 2025
7cee41c
Update on "[cond] don't trace fw and bw graph in autograd key"
ydwu4 Mar 20, 2025
23170b1
WIP: scan autograd
bohnstingl Mar 21, 2025
0ac5a15
Merge branch 'gh/ydwu4/222/head' of github.com:pytorch/pytorch into s…
bohnstingl Mar 21, 2025
93f5bab
Update base for Update on "[cond] don't trace fw and bw graph in auto…
ydwu4 Mar 21, 2025
ad92afa
Update on "[cond] don't trace fw and bw graph in autograd key"
ydwu4 Mar 21, 2025
5947348
Merge branch 'gh/ydwu4/222/head' of github.com:pytorch/pytorch into s…
bohnstingl Mar 22, 2025
60fc08f
Working function of scan autograd with new infrastructure
bohnstingl Mar 24, 2025
75fe416
Removed unnecessary code and other cleanups
bohnstingl Mar 24, 2025
26c94e9
Merge branch 'main' of github.com:pytorch/pytorch into scan_autograd22
bohnstingl Mar 24, 2025
357dde5
Merge branch 'main' of github.com:pytorch/pytorch into scan_autograd22
bohnstingl Mar 26, 2025
b018397
Fixed tracing issue causing missing fake meta values
bohnstingl Mar 27, 2025
8a288d3
Merge branch 'main' of github.com:pytorch/pytorch into scan_autograd22
bohnstingl Mar 27, 2025
63414f9
Fixed CI test
bohnstingl Mar 28, 2025
66f0b37
Merge branch 'main' of github.com:pytorch/pytorch into scan_autograd22
bohnstingl Mar 28, 2025
b50e7b9
Merge branch 'main' of github.com:pytorch/pytorch into scan_autograd22
bohnstingl Mar 31, 2025
0db488f
Fixed testcase
bohnstingl Mar 31, 2025
9df0724
Integrated review feedback; before code cleanup
bohnstingl Apr 1, 2025
ffacadd
After cleanup and lintrunner
bohnstingl Apr 2, 2025
4ce5f06
Fixes
bohnstingl Apr 2, 2025
88084f5
Further cleanup and restructuring
bohnstingl Apr 2, 2025
473db76
WIP: integration of code review
bohnstingl Apr 3, 2025
e6ca8b0
Added extensive documentation and lintrunner corrections
bohnstingl Apr 3, 2025
d27d50a
Revert change to cond.py
bohnstingl Apr 3, 2025
0ee4d01
Removed unrelated testcases and fixed lint issues
bohnstingl Apr 3, 2025
66c0ae1
Integrated review changes for documentation and test cases
bohnstingl Apr 4, 2025
c976b5e
Added some more comments
bohnstingl Apr 8, 2025
afa5d96
Introduced materialize_as_graph for combine_fn and for combine_bw_fn
8000 bohnstingl Apr 9, 2025
6d67d3e
Merge branch 'main' of github.com:pytorch/pytorch into scan_autograd22
bohnstingl Apr 10, 2025
ebe8c22
Added bypass for autograd DispatchKey if no gradient for either the i…
bohnstingl Apr 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions test/functorch/test_control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3486,18 +3486,20 @@ def f(fct, init, xs):
"""\
def forward(self, fct_1, init_1, xs_1):
permute = torch.ops.aten.permute.default(xs_1, [0, 1, 2])
select_copy = torch.ops.aten.select_copy.int(permute, 0, 0)
flip = torch.ops.aten.flip.default(permute, [0]); permute = None
select_copy = torch.ops.aten.select_copy.int(flip, 0, 0)
add = torch.ops.aten.add.Tensor(init_1, select_copy); add = None
add_1 = torch.ops.aten.add.Tensor(init_1, select_copy); select_copy = add_1 = None
sym_size_int_1 = torch.ops.aten.sym_size.int(init_1, 1)
sym_size_int_2 = torch.ops.aten.sym_size.int(init_1, 2)
sym_size_int_3 = torch.ops.aten.sym_size.int(xs_1, 1)
sym_size_int_4 = torch.ops.aten.sym_size.int(xs_1, 2); xs_1 = None
scan_combine_graph_0 = self.scan_combine_graph_0
scan = torch.ops.higher_order.scan(scan_combine_graph_0, [init_1], [permute], True, [sym_size_int_1, sym_size_int_2, sym_size_int_3, sym_size_int_4]); scan_combine_graph_0 = init_1 = permute = sym_size_int_1 = sym_size_int_2 = sym_size_int_3 = sym_size_int_4 = None
scan = torch.ops.higher_order.scan(scan_combine_graph_0, [init_1], [flip], [sym_size_int_1, sym_size_int_2, sym_size_int_3, sym_size_int_4]); scan_combine_graph_0 = init_1 = flip = sym_size_int_1 = sym_size_int_2 = sym_size_int_3 = sym_size_int_4 = None
getitem = scan[0]
getitem_1 = scan[1]; scan = None
return (getitem, getitem_1)""", # noqa: B950
flip_1 = torch.ops.aten.flip.default(getitem_1, [0]); getitem_1 = None
return (getitem, flip_1)""", # noqa: B950
)

# Check graph
Expand All @@ -3516,10 +3518,11 @@ def forward(self, L_init_ : torch.Tensor, L_xs_ : torch.Tensor):
v = l_init_ + select_copy; v = None
x = l_init_ + select_copy; select_copy = x = None
scan_combine_fn_0 = self.scan_combine_fn_0
scan = torch.ops.higher_order.scan(scan_combine_fn_0, [l_init_], [elem], True, []); scan_combine_fn_0 = l_init_ = elem = None
scan = torch.ops.higher_order.scan(scan_combine_fn_0, [l_init_], [elem_1], []); scan_combine_fn_0 = l_init_ = elem_1 = None
getitem = scan[0]
getitem_1 = scan[1]; scan = None
return (getitem, getitem_1)""", # noqa: B950
elem_2 = scan[1]; scan = None
flip_1 = torch.flip(elem_2, [0]); elem_2 = None
return (getitem, flip_1)""", # noqa: B950
)

@skipIfNoDynamoSupport
Expand Down Expand Up @@ -7484,7 +7487,7 @@ def forward(self, L_init_ : torch.Tensor, L_xs_ : torch.Tensor, L_add_closure_0_
r_4 = r_3.add(l_add_closure_0_cell_contents_1_0_); r_3 = None
r_5 = r_4.sum(); r_4 = r_5 = None
scan_combine_fn_0 = self.scan_combine_fn_0
scan = torch.ops.higher_order.scan(scan_combine_fn_0, [l_init_], [r], False, [l_add_closure_0_cell_contents_0_param_, l_add_closure_0_cell_contents_1_0_]); scan_combine_fn_0 = l_init_ = r = l_add_closure_0_cell_contents_0_param_ = l_add_closure_0_cell_contents_1_0_ = None
scan = torch.ops.higher_order.scan(scan_combine_fn_0, [l_init_], [r], [l_add_closure_0_cell_contents_0_param_, l_add_closure_0_cell_contents_1_0_]); scan_combine_fn_0 = l_init_ = r = l_add_closure_0_cell_contents_0_param_ = l_add_closure_0_cell_contents_1_0_ = None
getitem = scan[0]
getitem_1 = scan[1]; scan = None
return (getitem, getitem_1)""", # noqa: B950
Expand All @@ -7505,7 +7508,7 @@ def forward(self, L_init_ : torch.Tensor, L_xs_ : torch.Tensor, L_add_closure_0_
ret = matmul_1 + l_add_closure_0_cell_contents_1_0_; matmul_1 = None
x = ret.sum(); ret = x = None
scan_combine_fn_0 = self.scan_combine_fn_0
scan = torch.ops.higher_order.scan(scan_combine_fn_0, [l_init_], [elem], False, [l_add_closure_0_cell_contents_0_param_, l_add_closure_0_cell_contents_1_0_]); scan_combine_fn_0 = l_init_ = elem = l_add_closure_0_cell_contents_0_param_ = l_add_closure_0_cell_contents_1_0_ = None
scan = torch.ops.higher_order.scan(scan_combine_fn_0, [l_init_], [elem], [l_add_closure_0_cell_contents_0_param_, l_add_closure_0_cell_contents_1_0_]); scan_combine_fn_0 = l_init_ = elem = l_add_closure_0_cell_contents_0_param_ = l_add_closure_0_cell_contents_1_0_ = None
getitem = scan[0]
getitem_1 = scan[1]; scan = None
return (getitem, getitem_1)""", # noqa: B950
Expand Down
9 changes: 3 additions & 6 deletions torch/_dynamo/variables/higher_order_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1429,12 +1429,10 @@ def call_function(

args, kwargs = LazyVariableTracker.realize_all((args, kwargs))

def arg_extractor(combine_fn, init, xs, reverse, additional_inputs):
return combine_fn, init, xs, reverse, additional_inputs
def arg_extractor(combine_fn, init, xs, additional_inputs):
return combine_fn, init, xs, additional_inputs

combine_fn, init, xs, reverse, additional_inputs = arg_extractor(
*args, **kwargs
)
combine_fn, init, xs, additional_inputs = arg_extractor(*args, **k 8000 wargs)
assert isinstance(additional_inputs, variables.BaseListVariable)

if xs.python_type() != list:
Expand Down Expand Up @@ -1542,7 +1540,6 @@ def _check_phs_position_match(
init_proxy,
xs_proxy,
# dim.as_proxy(),
reverse.as_proxy(),
additional_inputs_proxy,
)

Expand Down
46 changes: 21 additions & 25 deletions torch/_higher_order_ops/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,9 @@ def add(x: torch.Tensor, y: torch.Tensor):
for elem in leaves_xs_orig:
leaves_xs.append(torch.movedim(elem, dim, 0))

if reverse:
leaves_xs = [torch.flip(elem, [0]) for elem in leaves_xs]

out = combine_fn(
pytree.tree_unflatten(leaves_init, spec_init),
pytree.tree_unflatten([first_slice_copy(elem) for elem in leaves_xs], spec_xs),
Expand Down Expand Up @@ -393,10 +396,8 @@ def _check_new_carry_match_init(leaves_init, leaves_carry):
num_inp_leaves=len(leaves_xs),
)

def run_flattened_scan(combine_fn, leaves_init, leaves_xs, reverse):
return scan_op(
combine_fn, leaves_init, leaves_xs, reverse, additional_inputs=[]
)
def run_flattened_scan(combine_fn, leaves_init, leaves_xs):
return scan_op(combine_fn, leaves_init, leaves_xs, additional_inputs=[])

if not torch._dynamo.is_compiling():
from torch._dynamo.backends.debugging import (
Expand All @@ -415,16 +416,18 @@ def run_flattened_scan(combine_fn, leaves_init, leaves_xs, reverse):
combine_fn,
leaves_init,
leaves_xs,
reverse=reverse,
)
else:
result = run_flattened_scan(combine_fn, leaves_init, leaves_xs, reverse)
result = run_flattened_scan(combine_fn, leaves_init, leaves_xs)

result_carry, result_flat = _extract_carry_and_out(
result,
len(leaves_init),
)

if reverse:
result_flat = [torch.flip(elem, [0]) for elem in result_flat]

return pytree.tree_unflatten(result_carry, spec_init), pytree.tree_unflatten(
result_flat, tree_out
)
Expand All @@ -434,16 +437,16 @@ class ScanOp(HigherOrderOperator):
def __init__(self):
super().__init__("scan")

def __call__(self, combine_fn, init, xs, reverse, additional_inputs):
def __call__(self, combine_fn, init, xs, additional_inputs):
assert isinstance(additional_inputs, list), "additional_inputs must be a list."
validate_subgraph_args_types(additional_inputs)
return super().__call__(combine_fn, init, xs, reverse, additional_inputs)
return super().__call__(combine_fn, init, xs, additional_inputs)


scan_op = ScanOp()


def generic_scan(operator, init, xs, dim=0, reverse=False, additional_inputs=None):
def generic_scan(operator, init, xs, dim=0, additional_inputs=None):
additional_inputs = additional_inputs if additional_inputs is not None else []

def _scan(init, xs):
Expand All @@ -453,10 +456,7 @@ def _scan(init, xs):
return carry, []

num_elems = xs[0].shape[dim]
if reverse:
ind = num_elems - 1
else:
ind = 0
ind = 0

# Compute dummy shapes for the pre-allocation
num_init_leaves = len(init)
Expand Down Expand Up @@ -497,7 +497,7 @@ def store_out_in_outs(out, ind):
o.scatter_(0, ind * idx, x.unsqueeze(0))

for i in range(num_elems):
ind = i if not reverse else num_elems - i - 1
ind = i
carry, out = _extract_carry_and_out(
operator(
*carry,
Expand Down Expand Up @@ -537,7 +537,6 @@ def trace_scan(
combine_fn: Callable,
init: list[torch.Tensor],
xs: list[torch.Tensor],
reverse: bool,
additional_inputs: list[torch.Tensor],
):
from torch._dynamo.utils import clone_input
Expand Down Expand Up @@ -582,7 +581,7 @@ def trace_scan(

proxy_mode.tracer.root.register_module(combine_graph_name, combine_graph)

args = (combine_graph, init, xs, reverse, additional_inputs)
args = (combine_graph, init, xs, additional_inputs)
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, args)
out_proxy = proxy_mode.tracer.create_proxy(
"call_function", func_overload, proxy_args, {}, name="scan"
Expand All @@ -602,12 +601,10 @@ def trace_scan(


@scan_op.py_impl(DispatchKey.CompositeExplicitAutograd)
def scan_op_dense(combine_fn, init, xs, reverse, additional_inputs):
def scan_op_dense(combine_fn, init, xs, additional_inputs):
mode = _get_current_dispatch_mode()
assert mode is None, "Mode should never be enabled for CPU/CUDA key"
return generic_scan(
combine_fn, init, xs, reverse=reverse, additional_inputs=additional_inputs
)
return generic_scan(combine_fn, init, xs, additional_inputs=additional_inputs)


class ScanAutogradOp(torch.autograd.Function):
Expand Down Expand Up @@ -896,12 +893,12 @@ def scan_autograd(combine_fn, init, xs, reverse, additional_inputs):


@scan_op.py_impl(ProxyTorchDispatchMode)
def scan_proxy_mode(mode, combine_fn, init, xs, reverse, additional_inputs):
return trace_scan(mode, scan_op, combine_fn, init, xs, reverse, additional_inputs)
def scan_proxy_mode(mode, combine_fn, init, xs, additional_inputs):
return trace_scan(mode, scan_op, combine_fn, init, xs, additional_inputs)


@scan_op.py_impl(FakeTensorMode)
def scan_fake_tensor_mode(mode, combine_fn, init, xs, reverse, additional_inputs):
def scan_fake_tensor_mode(mode, combine_fn, init, xs, additional_inputs):
with mode:
scan_length = xs[0].shape[0]
carry, outputs = _extract_carry_and_out(
Expand All @@ -920,7 +917,7 @@ def scan_fake_tensor_mode(mode, combine_fn, init, xs, reverse, additional_inputs


@scan_op.py_functionalize_impl
def scan_functionalize(ctx, combine_fn, init, xs, reverse, additional_inputs):
def scan_functionalize(ctx, combine_fn, init, xs, additional_inputs):
unwrapped_xs = ctx.unwrap_tensors(xs)
unwrapped_init = ctx.unwrap_tensors(init)
unwrapped_additional_inputs = ctx.unwrap_tensors(additional_inputs)
Expand Down Expand Up @@ -951,7 +948,6 @@ def scan_functionalize(ctx, combine_fn, init, xs, reverse, additional_inputs):
functional_combine_fn,
unwrapped_init,
unwrapped_xs,
reverse,
unwrapped_additional_inputs,
)
return ctx.wrap_tensors(ret)
Expand Down
0