8000 Improvements for associative_scan - Autograd by bohnstingl · Pull Request #136966 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Improvements for associative_scan - Autograd #136966

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 16 commits into from
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Working version
  • Loading branch information
bohnstingl committed Oct 22, 2024
commit ef000a608da82d7aa2598c9a792727bb8f18ea2b
285 changes: 250 additions & 35 deletions torch/_higher_order_ops/associative_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,13 @@ def get_gradient_mask(tensor_list):
def mask_gradient(grads, mask):
return [g for g, m in zip(grads, mask) if m]

# say we have a tensor of shape [3, 4, 5, 6]
# shift_source_dim_to_target_dim(t, 0, 3) -> [4, 5, 6, 3]
def shift_source_dim_to_target_dim(t, from_dim: int, to_dim: int):
assert to_dim >= 0 and to_dim < t.ndim
assert from_dim >= 0 and from_dim < t.ndim
return torch.movedim(t, from_dim, to_dim)


def wrap_combine_fn_flat(*args, combine_fn, spec, num_leaves):
assert len(args) == 2 * num_leaves
Expand Down Expand Up @@ -186,11 +193,11 @@ def add(x: torch.Tensor, y: torch.Tensor):
"Combine_mode must either 'pointwise' or 'generic', but got {combine_mode}"
)

if not torch._dynamo.is_compiling():
with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit():
return torch.compile(associative_scan, fullgraph=True)(
combine_fn, xs, dim, reverse=reverse, combine_mode=combine_mode
)
# if not torch._dynamo.is_compiling():
# with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit():
# return torch.compile(associative_scan, fullgraph=True)(
# combine_fn, xs, dim, reverse=reverse, combine_mode=combine_mode
# )

leaves, spec = pytree.tree_flatten(xs)

Expand All @@ -206,11 +213,17 @@ def add(x: torch.Tensor, y: torch.Tensor):

if reverse:
leaves = [torch.flip(elem, [dim]) for elem in leaves]

# # Move scan dim to 0 and always perform scan on dim 0
# leaves = [
# shift_source_dim_to_target_dim(elem, int(dim), 0) for elem in leaves
# ]
# dim = 0

shape = leaves[0].shape
ndim = len(shape)
dim = utils.canonicalize_dim(ndim, dim)

for x in leaves[1:]:
assert x.shape == shape, "All xs tensors must have the same shape"

