10000 Improvements for associative_scan - Autograd by bohnstingl · Pull Request #136966 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Improvements for associative_scan - Autograd #136966

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 16 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Working version Autograd for the binary operator
  • Loading branch information
bohnstingl committed Oct 22, 2024
commit e89fdf36aff1d26e00947605ad1bdfb5c8a420f9
50 changes: 48 additions & 2 deletions test/functorch/test_control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,15 @@ def _fake_associative_scan(combine_fn, xs, dim, reverse=False):
inp_leaves, spec = pytree.tree_flatten(xs)
result_flat = []
num_leaves = len(inp_leaves)
# inp_leaves = [torch.flip(x, [dim]) for x in inp_leaves]
# from torch._higher_order_ops.associative_scan import generic_associative_scan, wrap_combine_fn_flat

# combine_fn = functools.partial(
# wrap_combine_fn_flat, combine_fn=combine_fn, spec=spec, num_leaves=len(inp_leaves)
# )
# res = generic_associative_scan(combine_fn, inp_leaves, dim)
# return res

op = reversed if reverse else lambda x: x

for ind in op(range(inp_leaves[0].size(dim))):
Expand Down Expand Up @@ -142,6 +151,9 @@ def s5_operator(x, y):
A_i, Bu_i = x
A_j, Bu_j = y
return A_j * A_i, A_j * Bu_i + Bu_j
# return 5 * A_j * A_i, 5 * A_j * Bu_i + Bu_j
# return A_i * A_j, Bu_i * Bu_j
# return A_i + Bu_j, A_j + Bu_i