Expand Down Expand Up @@ -396,6 +409,7 @@ def trace_associative_scan(

@associative_scan_op.py_impl(DispatchKey.CompositeExplicitAutograd)
def associative_scan_op_dense(combine_fn, xs, dim):
return generic_associative_scan(combine_fn, xs, dim)
raise NotImplementedError("associative_scan is not implemented for eager")


Expand Down Expand Up @@ -427,6 +441,170 @@ def forward(
outs = associative_scan_op(fw_graph, xs, dim)
ctx.save_for_backward(*(*xs, *outs))

# #BWD in FWD
# flat_g 8000 rads = [torch.ones_like(o) for o in outs]
# # flat_grads = [torch.arange(1, scan_length + 1, device=o.device, dtype=o.dtype) for o in outs]

# dim = int(ctx._dim)
# scan_length = ctx._scan_length
# num_xs = ctx._num_xs
# gradient_mask = ctx._gradient_mask
# num_xs_masked = sum(gradient_mask)

# # Extract the inputs to the forward path and outputs from the forward path
# # flat_args = ctx.saved_tensors
# # xs, outs = flat_args[:num_xs], flat_args[num_xs:]

# # Helper variables
# ones = torch.unsqueeze(torch.ones_like(first_slice_copy(xs[0], dim)), dim)
# zeros = torch.unsqueeze(torch.zeros_like(first_slice_copy(xs[0], dim)), dim)
# shifted_outs = [torch.concat([ones, aten.slice(o, dim, 0, -1, 1)], dim) for o in outs]

# # vmap joint graph over scan dimension
# mapped_joint_graph = ctx._mapped_joint_graph

# def expand_grads_with_None(real_grads, mask):
# g_list = []
# true_cnt = 0
# for m in mask:
# if m:
# g_list.append(real_grads[true_cnt])
# true_cnt += 1
# else:
# g_list.append(None)
# return g_list

# # Mask the gradients for the variables that do not require gradients for partial gradient support
# flat_grads = [fg for fg, m in zip(flat_grads, gradient_mask) if m]

# # Function to compute the gradients with respect
# # *) to the inputs (xs) -> grads_xs
# # *) to the previous outputs -> grads_hs
# def compute_grad_hs_xs():

# # Compute the partial grads_x and grads_h by only setting part of the gradients for the joint_graph to 1
# def compute_part_grads(flat_grad_ind):
# flat_grads_init = [torch.ones_like(x) if flat_grad_ind == ind else torch.zeros_like(x) for ind, x in enumerate(xs)]
# grads = mapped_joint_graph(*flat_grads_init, *shifted_outs, *xs)
# return *grads,

# # Compute all the partial gradients
# grad_parts = [torch.unsqueeze(g, 0) for g in compute_part_grads(0)]
# for part_ind in range(1, num_xs):
# grad_parts = [torch.concat([gp, torch.unsqueeze(g, 0)], 0) for gp, g in zip(grad_parts, compute_part_grads(part_ind))]

# return grad_parts

# # Compute the grads_xs and grads_hs by collecting all the partial gradients
# grads_intermediate = compute_grad_hs_xs()
# grads_hs, grads_xs = grads_intermediate[:num_xs_masked], grads_intermediate[num_xs_masked:]

# # In case of the associative_scan, the first output mirrors the first scan element of xs
# # Therefore, the first grads_hs are all zeros and the first grads_xs are ones
# grads_hs = [torch.concat([torch.zeros_like(aten.slice(g, dim + 1, 0, 1, 1)), aten.slice(g, dim + 1, 1, None, 1)], dim + 1) for g in grads_hs]
# grads_xs = [torch.concat([torch.stack([torch.ones_like(gp) if ind == 0 else torch.zeros_like(gp) for ind, gp in enumerate(aten.slice(g, dim + 1, 0, 1, 1))], 0), aten.slice(g, dim + 1, 1, None, 1)], dim + 1) for g in grads_xs]

# # Compute the cumprod of the rows of the gradient matrix and fill the remainder with zeros
# def cumprod_and_prepad(fg, val, size):
# return torch.concat([zeros] * max(size - 1 - val.shape[dim + 1], 0) + [ones * fg] + [fg * torch.sum(torch.cumprod(val, dim + 1), 0)], dim)

# # Compute the gradients for a single element of xs
# # The computations are done on dim + 1, because g_x and g_h have all partial gradients on dim 0
# # The partial gradients are combined in the process of this function
# def compute_grad_xs(fg, g_x, g_h):
# g_x = torch.flip(g_x, [dim + 1])
# g_h = torch.flip(g_h, [dim + 1])
# fg = torch.concat([torch.flip(fg, [dim]), ones], dim)

# # Create the matrix consisting of
# gradient_mat = [cumprod_and_prepad(aten.slice(fg, dim, n, n + 1, 1), aten.slice(g_h, dim + 1, n, -1, 1), scan_length) for n in range(0, scan_length, 1)]
# # print(g_h)
# # print(torch.stack(gradient_mat, 0))
# # print(torch.sum(g_x, 0))
# # print([aten.slice(fg, dim, n, n + 1, 1) for n in range(0, scan_length, 1)])
# grads = torch.flip(torch.sum(torch.stack(gradient_mat, 0) * torch.sum(g_x, 0), 0), [dim])
# return grads

# # Compute the gradients in parallel for all elements of xs
# compute_grad_xs_mapped = torch.vmap(compute_grad_xs, 0, 0)
# grads_old = [torch.squeeze(el, 0) for el in torch.split(compute_grad_xs_mapped(torch.stack(flat_grads, 0), torch.stack(grads_xs, 0), torch.stack(grads_hs, 0)), 1, 0)]

# # grads_old = compute_grad_xs(flat_grads[0], grads_xs[0], grads_hs[0])

# # Expand the gradients with Nones for partial gradient support
# grads_old = expand_grads_with_None(grads_old, gradient_mask)


# # Function to compute the gradients with respect
# # *) to the inputs (xs) -> grads_xs
# # *) to the previosus outputs -> grads_hs
# def compute_grad_hs_xs():

# # Compute the partial grads_x and grads_h by only setting part of the gradients for the joint_graph to 1
# def compute_part_grads(flat_grad_ind):
# flat_grads_init = [torch.ones_like(x) if flat_grad_ind == ind else torch.zeros_like(x) for ind, x in enumerate(xs)]
# grads = mapped_joint_graph(*flat_grads_init, *shifted_outs, *xs)
# return *grads,

# # Compute all the partial gradients
# grad_parts = [torch.unsqueeze(g, 0) for g in compute_part_grads(0)]
# for part_ind in range(1, num_xs):
# grad_parts = [torch.concat([gp, torch.unsqueeze(g, 0)], 0) for gp, g in zip(grad_parts, compute_part_grads(part_ind))]

# return grad_parts

# # Compute the grads_xs and grads_hs by collecting all the partial gradients
# grads_intermediate = compute_grad_hs_xs()
# grads_h_parts, grads_x_parts = grads_intermediate[:num_xs_masked], grads_intermediate[num_xs_masked:]

# grads_h_parts = [torch.concat([torch.zeros_like(aten.slice(g, dim + 1, 0, 1, 1)), aten.slice(g, dim + 1, 1, None, 1)], dim + 1) for g in grads_h_parts]
# grads_h_parts2 = [torch.concat([torch.zeros_like(aten.slice(g, dim + 1, 0, 1, 1)), torch.zeros_like(aten.slice(g, dim + 1, 0, 1, 1)), aten.slice(g, dim + 1, 1, -1, 1)], dim + 1) for g in grads_h_parts]
# grads_x_parts = [torch.concat([torch.stack([torch.ones_like(gp) if ind == 0 else torch.zeros_like(gp) for ind, gp in enumerate(aten.slice(g, dim + 1, 0, 1, 1))], 0), aten.slice(g, dim + 1, 1, None, 1)], dim + 1) for g in grads_x_parts]

# grads_h_parts = [torch.flip(gh_p, [dim + 1]) for gh_p in grads_h_parts]
# grads_x = [torch.flip(torch.sum(gx_p, 0), [dim]) for gx_p in grads_x_parts]
# flat_grads = [torch.concat([torch.flip(fg, [dim]), ones], dim) for fg in flat_grads]

# # Compute the cumprod of the rows of the gradient matrix and fill the remainder with zeros
# def select(val, size):
# return aten.slice(val, dim + 1, size, -1)
# # return aten.slice(val, dim + 1, size + 1, None)

# def prepad(val, size):
# return torch.concat([zeros] * size + [ones] + [val], dim)

# def mul(x, y):
# return x * y

# grads_h_prod_mat = [torch.stack([prepad(torch.sum(torch.cumprod(select(gh_p, i), dim + 1), 0), i) * aten.slice(fg, dim, i, i + 1, 1) for i in range(scan_length)], 0) for gh_p, fg in zip(grads_h_parts, flat_grads)]

# ones_t = torch.ones((scan_length, scan_length), device=outs[0].device)
# shape = grads_h_parts[0].shape
# def expand_to_equal_dims(el):
# while len(el.shape) < (len(shape) + 1):
# el = torch.unsqueeze(el, -1)
# return el
# triu = torch.unsqueeze(torch.triu(ones_t, diagonal=1), 0)
# triu = expand_to_equal_dims(triu)
# tril = torch.unsqueeze(torch.tril(ones_t, diagonal=0), 0)
# tril = expand_to_equal_dims(tril)
# tril2 = torch.unsqueeze(torch.tril(ones_t, diagonal=-1), 0)
# tril2 = expand_to_equal_dims(tril2)

# def create_grads_h_matrix(gh_p):
# h_mat = gh_p.expand(shape[0:1] + shape[1:2] + shape[1:])
# h_mat = h_mat * triu + tril
# return h_mat

# grads_h_prod_mat2 = [torch.sum(torch.cumprod(create_grads_h_matrix(gh_p), dim + 2) - tril2, 0) * aten.slice(fg, dim, 0, -1, 1) for gh_p, fg in zip(grads_h_parts2, flat_grads)]

# grads = [torch.flip(torch.sum(gh_p * gx, 0), [dim]) for gh_p, gx in zip(grads_h_prod_mat, grads_x)]
# grads2 = [torch.flip(torch.sum(gh_p * gx, 0), [dim]) for gh_p, gx in zip(grads_h_prod_mat2, grads_x)]
# grads = expand_grads_with_None(grads, gradient_mask)

# if any(torch.max(torch.abs(g - go)) > 1e-4 for g, go in zip(grads, grads_old)):
# print('Failed')

return *outs,

@staticmethod
Expand Down Expand Up @@ -532,7 +710,7 @@ def expand_grads_with_None(real_grads, mask):

# Function to compute the gradients with respect
# *) to the inputs (xs) -> grads_xs
# *) to the previous outputs -> grads_hs
# *) to the previosus outputs -> grads_hs
def compute_grad_hs_xs():

# Compute the partial grads_x and grads_h by only setting part of the gradients for the joint_graph to 1
Expand All @@ -550,35 +728,45 @@ def compute_part_grads(flat_grad_ind):

# Compute the grads_xs and grads_hs by collecting all the partial gradients
grads_intermediate = compute_grad_hs_xs()
grads_hs, grads_xs = grads_intermediate[:num_xs_masked], grads_intermediate[num_xs_masked:]

# In case of the associative_scan, the first output mirrors the first scan element of xs
# Therefore, the first grads_hs are all zeros and the first grads_xs are only ones
grads_hs = [torch.concat([torch.zeros_like(aten.slice(g, dim + 1, 0, 1, 1)), aten.slice(g, dim + 1, 1, None, 1)], dim + 1) for g in grads_hs]
grads_xs = [torch.concat([torch.stack([torch.ones_like(gp) if ind == 0 else torch.zeros_like(gp) for ind, gp in enumerate(aten.slice(g, dim + 1, 0, 1, 1))], 0), aten.slice(g, dim + 1, 1, None, 1)], dim + 1) for g in grads_xs]

# Compute the cumprod of the rows of the gradient matrix and fill the remainder with zeros
def cumprod_and_prepad(fg, val, size):
return torch.concat([zeros] * max(size - 1 - val.shape[dim + 1], 0) + [ones * fg] + [fg * torch.sum(torch.cumprod(val, dim + 1), 0)], dim)

# Compute the gradients for a single element of xs
# The computations are done on dim + 1, because g_x and g_h have all partial gradients on dim 0
# The partial gradients are combined in the process of this function
def compute_grad_xs(fg, g_x, g_h):
g_x = torch.flip(g_x, [dim + 1])
g_h = torch.flip(g_h, [dim + 1])
fg = torch.concat([torch.flip(fg, [dim]), ones], dim)

# Create the matrix consisting of
gradient_mat = [cumprod_and_prepad(aten.slice(fg, dim, n, n + 1, 1), aten.slice(g_h, dim + 1, n, -1, 1), scan_length) for n in range(0, scan_length, 1)]
grads = torch.flip(torch.sum(torch.stack(gradient_mat, 0) * torch.sum(g_x, 0), 0), [dim])
return grads
grads_h_parts, grads_x_parts = grads_intermediate[:num_xs_masked], grads_intermediate[num_xs_masked:]

# Compute the gradients in parallel for all elements of xs
compute_grad_xs_mapped = torch.vmap(compute_grad_xs, 0, 0)
grads = [torch.squeeze(el, 0) for el in torch.split(compute_grad_xs_mapped(torch.stack(flat_grads, 0), torch.stack(grads_xs, 0), torch.stack(grads_hs, 0)), 1, 0)]

# Expand the gradients with Nones for partial gradient support
grads_h_parts = [torch.concat([torch.zeros_like(aten.slice(g, dim + 1, 0, 1, 1)), aten.slice(g, dim + 1, 1, None, 1)], dim + 1) for g in grads_h_parts]
# grads_h_parts = [torch.concat([torch.zeros_like(aten.slice(g, dim + 1, 0, 1, 1)), torch.zeros_like(aten.slice(g, dim + 1, 0, 1, 1)), aten.slice(g, dim + 1, 1, -1, 1)], dim + 1) for g in grads_h_parts]
grads_x_parts = [torch.concat([torch.stack([torch.ones_like(gp) if ind == 0 else torch.zeros_like(gp) for ind, gp in enumerate(aten.slice(g, dim + 1, 0, 1, 1))], 0), aten.slice(g, dim + 1, 1, None, 1)], dim + 1) for g in grads_x_parts]

grads_h_parts = [torch.flip(gh_p, [dim + 1]) for gh_p in grads_h_parts]
grads_x = [torch.flip(torch.sum(gx_p, 0), [dim]) for gx_p in grads_x_parts]
flat_grads = [torch.concat([torch.flip(fg, [dim]), ones], dim) for fg in flat_grads]

def select(val, size):
return aten.slice(val, dim + 1, size, -1, 1)

def prepad(val, size):
return torch.concat([zeros] * size + [ones] + [val], dim)

grads_h_prod_mat = [torch.stack([prepad(torch.sum(torch.cumprod(select(gh_p, i), dim + 1), 0), i) * aten.slice(fg, dim, i, i + 1, 1) for i in range(scan_length)], 0) for gh_p, fg in zip(grads_h_parts, flat_grads)]

# ones_t = torch.ones((scan_length, scan_length), device=outs[0].device)
# shape = grads_h_parts[0].shape
# def expand_to_equal_dims(el):
# while len(el.shape) < (len(shape) + 1):
# el = torch.unsqueeze(el, -1)
# return el
# triu = torch.unsqueeze(torch.triu(ones_t, diagonal=1), 0)
# triu = expand_to_equal_dims(triu)
# tril = torch.unsqueeze(torch.tril(ones_t, diagonal=0), 0)
# tril = expand_to_equal_dims(tril)
# tril2 = torch.unsqueeze(torch.tril(ones_t, diagonal=-1), 0)
# tril2 = expand_to_equal_dims(tril2)

# def create_grads_h_matrix(gh_p):
# h_mat = gh_p.expand(shape[0:1] + shape[1:2] + shape[1:])
# h_mat = h_mat * triu + tril
# return h_mat

# grads_h_prod_mat = [torch.sum(torch.cumprod(create_grads_h_matrix(gh_p), dim + 2) - tril2, 0) * aten.slice(fg, dim, 0, -1, 1) for gh_p, fg in zip(grads_h_parts, flat_grads)]

grads = [torch.flip(torch.sum(gh_p * gx, 0), [dim]) for gh_p, gx in zip(grads_h_prod_mat, grads_x)]
grads = expand_grads_with_None(grads, gradient_mask)

return *[None] * 4, *grads
Expand Down Expand Up @@ -643,3 +831,30 @@ def associative_scan_functionalize(ctx, combine_fn, xs, dim):
)
ret = associative_scan_op(functional_combine_fn, unwrapped_xs, dim)
return ctx.wrap_tensors(ret)

def _fake_associative_scan(combine_fn, xs, dim, reverse=False):
inp_leaves, spec = pytree.tree_flatten(xs)
result_flat = []
num_leaves = len(inp_leaves)
op = reversed if reverse else lambda x: x

for ind in op(range(inp_leaves[0].size(dim))):
r = [
inp_leaves[leave_ind][(slice(None),) * dim + (ind,)]
for leave_ind in range(num_leaves)
]
if (ind > 0 and not reverse) or (
ind < (inp_leaves[0].size(dim) - 1) and reverse
):
r = combine_fn(
pytree.tree_unflatten(result_flat[-1], spec),
pytree.tree_unflatten(r, spec),
)
r_flat, _ = pytree.tree_flatten(r)
result_flat.append(r_flat)

results = [
torch.stack([e[leave_ind] for e in op(result_flat)], dim)
for leave_ind in range(num_leaves)
]
return pytree.tree_unflatten(results, spec)
0