def tuple_fct(x, y):
return (x[0] + y[0], x[1] * y[1])
Expand Down Expand Up @@ -1635,8 +1647,36 @@ def test_associative_scan_binary_operator(self, combine_mode, reverse, device, a
projected_inputs = torch.randn(
timesteps, state_dim, requires_grad=autograd, device=device
)
# projected_inputs = torch.tile(torch.unsqueeze(torch.arange(1, 2, requires_grad=autograd, device=device, dtype=torch.float32), 0), (timesteps, 1))
A = torch.randn(state_dim, requires_grad=autograd, device=device)
elements = (A.repeat((timesteps, 1)), projected_inputs)
# A = torch.tile(torch.unsqueeze(torch.arange(1, 2, requires_grad=autograd, device=device, dtype=torch.float32), 0), (timesteps, 1))
# elements = (A, projected_inputs)

# Other frameworks
# elements_np = tuple([x.cpu().detach().numpy() for x in elements])

# import tensorflow as tf
# import tensorflow_probability as tfp

# elements_tf = tuple([tf.convert_to_tensor(x) for x in elements_np])
# with tf.GradientTape() as g:
# g.watch(elements_tf)
# ret = tfp.math.scan_associative(get_scan_combine_fn("s5_operator", True), elements_tf)

# grads_tf = g.gradient(ret, elements_tf)

# JAX
# import jax.numpy as jnp
# from jax import grad, value_and_grad
# from jax import lax
# elements_jax = tuple([jnp.array(x) for x in elements_np])

# def fun(el):
# return lax.associative_scan(s5_operator, el, axis=0)
# # resul_jax = lax.associative_scan(s5_operator, elements_jax, axis=0)
# vg = value_and_grad(fun)
# ret = vg(elements_jax)

result = associative_scan(
get_scan_combine_fn("s5_operator", True),
Expand Down Expand Up @@ -1666,8 +1706,14 @@ def test_associative_scan_binary_operator(self, combine_mode, reverse, device, a
grads = torch.autograd.grad(
result_flatten, (*elements_flatten,), grad_out
)
print([torch.sum(g) for g in grads])
print([torch.sum(g) for g in expected_grads])
# # print([torch.sum(g) for g in grads])
# print(grads)
# print('-'*50)
# # print([torch.sum(g) for g in expected_grads])
# print(expected_grads)
# print('-'*50)
# print(grads_tf)
# print('-'*50)
self.assertEqual(grads, expected_grads)

@requires_cuda
Expand Down
131 changes: 122 additions & 9 deletions torch/_higher_order_ops/associative_scan.py
< EDBE td class="blob-num blob-num-addition empty-cell">
Original file line number Diff line number Diff line change
Expand Up @@ -187,11 +187,11 @@ def add(x: torch.Tensor, y: torch.Tensor):
"Combine_mode must either 'pointwise' or 'generic', but got {combine_mode}"
)

if not torch._dynamo.is_compiling():
with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit():
return torch.compile(associative_scan, fullgraph=True)(
combine_fn, xs, dim, reverse=reverse, combine_mode=combine_mode
)
# if not torch._dynamo.is_compiling():
# with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit():
# return torch.compile(associative_scan, fullgraph=True)(
# combine_fn, xs, dim, reverse=reverse, combine_mode=combine_mode
# )

leaves, spec = pytree.tree_flatten(xs)

Expand Down Expand Up @@ -237,6 +237,8 @@ def add(x: torch.Tensor, y: torch.Tensor):
result_flat = generic_associative_scan(combine_fn, leaves, dim)
else:
result_flat = associative_scan_op(combine_fn, leaves, dim)

# result_flat = generic_associative_scan(combine_fn, leaves, dim)

if reverse:
result_flat = [torch.flip(elem, [dim]) for elem in result_flat]
Expand Down Expand Up @@ -397,7 +399,8 @@ def trace_associative_scan(

@associative_scan_op.py_impl(DispatchKey.CompositeExplicitAutograd)
def associative_scan_op_dense(combine_fn, xs, dim):
raise NotImplementedError("associative_scan is not implemented for eager")
# raise NotImplementedError("associative_scan is not implemented for eager")
return generic_associative_scan(combine_fn, xs, dim)


class ScanAutogradOp(torch.autograd.Function):
Expand Down Expand Up @@ -428,6 +431,112 @@ def forward(
outs = associative_scan_op(fw_graph, xs, dim)
ctx.save_for_backward(*(*xs, *outs))

#BWD in FWD
gradient_mask = ctx._gradient_mask
num_xs_masked = sum(gradient_mask)

flat_grads = [torch.ones_like(el) for el in xs]

# Helper variables
ones = torch.unsqueeze(torch.ones_like(first_slice_copy(xs[0], dim)), dim)
zeros = torch.unsqueeze(torch.zeros_like(first_slice_copy(xs[0], dim)), dim)
shifted_outs = [torch.concat([ones, aten.slice(o, dim, 0, -1, 1)], dim) for o in outs]

xs_cropped = [aten.slice(x, dim, 1, None, 1) for x in xs]
shifted_outs_cropped = [aten.slice(x, dim, 1, None, 1) for x in shifted_outs]
mapped_joint_graph = ctx._mapped_joint_graph

def expand_grads_with_None(real_grads, mask):
g_list = []
true_cnt = 0
for m in mask:
if m:
g_list.append(real_grads[true_cnt])
true_cnt += 1
else:
g_list.append(None)
return g_list

# Compute the gradients of the loss output with respect to x and h
flat_grads = [fg for fg, m in zip(flat_grads, gradient_mask) if m]

# Function to compute the gradients with respect
# *) to the inputs (xs) -> grads_xs
# *) to the previous outputs -> grads_hs
def compute_grad_hs_xs():

# Compute the partial grads_x and grads_h by only setting part of the gradients for the joint_graph to 1
def compute_part_grads(flat_grad_ind):
flat_grads_init = [torch.ones_like(x) if flat_grad_ind == ind else torch.zeros_like(x) for ind, x in enumerate(xs)]
# flat_grads_init = [torch.ones_like(x) if flat_grad_ind == ind else torch.zeros_like(x) for ind, x in enumerate(xs_cropped)]
grads = mapped_joint_graph(*flat_grads_init, *shifted_outs, *xs)
# grads = mapped_joint_graph(*flat_grads_init, *shifted_outs_cropped, *xs_cropped)
return *grads,

# Compute all the partial gradients
grad_parts = [torch.unsqueeze(g, 0) for g in compute_part_grads(0)]
# grad_parts = list(compute_part_grads(0))
# grad_parts = [torch.concat([zeros, aten.slice(g, dim, 1, None, 1)], dim) for g in grad_parts[:num_xs_masked]] + grad_parts[num_xs_masked:]
# grad_parts = [torch.unsqueeze(g, 0) for g in grad_parts]
for part_ind in range(1, num_xs):
grad_parts = [torch.concat([gp, torch.unsqueeze(g, 0)], 0) for gp, g in zip(grad_parts, compute_part_grads(part_ind))]
# grad_parts = list(compute_part_grads(part_ind))
# grad_parts = [torch.concat([zeros, aten.slice(g, dim, 1, None, 1)], dim) for g in grad_parts[:num_xs_masked]] + grad_parts[num_xs_masked:]
# grad_parts = [torch.unsqueeze(g, 0) for g in grad_parts]

return grad_parts

# Compute the grads_xs and grads_hs by collecting all the partial gradients
grads_intermediate = compute_grad_hs_xs()
grads_hs, grads_xs = grads_intermediate[:num_xs_masked], grads_intermediate[num_xs_masked:]

# The first grad_x is always 1, as the initial output is the initial xs
# print(grads_xs)
grads_hs = [torch.concat([torch.zeros_like(aten.slice(g, dim + 1, 0, 1, 1)), aten.slice(g, dim + 1, 1, None, 1)], dim + 1) for g in grads_hs]
grads_xs = [torch.concat([torch.stack([torch.ones_like(gp) if ind == 0 else torch.zeros_like(gp) for ind, gp in enumerate(aten.slice(g, dim + 1, 0, 1, 1))], 0), aten.slice(g, dim + 1, 1, None, 1)], dim + 1) for g in grads_xs]

# flat_grads_init = [torch.ones_like(x) for x in xs]
# tmp_grads = mapped_joint_graph(*flat_grads_init, *shifted_outs, *xs)
# grads_xs = tmp_grads[num_xs_masked:]

# Compute the cumprod of the rows of the gradient matrix and fill the remainder with zeros
def cumprod_and_prepad(fg, val, size):
return torch.concat([zeros] * max(size - 1 - val.shape[dim + 1], 0) + [ones * fg] + [fg * torch.sum(torch.cumprod(val, dim + 1), 0)], dim)
# return torch.concat([zeros] * max(size - 1 - val.shape[dim], 0) + [ones * fg] + [fg * torch.cumprod(val, dim)], dim)

# Compute the gradients for a single element of xs
# The computations are done on dim + 1, because g_x and g_h have all partial gradients on dim 0
# The partial gradients are combined in the process of this function
def compute_grad_xs(fg, g_x, g_h):
g_x = torch.flip(g_x, [dim + 1])
g_h = torch.flip(g_h, [dim + 1])
fg = torch.concat([torch.flip(fg, [dim]), ones], dim)

# Create the matrix consisting of
gradient_mat = [cumprod_and_prepad(aten.slice(fg, dim, n + 0, n + 1, 1), aten.slice(g_h, dim + 1, n, -1, 1), scan_length) for n in range(0, scan_length, 1)]
grads = torch.flip(torch.sum(torch.stack(gradient_mat, 0) * torch.sum(g_x, 0), 0), [dim])

# gradient_mat = [cumprod_and_prepad(aten.slice(fg, dim, n + 0, n + 1, 1), aten.slice(g_h[0], dim, n, -1, 1), scan_length) for n in range(0, scan_length, 1)]
# grads1 = torch.flip(torch.sum(torch.stack(gradient_mat, 0) * g_x[0], 0), [dim])
# gradient_mat = [cumprod_and_prepad(aten.slice(fg, dim, n + 0, n + 1, 1), aten.slice(g_h[1], dim, n, -1, 1), scan_length) for n in range(0, scan_length, 1)]
# grads2 = torch.flip(torch.sum(torch.stack(gradient_mat, 0) * g_x[1], 0), [dim])
# grads = torch.sum(torch.stack([torch.sum(torch.stack([cumprod_and_prepad(aten.slice(fg, dim, n + 0, n + 1, 1), aten.slice(g_h[part], dim, n, -1, 1), scan_length) for n in range(0, scan_length, 1)], 0) * g_x[part], 0) for part in range(g_x.shape[0])], 0), 0)

# grads2 = torch.flip(torch.sum(torch.stack(gradient_mat, 0) * g_x[1], 0), [dim])
# grads = torch.flip(torch.sum(torch.stack(gradient_mat, 0) * g_x, 0), [dim])
# grads = torch.sum(torch.stack(gradient_mat, 0)[0:1] * g_x[0:1, :], 0)
# grads += torch.sum(torch.stack(gradient_mat, 0)[1:2] * g_x[1:2, :], 0)
# grads = torch.flip(grads, [dim])
# return grads
return grads

compute_grad_xs_mapped = torch.vmap(compute_grad_xs, 0, 0)
grads = [torch.squeeze(el, 0) for el in torch.split(compute_grad_xs_mapped(torch.stack(flat_grads, 0), torch.stack(grads_xs, 0), torch.stack(grads_hs, 0)), 1, 0)]

# Expand the gradients with Nons for partial gradient support
grads = expand_grads_with_None(grads, gradient_mask)

# print(grads)
return *outs,

@staticmethod
Expand Down Expand Up @@ -498,7 +607,6 @@ def g(x: torch.Tensor, y: torch.Tensor):
https://justintchiu.com/blog/pscan_diff/

"""
joint_graph = ctx._joint_graph
dim = ctx._dim
scan_length = ctx._scan_length
num_xs = ctx._num_xs
Expand All @@ -515,7 +623,7 @@ def g(x: torch.Tensor, y: torch.Tensor):
shifted_outs = [torch.concat([ones, aten.slice(o, dim, 0, -1, 1)], dim) for o in outs]

# vmap joint graph over scan dimension
mapped_joint_graph = torch.vmap(joint_graph, int(dim), int(dim))
mapped_joint_graph = ctx._mapped_joint_graph

def expand_grads_with_None(real_grads, mask):
g_list = []
Expand Down Expand Up @@ -554,9 +662,13 @@ def compute_part_grads(flat_grad_ind):
grads_intermediate = compute_grad_hs_xs()
grads_hs, grads_xs = grads_intermediate[:num_xs_masked], grads_intermediate[num_xs_masked:]

grads_hs = [torch.concat([torch.zeros_like(aten.slice(g, dim + 1, 0, 1, 1)), aten.slice(g, dim + 1, 1, None, 1)], dim + 1) for g in grads_hs]
grads_xs = [torch.concat([torch.stack([torch.ones_like(gp) if ind == 0 else torch.zeros_like(gp) for ind, gp in enumerate(aten.slice(g, dim + 1, 0, 1, 1))], 0), aten.slice(g, dim + 1, 1, None, 1)], dim + 1) for g in grads_xs]

# Compute the cumprod of the rows of the gradient matrix and fill the remainder with zeros
def cumprod_and_prepad(fg, val, size):
return torch.concat([zeros] * max(size - 1 - val.shape[dim + 1], 0) + [ones * fg] + [fg * torch.sum(torch.cumprod(val, dim + 1), 0)], dim)
# return torch.concat([zeros] * max(size - 1 - val.shape[dim], 0) + [ones * fg] + [fg * torch.cumprod(val, dim)], dim)

# Compute the gradients for a single element of xs
# The computations are done on dim + 1, because g_x and g_h have all partial gradients on dim 0
Expand All @@ -569,7 +681,8 @@ def compute_grad_xs(fg, g_x, g_h):
# Create the matrix consisting of
gradient_mat = [cumprod_and_prepad(aten.slice(fg, dim, n + 0, n + 1, 1), aten.slice(g_h, dim + 1, n, -1, 1), scan_length) for n in range(0, scan_length, 1)]
grads = torch.flip(torch.sum(torch.stack(gradient_mat, 0) * torch.sum(g_x, 0), 0), [dim])

# grads = torch.sum(torch.stack([torch.sum(torch.stack([cumprod_and_prepad(aten.slice(fg, dim, n + 0, n + 1, 1), aten.slice(g_h[part], dim, n, -1, 1), scan_length) for n in range(0, scan_length, 1)], 0) * g_x[part], 0) for part in range(g_x.shape[0])], 0), 0)
# grads = torch.flip(grads, [dim])
return grads

compute_grad_xs_mapped = torch.vmap(compute_grad_xs, 0, 0)
Expand Down
0