diff --git a/test/functorch/test_control_flow.py b/test/functorch/test_control_flow.py index 9349e9c103f2bf..0b166aeef9a1be 100644 --- a/test/functorch/test_control_flow.py +++ b/test/functorch/test_control_flow.py @@ -113,7 +113,7 @@ def compile_mode_helper(fct, compile_mode): ] -def get_scan_combine_fn(name, associative=True): +def get_scan_combine_fn(name, associative=True, parameters=None): def add(x: torch.Tensor, y: torch.Tensor): return x + y @@ -156,6 +156,18 @@ def non_pointwise(x: torch.Tensor, y: torch.Tensor): W = torch.diag(torch.ones(2, device=x.device)) return x @ W + y @ W + def RNN(x: torch.Tensor, y: torch.Tensor): + c_new = y @ parameters[0] + parameters[1] + h_new = torch.tanh(c_new + x @ parameters[2] + parameters[3]) + return h_new, h_new + + def fct_c1_no_grad(x: torch.Tensor, y: torch.Tensor): + h_new = torch.tanh(x[0] + x[1] + y) + c2 = x[1] + y + with torch.no_grad(): + c1 = x[0] + y + return (c1, c2), h_new + if name == "add": fct = add elif name == "adds": @@ -174,6 +186,10 @@ def non_pointwise(x: torch.Tensor, y: torch.Tensor): fct = complex_pointwise elif name == "non_pointwise": fct = non_pointwise + elif name == "RNN": + fct = RNN + elif name == "fct_c1_no_grad": + fct = fct_c1_no_grad else: raise ValueError("Combine_fn name unknown!") @@ -442,6 +458,18 @@ def setUp(self): torch._dynamo.reset() super().setUp() + def check_autograd(self, result, result_exp, params): + params_flatten = pytree.tree_leaves(params) + result_flatten = pytree.tree_leaves(result) + result_exp_flatten = pytree.tree_leaves(result_exp) + grad_exp_init = [torch.ones_like(el) for el in result_exp_flatten] + expected_grads = torch.autograd.grad( + result_exp_flatten, params_flatten, grad_exp_init + ) + grad_init = [torch.ones_like(el) for el in result_flatten] + grads = torch.autograd.grad(result_flatten, params_flatten, grad_init) + self.assertEqual(grads, expected_grads, atol=6e-05, rtol=6e-06) + def test_cond_no_trace(self): def true_fn(x): return x.sin() @@ -1529,11 +1557,12 @@ def combine_fn(carry, x): @parametrize("reverse", [False, True]) @parametrize("compile_mode", ["none", "eager"]) @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) - def test_scan_compile(self, reverse, compile_mode, device): + @parametrize("autograd", [False, True]) + def test_scan_compile(self, reverse, compile_mode, device, autograd): def add2(x: torch.Tensor, y: torch.Tensor): return x * y, x + y - x = torch.randn(3, 10, 2, device=device) + x = torch.randn(3, 10, 2, device=device, requires_grad=autograd) scan_fct = compile_mode_helper(scan, compile_mode) @@ -1541,12 +1570,12 @@ def add2(x: torch.Tensor, y: torch.Tensor): ( get_scan_combine_fn("add", False), torch.cumsum, - torch.zeros(10, 2, device=device), + torch.zeros(10, 2, device=device, requires_grad=autograd), ), ( get_scan_combine_fn("mul", False), torch.cumprod, - torch.ones(10, 2, device=device), + torch.ones(10, 2, device=device, requires_grad=autograd), ), ]: result = scan_fct(op, init, x, dim=0, reverse=reverse) @@ -1556,6 +1585,9 @@ def add2(x: torch.Tensor, y: torch.Tensor): result_exp_PT = op_pt(x, 0) self.assertEqual(result[1], result_exp_PT) + if autograd: + self.check_autograd(result, result_exp, (init, x)) + # Jax Examples x = torch.arange(0, 4, device=device, dtype=torch.int64) init = torch.zeros(1, device=device, dtype=torch.int64) @@ -1607,8 +1639,10 @@ def add2(x: torch.Tensor, y: torch.Tensor): self.assertEqual(result, result_exp) # Non associative operation - x = torch.arange(0, 5, device=device, dtype=torch.float32) - init = torch.ones(1, device=device, dtype=torch.float32) + x = torch.arange( + 0, 5, device=device, dtype=torch.float32, requires_grad=autograd + ) + init = torch.ones(1, device=device, dtype=torch.float32, requires_grad=autograd) result = scan_fct( get_scan_combine_fn("div", False), init, @@ -1625,6 +1659,9 @@ def add2(x: torch.Tensor, y: torch.Tensor): ) self.assertEqual(result, result_exp) + if autograd: + self.check_autograd(result, result_exp, (init, x)) + # TODO: provide an implementation for all compile modes and re-enable all test @skipIfTorchDynamo("don't test compile on compile") @requires_cuda @@ -1698,32 +1735,37 @@ def test_scan_dtype(self, reverse, compile_mode, device, dtype): ], ) + @unittest.skipIf(not SM70OrLater, "triton") @requires_cuda @parametrize("reverse", [False, True]) + @parametrize("compile_mode", ["none", "eager"]) @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) - def test_scan_dim(self, reverse, device): + @parametrize("autograd", [False, True]) + def test_scan_dim(self, reverse, compile_mode, device, autograd): import random - num_dims = [random.randint(2, 5) for _ in range(10)] + scan_fct = compile_mode_helper(scan, compile_mode) + + num_dims = [random.randint(2, 5) for _ in range(5)] for num_dim in num_dims: shapes = [random.randint(1, 10) for _ in range(num_dim)] rnd_scan_dim = random.randint(0, num_dim - 1) - x = torch.randn(*shapes, device=device) + x = torch.randn(*shapes, device=device, requires_grad=autograd) init_shapes = shapes[:rnd_scan_dim] + shapes[rnd_scan_dim + 1 :] for op, op_pt, init in [ ( get_scan_combine_fn("add", False), torch.cumsum, - torch.zeros(*init_shapes, device=device), + torch.zeros(*init_shapes, device=device, requires_grad=autograd), ), ( get_scan_combine_fn("mul", False), torch.cumprod, - torch.ones(*init_shapes, device=device), + torch.ones(*init_shapes, device=device, requires_grad=autograd), ), ]: - result = scan(op, init, x, dim=rnd_scan_dim, reverse=reverse) + result = scan_fct(op, init, x, dim=rnd_scan_dim, reverse=reverse) result_exp = _fake_scan( op, init=init, xs=x, dim=rnd_scan_dim, reverse=reverse ) @@ -1734,33 +1776,41 @@ def test_scan_dim(self, reverse, device): res_list[1] = res_list[1].movedim(0, rnd_scan_dim) self.assertEqual(res_list[1], result_exp_PT) + if autograd: + self.check_autograd(result, result_exp, (init, x)) + + @unittest.skipIf(not SM70OrLater, "triton") @requires_cuda @parametrize("reverse", [False, True]) + @parametrize("compile_mode", ["none", "eager"]) @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) - def test_scan_binary_operator(self, reverse, device): + @parametrize("autograd", [False, True]) + def test_scan_binary_operator(self, reverse, compile_mode, device, autograd): state_dim = 20 timesteps = 10 + scan_fct = compile_mode_helper(scan, compile_mode) + projected_inputs = torch.randn( - timesteps, state_dim, requires_grad=True, device=device + timesteps, state_dim, requires_grad=autograd, device=device ) - A = torch.randn(state_dim, requires_grad=True, device=device) + A = torch.randn(state_dim, requires_grad=autograd, device=device) elements = (A.repeat((timesteps, 1)), projected_inputs) init = tuple( [ torch.ones_like( torch._ops.ops.aten.slice(elements[0], 0, 0, 1, 1), - requires_grad=True, + requires_grad=autograd, ) ] + [ torch.zeros_like( torch._ops.ops.aten.slice(projected_inputs, 0, 0, 1, 1), - requires_grad=True, + requires_grad=autograd, ) ] ) - result = scan( + result = scan_fct( get_scan_combine_fn("s5_operator", False), init, elements, @@ -1776,17 +1826,37 @@ def test_scan_binary_operator(self, reverse, device): ) self.assertEqual(result, expected_result) + if autograd: + init_flatten, _ = pytree.tree_flatten(init) + elements_flatten, _ = pytree.tree_flatten(elements) + + result_flatten, _ = pytree.tree_flatten(result) + result_exp_flatten, _ = pytree.tree_flatten(expected_result) + grad_out = [torch.ones_like(el) for el in result_exp_flatten] + expected_grads = torch.autograd.grad( + result_exp_flatten, (*init_flatten, *elements_flatten), grad_out + ) + grads = torch.autograd.grad( + result_flatten, (*init_flatten, *elements_flatten), grad_out + ) + self.assertEqual(grads, expected_grads) + @skipIfRocm(msg="Unsupported on ROCM yet") + @unittest.skipIf(not SM70OrLater, "triton") @requires_cuda @parametrize("reverse", [False, True]) + @parametrize("compile_mode", ["none", "eager"]) @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) - def test_scan_tuple(self, reverse, device): - x = torch.randn(3, 2, 2, device=device) - y = torch.randn(3, 2, 2, device=device) + @parametrize("autograd", [False, True]) + def test_scan_tuple(self, reverse, compile_mode, device, autograd): + x = torch.randn(3, 2, 2, device=device, requires_grad=autograd) + y = torch.randn(3, 2, 2, device=device, requires_grad=autograd) inp = (x, y) init = tuple(torch._ops.ops.aten.slice(e, 0, 0, 1, 1) for e in inp) - result_same = scan( + scan_fct = compile_mode_helper(scan, compile_mode) + + result_same = scan_fct( get_scan_combine_fn("tuple_fct", False), init, inp, @@ -1802,6 +1872,9 @@ def test_scan_tuple(self, reverse, device): ) self.assertEqual(result_same, expected_result) + if autograd: + self.check_autograd(result_same, expected_result, (init, inp)) + def fct_different_output_tuple(x, y): return ((x[0] + y[0], x[1] * y[1]), (x[1] * y[1])) @@ -1817,6 +1890,9 @@ def fct_different_output_tuple(x, y): self.assertEqual(result_diff, expected_result) self.assertEqual(result_diff[1], result_same[1][1]) + if autograd: + self.check_autograd(result_diff, expected_result, (init, inp)) + def test_scan_wrong_pytree(self): # Init and input have same pytree def fct_wrong_pytree(x, y): @@ -1869,19 +1945,23 @@ def fct_float_output(x, y): @requires_cuda @parametrize("reverse", [False, True]) + @parametrize("compile_mode", ["none", "eager"]) @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) - def test_scan_complex_pytree(self, reverse, device): + @parametrize("autograd", [False, True]) + def test_scan_complex_pytree(self, reverse, compile_mode, device, autograd): # Init and input have same pytree - x = torch.randn(3, 2, 2, device=device) - y = torch.randn(3, 2, 2, device=device) - z = torch.randn(3, 2, 2, device=device) + scan_fct = compile_mode_helper(scan, compile_mode) + + x = torch.randn(3, 2, 2, device=device, requires_grad=autograd) + y = torch.randn(3, 2, 2, device=device, requires_grad=autograd) + z = torch.randn(3, 2, 2, device=device, requires_grad=autograd) inp = {"i": x, "j": ([y], [{"o": z}])} inp_flat, inp_spec = pytree.tree_flatten(inp) init_flat = [torch._ops.ops.aten.slice(e, 0, 0, 1, 1) for e in inp_flat] init = pytree.tree_unflatten(init_flat, inp_spec) - result = scan( + result = scan_fct( get_scan_combine_fn("complex_pointwise", False), init, inp, @@ -1897,6 +1977,9 @@ def test_scan_complex_pytree(self, reverse, device): ) self.assertEqual(result, expected_result) + if autograd: + self.check_autograd(result, expected_result, (init, inp)) + # TODO: Does not work because of the usage of vmap witin associative_scan # The paT206899919 rameterization is commented out for the moment and the test is marked with expected fail # Fails with: AssertionError: scan is not an OpOverload @@ -1946,9 +2029,10 @@ def body(x, y): @parametrize("compile_mode", ["none", "eager"]) @parametrize("reverse", [False, True]) @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) - def test_scan_downstream_scan_matmul(self, compile_mode, reverse, device): - inp = torch.randn(3, 10, 2, device=device) - init = torch.randn(3, 2, device=device) + @parametrize("autograd", [False, True]) + def test_scan_downstream_scan_matmul(self, compile_mode, reverse, device, autograd): + inp = torch.randn(3, 10, 2, device=device, requires_grad=autograd) + init = torch.randn(3, 2, device=device, requires_grad=autograd) for ind in range(2): # Chain with matmul @@ -1975,18 +2059,24 @@ def chain_fct(inp): result = fct_cmp(inp) self.assertEqual(result, expected_result) + if autograd: + self.check_autograd(result, expected_result, (init, inp)) + # TODO: provide an implementation for all compile modes and re-enable all test @skipIfTorchDynamo("don't test compile on compile") @requires_cuda @parametrize("compile_mode", ["none", "eager"]) @parametrize("reverse", [False, True]) @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) - def test_scan_downstream_scan_scan_dim(self, compile_mode, reverse, device): - inp = torch.randn(3, 10, 2, device=device) - init = torch.randn(3, 2, device=device) + @parametrize("autograd", [False, True]) + def test_scan_downstream_scan_scan_dim( + self, compile_mode, reverse, device, autograd + ): + inp = torch.randn(3, 10, 2, device=device, requires_grad=autograd) + init = torch.randn(3, 2, device=device, requires_grad=autograd) # Chain with scan on different dim - init2 = torch.randn(1, 10, 2, device=device) + init2 = torch.randn(1, 10, 2, device=device, requires_grad=autograd) def chain_fct_different_dim(inp): o1 = scan( @@ -2026,13 +2116,21 @@ def chain_fct_different_dim(inp): result = fct_cmp(inp) self.assertEqual(result, expected_result) + if autograd: + self.check_autograd(result, expected_result, (init, init2, inp)) + + @unittest.skipIf(not SM70OrLater, "triton") @requires_cuda @parametrize("reverse", [False, True]) + @parametrize("compile_mode", ["none", "eager"]) @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) - def test_scan_non_pointwise(self, reverse, device): - x = torch.randn(3, 10, 2, device=device) - init = torch.randn(10, 2, device=device) - result_expected = _fake_scan( + @parametrize("autograd", [False, True]) + def test_scan_non_pointwise(self, reverse, compile_mode, device, autograd): + scan_fct = compile_mode_helper(scan, compile_mode) + + x = torch.randn(3, 10, 2, device=device, requires_grad=autograd) + init = torch.randn(10, 2, device=device, requires_grad=autograd) + expected_result = _fake_scan( get_scan_combine_fn("non_pointwise", False), init=init, xs=x, @@ -2040,14 +2138,17 @@ def test_scan_non_pointwise(self, reverse, device): reverse=reverse, ) - out = scan( + result = scan_fct( get_scan_combine_fn("non_pointwise", False), init, x, dim=0, reverse=reverse, ) - self.assertEqual(out, result_expected) + self.assertEqual(result, expected_result) + + if autograd: + self.check_autograd(result, expected_result, (init, x)) @requires_cuda @parametrize("reverse", [False, True]) @@ -2181,26 +2282,20 @@ def test_scan_compile_cnt(self, reverse, device): self.assertEqual(cnt.frame_count, 6) @skipIfTorchDynamo("don't test compile on compile") - @requires_cuda - @parametrize("compile_mode", ["none", "eager"]) - def test_scan_init_scanned_0(self, compile_mode): - scan_fct = compile_mode_helper(scan, compile_mode) - + def test_scan_init_scanned_0(self): # Only init and no input - x = torch.randn(3, 1, 2) - init = torch.randn(3, 2) + x = torch.randn(3, 1, 2, device=torch.device("cpu")) + init = torch.randn(3, 2, device=torch.device("cpu")) dim = 1 # Scan dimension is 0 init = torch._ops.ops.aten.slice(x, dim, 0, 1, 1) inp = torch._ops.ops.aten.slice(x, dim, 1, None, 1) with self.assertRaisesRegex( - # RuntimeError, - # "scan\(\) operator doesn't support.*", - torch._dynamo.exc.UncapturedHigherOrderOpError, - "scan must be captured completely with.*", + RuntimeError, + "All xs leaves must at least have.*", ): - scan_fct( + scan( get_scan_combine_fn("add", False), init, inp, @@ -2208,29 +2303,14 @@ def test_scan_init_scanned_0(self, compile_mode): ) @skipIfTorchDynamo("don't test compile on compile") - @requires_cuda - @parametrize("compile_mode", ["none", "eager"]) - def test_scan_init_non_tensor(self, compile_mode): - scan_fct = compile_mode_helper(scan, compile_mode) - - x = torch.randn(3, 1, 2) + def test_scan_init_non_tensor(self): + x = torch.randn(3, 1, 2, device=torch.device("cpu")) dim = 1 # Init is a float and not a tensor init = 1.0 - if compile_mode == "none": - with self.assertRaisesRegex( - RuntimeError, - "All init leaves must be a Tensor", - ): - scan_fct(get_scan_combine_fn("add", False), init, x, dim=dim) - else: - with self.assertRaisesRegex( - # Should be: RuntimeError, "Init leaves must be a Tensor" - torch._dynamo.exc.Unsupported, - "Observed exception.*", - ): - scan_fct(get_scan_combine_fn("add", False), init, x, dim=dim) + with self.assertRaisesRegex(RuntimeError, "All init leaves must be a Tensor.*"): + scan(get_scan_combine_fn("add", False), init, x, dim=dim, reverse=False) @skipIfTorchDynamo("don't test compile on compile") def test_scan_init_wrong_shape(self): @@ -2347,12 +2427,12 @@ def no_carry(x: torch.Tensor, y: torch.Tensor): @parametrize("reverse", [False, True]) @parametrize("compile_mode", ["none", "eager"]) @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) - def test_scan_init(self, reverse, compile_mode, device): + @parametrize("autograd", [False, True]) + def test_scan_init(self, reverse, compile_mode, device, autograd): scan_fct = compile_mode_helper(scan, compile_mode) # Only init and no input - x = torch.randn(3, 1, 2, device=device) - init = torch.randn(3, 1, 2, device=device) + x = torch.randn(3, 1, 2, device=device, requires_grad=autograd) dim = 1 op, op_pt = (get_scan_combine_fn("add", False), torch.cumsum) @@ -2365,15 +2445,17 @@ def test_scan_init(self, reverse, compile_mode, device): self.assertEqual(result_init, result_exp) self.assertEqual(result_init[0], init) - x = torch.randn(3, 5, 2, device=device) - init = torch.randn(3, 5, 2, device=device) + if autograd: + self.check_autograd(result, result_exp, (init,)) + + x = torch.randn(3, 5, 2, device=device, requires_grad=autograd) dim = 0 op, op_pt = (get_scan_combine_fn("add", False), torch.cumsum) inp = torch._ops.ops.aten.slice(x, dim, 1, None, 1) # Init tensor scalar - init = torch.ones(1, device=device) + init = torch.ones(1, device=device, requires_grad=autograd) def add_scalar_carry(x: torch.Tensor, y: torch.Tensor): return x + 1.0, x + y @@ -2385,8 +2467,11 @@ def add_scalar_carry(x: torch.Tensor, y: torch.Tensor): self.assertEqual(result_init, result_exp) self.assertEqual(result_init[0], torch.tensor([3.0], device=device)) + if autograd: + self.check_autograd(result_init, result_exp, (init, inp)) + # Init tensor entirely different shape than inp - init = torch.randn(7, 8, device=device) + init = torch.randn(7, 8, device=device, requires_grad=autograd) def add_scalar_carry2(x: torch.Tensor, y: torch.Tensor): return x + 1.0, x[: y.shape[0], : y.shape[1]] + y @@ -2408,6 +2493,9 @@ def add_scalar_carry2(x: torch.Tensor, y: torch.Tensor): self.assertEqual(result_init, result_exp) self.assertEqual(result_init[0].shape, torch.Size([2, 5, 2])) + if autograd: + self.check_autograd(result_init, result_exp, (init, inp)) + init = torch.tile(init, (1, 2, 1)) def add_scalar_carry_sliced_out(x: torch.Tensor, y: torch.Tensor): @@ -2423,17 +2511,14 @@ def add_scalar_carry_sliced_out(x: torch.Tensor, y: torch.Tensor): self.assertEqual(result_init[0].shape, torch.Size([2, 10, 2])) self.assertEqual(result_init[1].shape, torch.Size([2, 2, 5, 2])) + if autograd: + self.check_autograd(result_init, result_exp, (init, inp)) + # Correct case op, op_pt = (get_scan_combine_fn("add", False), torch.cumsum) - x = torch.randn(3, 2, 2, device=device) - dim = 1 - - if reverse: - init = torch.zeros_like(torch.select_copy(x, -1, 0)) - inp = torch._ops.ops.aten.slice(x, dim, 0, -1, 1) - else: - init = torch.zeros_like(torch.select_copy(x, 1, 0)) - inp = torch._ops.ops.aten.slice(x, dim, 1, None, 1) + x = torch.randn(3, 2, 2, device=device, requires_grad=autograd) + init = torch.zeros(3, 2, device=device, requires_grad=autograd) + dim = 2 result = scan_fct(op, init, x, dim=dim, reverse=reverse) result_exp = _fake_scan(op, init=init, xs=x, dim=dim, reverse=reverse) @@ -2442,9 +2527,12 @@ def add_scalar_carry_sliced_out(x: torch.Tensor, y: torch.Tensor): if not reverse: result_exp_PT = op_pt(x, dim) result = list(result) - result[1] = pytree.tree_map(lambda t: t.movedim(0, dim), result[1]) + result[1] = pytree.tree_map(lambda t: torch.movedim(t, 0, dim), result[1]) self.assertEqual(result[1], result_exp_PT) + if autograd: + self.check_autograd(result, result_exp, (init, x)) + @requires_cuda @parametrize("reverse", [False, True]) @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) @@ -2481,10 +2569,13 @@ def test_scan_init_wrong_pytree_complex(self, reverse, device): reverse=reverse, ) + @unittest.skipIf(not SM70OrLater, "triton") @requires_cuda @parametrize("reverse", [False, True]) + @parametrize("compile_mode", ["none", "eager"]) @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) - def test_scan_init_pytree_complex(self, reverse, device): + @parametrize("autograd", [False, True]) + def test_scan_init_pytree_complex(self, reverse, compile_mode, device, autograd): def fct_pointwise_different_output(x, y): return ( { @@ -2495,7 +2586,7 @@ def fct_pointwise_different_output(x, y): ), }, ( - y["i"], + y["i"] * 2, { "o": x["i"] * y["i"], "j": ( @@ -2511,13 +2602,13 @@ def fct_pointwise_different_carry(x, y): { "i": x["i"] * y["i"], "j": ( - x["i"], + x["i"] * 2, [x["j"][1][0] * y["j"][0][0]], [{"o": x["j"][2][0]["o"] + y["j"][1][0]["o"]}], ), }, ( - y["i"], + y["i"] * 2, { "o": x["i"] * y["i"] + x["j"][0][0], "j": ( @@ -2528,9 +2619,11 @@ def fct_pointwise_different_carry(x, y): ), ) - x = torch.randn(3, 2, 2, device=device) - y = torch.randn(3, 2, 2, device=device) - z = torch.randn(3, 2, 2, device=device) + scan_fct = compile_mode_helper(scan, compile_mode) + + x = torch.randn(3, 2, 2, device=device, requires_grad=autograd) + y = torch.randn(3, 2, 2, device=device, requires_grad=autograd) + z = torch.randn(3, 2, 2, device=device, requires_grad=autograd) if reverse: init_start, init_end = -1, None @@ -2554,7 +2647,7 @@ def fct_pointwise_different_carry(x, y): [{"o": torch._ops.ops.aten.slice(z, 0, inp_start, inp_end, 1)}], ), } - result = scan( + result = scan_fct( get_scan_combine_fn("complex_pointwise", False), init, inp, @@ -2570,8 +2663,15 @@ def fct_pointwise_different_carry(x, y): ) self.assertEqual(result, expected_result) + if autograd: + init_flat = pytree.tree_leaves(init) + inp_flat = pytree.tree_leaves(inp) + self.check_autograd(result, expected_result, (*init_flat, *inp_flat)) + # Pytree of output is different - result = scan(fct_pointwise_different_output, init, inp, dim=0, reverse=reverse) + result = scan_fct( + fct_pointwise_different_output, init, inp, dim=0, reverse=reverse + ) expected_result = _fake_scan( fct_pointwise_different_output, init=init, xs=inp, dim=0, reverse=reverse ) @@ -2593,12 +2693,19 @@ def fct_pointwise_different_carry(x, y): [{"o": torch._ops.ops.aten.slice(z, 0, inp_start, inp_end, 1)}], ), } - result = scan(fct_pointwise_different_carry, init, inp, dim=0, reverse=reverse) + result = scan_fct( + fct_pointwise_different_carry, init, inp, dim=0, reverse=reverse + ) expected_result = _fake_scan( fct_pointwise_different_carry, init=init, xs=inp, dim=0, reverse=reverse ) self.assertEqual(result, expected_result) + if autograd: + init_flat = pytree.tree_leaves(init) + inp_flat = pytree.tree_leaves(inp) + self.check_autograd(result, expected_result, (*init_flat, *inp_flat)) + @skipIfTorchDynamo("don't test compile on compile") @skipIfNoDynamoSupport @skipIfCrossRef # Arg order changes with crossref @@ -2652,42 +2759,484 @@ def forward(self, child: "f32[1, 10, 2]", child_1: "f32[1, 10, 2]", child_2: "f3 """, # noqa: B950 ) - def test_scan_RNN(self): + @skipIfTorchDynamo("Graph is not captured by backend if test with dynamo") + @unittest.skipIf(not SM70OrLater, "triton") + @requires_cuda + @parametrize("compile_mode", ["none", "eager"]) + @parametrize("autograd", [False, True]) + def test_scan_closure_RNN(self, compile_mode, autograd): dim = 1 device = torch.device("cpu") + scan_fct = compile_mode_helper(scan, compile_mode) rnn = torch.nn.RNN( input_size=5, hidden_size=7, + batch_first=True, ) rnn = rnn.to(device=device) - x = torch.randn(1, 2, 5, device=device) - h = torch.randn(1, 2, 7, device=device) - - new_state_dict = { - "weight_ih_l0": torch.ones_like(rnn.weight_ih_l0), - "bias_ih_l0": torch.ones_like(rnn.bias_ih_l0), - "weight_hh_l0": torch.ones_like(rnn.weight_hh_l0), - "bias_hh_l0": torch.ones_like(rnn.bias_hh_l0), - } - rnn.load_state_dict(new_state_dict) + x = torch.randn(3, 10, 5, device=device, requires_grad=autograd) + h = torch.randn(3, 7, device=device, requires_grad=autograd) + + W_ih = rnn.weight_ih_l0.T.clone() + b_ih = rnn.bias_ih_l0.clone() + W_hh = rnn.weight_hh_l0.T.clone() + b_hh = rnn.bias_hh_l0.clone() + + if not autograd: + W_ih = W_ih.detach() + b_ih = b_ih.detach() + W_hh = W_hh.detach() + b_hh = b_hh.detach() + + expected_result = rnn(x, torch.unsqueeze(h, 0)) + expected_result_out = expected_result[0] + expected_result_state = expected_result[1][0, :] + + result = scan_fct( + get_scan_combine_fn("RNN", True, parameters=[W_ih, b_ih, W_hh, b_hh]), + h, + x, + dim=dim, + reverse=False, + ) + result_cmp = [result[0], torch.movedim(result[1], 0, dim)] + self.assertEqual(result_cmp[0], expected_result_state) + self.assertEqual(result_cmp[1], expected_result_out) + + if autograd: + result_flat = pytree.tree_leaves(result) + result_exp_flat = [expected_result_state, expected_result_out] + + grad_out_expected = [torch.ones_like(r) for r in result_exp_flat] + expected_grads = torch.autograd.grad( + result_exp_flat, + ( + h, + x, + rnn.weight_ih_l0, + rnn.bias_ih_l0, + rnn.weight_hh_l0, + rnn.bias_hh_l0, + ), + grad_out_expected, + ) + expected_add_input_grads = list(expected_grads[2:]) + expected_grads = expected_grads[:2] + + grad_out = [torch.ones_like(r) for r in result] + grads = torch.autograd.grad( + result_flat, (h, x, W_ih, b_ih, W_hh, b_hh), grad_out + ) + add_input_grads = list(grads[2:]) + add_input_grads[0] = add_input_grads[0].T + add_input_grads[2] = add_input_grads[2].T + grads = grads[:2] + self.assertEqual(grads, expected_grads) + self.assertEqual(add_input_grads, expected_add_input_grads) + + @unittest.skipIf(not SM70OrLater, "triton") + @requires_cuda + @parametrize("reverse", [False, True]) + @parametrize("compile_mode", ["none", "eager"]) + @parametrize( + "partial_grad", ["xs", "init", "additional_inputs", "complex", "random"] + ) + @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) + def test_scan_closure_RNN_partial_autograd( + self, reverse, compile_mode, partial_grad, device + ): + dim = 1 + scan_fct = compile_mode_helper(scan, compile_mode) + + # The first two booleans are the xs + # The second two are the inits + # The last four are the additional_inputs + autograds = [] + + if partial_grad == "xs": + # xs tests + autograds.append([True, False, True, True, True, True, True, True]) + autograds.append([False, False, True, True, True, True, True, True]) + elif partial_grad == "init": + # init tests + autograds.append([True, True, False, True, True, True, True, True]) + autograds.append([True, True, False, False, True, True, True, True]) + elif partial_grad == "additional_inputs": + # additional input tests + autograds.append([True, True, True, True, False, True, False, True]) + autograds.append([True, True, True, True, False, False, False, False]) + elif partial_grad == "complex": + # complex cases + autograds.append([True, False, False, False, False, False, False, True]) + autograds.append([False, False, True, True, False, False, False, True]) + elif partial_grad == "random": + # random tests + import random + + for _ in range(5): + autograds.append([bool(random.randint(0, 1)) for _ in range(8)]) + + for autograd in autograds: + x = torch.randn(3, 10, 5, device=device, requires_grad=autograd[0]) + x1 = torch.randn(3, 10, 5, device=device, requires_grad=autograd[1]) + h = torch.randn(3, 7, device=device, requires_grad=autograd[2]) + h_1 = torch.randn(3, 7, device=device, requires_grad=autograd[3]) + W_ih = torch.randn(5, 7, device=device, requires_grad=autograd[4]) + b_ih = torch.randn(7, device=device, requires_grad=autograd[5]) + W_hh = torch.randn(7, 7, device=device, requires_grad=autograd[6]) + b_hh = torch.randn(7, device=device, requires_grad=autograd[7]) + + params = [ + p + for p, a in zip([x, x1, h, h_1, W_ih, b_ih, W_hh, b_hh], autograd) + if a + ] + + def RNN(x: torch.Tensor, y: torch.Tensor): + c_new_0 = x[0] + 1 + c_new_1 = x[1] + 1 + h_new = ( + torch.tanh(c_new_1 + x[0] @ W_hh + b_hh) + + y[0] @ W_ih + + y[1] @ W_ih + + b_ih + + x[1] + ) + return (c_new_0, c_new_1), h_new + + inits = (h, h_1) + result = scan_fct(RNN, inits, (x, x1), dim=dim, reverse=reverse) + result_exp = _fake_scan(RNN, (h, h_1), (x, x1), dim=dim, reverse=reverse) + self.assertEqual(result, result_exp) + + if autograd: + result_flat = pytree.tree_leaves(result) + result_exp_flat = pytree.tree_leaves(result_exp) + exp_grad_mask = [ + True if r.requires_grad else False for r in result_exp_flat + ] + self.check_autograd( + [r for r, m in zip(result_flat, exp_grad_mask) if m], + [r for r, m in zip(result_exp_flat, exp_grad_mask) if m], + params, + ) + + @unittest.skipIf(not SM70OrLater, "triton") + @requires_cuda + @parametrize("reverse", [False, True]) + @parametrize("compile_mode", ["none", "eager"]) + @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) + @parametrize("autograd", [False, True]) + def test_scan_closure_combine_fn_with_no_grad_init_carries_unequal_grad( + self, reverse, compile_mode, device, autograd + ): + dim = 1 + scan_fct = compile_mode_helper(scan, compile_mode) + x = torch.randn(3, 10, 7, device=device, requires_grad=autograd) + h1 = torch.randn(3, 7, device=device, requires_grad=autograd) + h2 = torch.randn(3, 7, device=device, requires_grad=autograd) + + result = scan_fct( + get_scan_combine_fn("fct_c1_no_grad", True), + (h1, h2), + x, + dim=dim, + reverse=reverse, + ) + result_exp = _fake_scan( + get_scan_combine_fn("fct_c1_no_grad", True), + (h1, h2), + x, + dim=dim, + reverse=reverse, + ) + self.assertEqual(result, result_exp) + + if autograd: + # TODO: Ideally we should be able to select the results that require gradients like this + # [leaf for leaf in pytree.tree_leaves(result) if leaf.requires_grad == True] + # However, for the scan operator this does not work, as all outputs always have + # grad_fn= + res_req_grad_flat = pytree.tree_leaves(result)[1:] + res_exp_req_grad_flat = pytree.tree_leaves(result_exp)[1:] + self.check_autograd(res_req_grad_flat, res_exp_req_grad_flat, (x, h2)) + + @unittest.skipIf(not SM70OrLater, "triton") + @requires_cuda + @parametrize("reverse", [False, True]) + @parametrize("compile_mode", ["none", "eager"]) + @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) + @parametrize("autograd", [False, True]) + def test_scan_closure_combine_fn_with_no_grad_init_carries_equal_grad( + self, reverse, compile_mode, device, autograd + ): + dim = 1 + scan_fct = compile_mode_helper(scan, compile_mode) + x = torch.randn(3, 10, 7, device=device, requires_grad=autograd) + h1 = torch.randn(3, 7, device=device, requires_grad=False) + h2 = torch.randn(3, 7, device=device, requires_grad=autograd) + + result = scan_fct( + get_scan_combine_fn("fct_c1_no_grad", True), + (h1, h2), + x, + dim=dim, + reverse=reverse, + ) + result_exp = _fake_scan( + get_scan_combine_fn("fct_c1_no_grad", True), + (h1, h2), + x, + dim=dim, + reverse=reverse, + ) + self.assertEqual(result, result_exp) + + if autograd: + # TODO: Ideally we should be able to select the results that require gradients like this + # [leaf for leaf in pytree.tree_leaves(result) if leaf.requires_grad == True] + # However, for the scan operator this does not work, as all outputs always have + # grad_fn= + res_req_grad_flat = pytree.tree_leaves(result)[1:] + res_exp_req_grad_flat = pytree.tree_leaves(result_exp)[1:] + self.check_autograd(res_req_grad_flat, res_exp_req_grad_flat, (x, h2)) + + @unittest.skipIf(not SM70OrLater, "triton") + @requires_cuda + @parametrize("reverse", [False, True]) + @parametrize("compile_mode", ["none", "eager"]) + @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) + @parametrize("autograd", [False, True]) + def test_scan_closure_combine_fn_with_no_grad_for_out( + self, reverse, compile_mode, device, autograd + ): + dim = 1 + scan_fct = compile_mode_helper(scan, compile_mode) + x = torch.randn(3, 10, 7, device=device, requires_grad=autograd) + h1 = torch.randn(3, 7, device=device, requires_grad=autograd) + h2 = torch.randn(3, 7, device=device, requires_grad=autograd) + + def fct_ys_no_grad(x: torch.Tensor, y: torch.Tensor): + c1 = x[0] + y + c2 = x[1] + y + with torch.no_grad(): + h_new = torch.tanh(x[0] + x[1] + y) + return (c1, c2), h_new + + result = scan_fct(fct_ys_no_grad, (h1, h2), x, dim=dim, reverse=reverse) + result_exp = _fake_scan(fct_ys_no_grad, (h1, h2), x, dim=dim, reverse=reverse) + self.assertEqual(result, result_exp) + + if autograd: + self.check_autograd(result[0], result_exp[0], (x, h1, h2)) - def RNN(x: torch.Tensor, y: torch.Tensor): - W_ih = torch.ones((5, 7), device=device) - b_ih = torch.ones((7), device=device) - W_hh = torch.ones((7, 7), device=device) - b_hh = torch.ones((7), device=device) - c_new = y @ W_ih + b_ih - h_new = torch.tanh(c_new + x @ W_hh + b_hh) - return h_new, h_new + @unittest.skipIf(not SM70OrLater, "triton") + @requires_cuda + @parametrize("reverse", [False, True]) + @parametrize("compile_mode", ["none", "eager"]) + @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) + @parametrize("autograd", [False, True]) + def test_scan_closure_combine_fn_with_no_grad_additional_inputs_partial( + self, reverse, compile_mode, device, autograd + ): + dim = 1 + scan_fct = compile_mode_helper(scan, compile_mode) + x = torch.randn(3, 10, 7, device=device, requires_grad=autograd) + h = torch.randn(3, 7, device=device, requires_grad=autograd) + W_ih = torch.randn(7, 7, device=device, requires_grad=autograd) + b_ih = torch.randn(7, device=device, requires_grad=autograd) + W_hh = torch.randn(7, 7, device=device, requires_grad=autograd) + b_hh = torch.randn(7, device=device, requires_grad=autograd) + + def fct_no_grad_bhh_Whh(x: torch.Tensor, y: torch.Tensor): + c_new = y @ W_ih + b_ih + x + + h_new = c_new + 1 + with torch.no_grad(): + h_new_no_grad = torch.tanh(x @ W_hh + b_hh) + h_new2 = h_new + h_new_no_grad + + return c_new, h_new2 + + result = scan_fct(fct_no_grad_bhh_Whh, h, x, dim=dim, reverse=reverse) + result_exp = _fake_scan(fct_no_grad_bhh_Whh, h, x, dim=dim, reverse=reverse) + self.assertEqual(result, result_exp) + + if autograd: + self.check_autograd(result[1], result_exp[1], (h, x, W_ih, b_ih)) + + @unittest.skipIf(not SM70OrLater, "triton") + @requires_cuda + @parametrize("reverse", [False, True]) + @parametrize("compile_mode", ["none", "eager"]) + @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) + @parametrize("autograd", [False, True]) + def test_scan_closure_combine_fn_with_no_grad_additional_inputs_all( + self, reverse, compile_mode, device, autograd + ): + dim = 1 + scan_fct = compile_mode_helper(scan, compile_mode) + x = torch.randn(3, 10, 7, device=device, requires_grad=autograd) + h = torch.randn(3, 7, device=device, requires_grad=autograd) + W_ih = torch.randn(7, 7, device=device, requires_grad=autograd) + b_ih = torch.randn(7, device=device, requires_grad=autograd) + W_hh = torch.randn(7, 7, device=device, requires_grad=autograd) + b_hh = torch.randn(7, device=device, requires_grad=autograd) + + def fct_no_grad_bih_Wih_bhh_Whh(x: torch.Tensor, y: torch.Tensor): + c_new = x + y + h_new = c_new + x + with torch.no_grad(): + c_new_no_grad = y @ W_ih + b_ih + h_new_no_grad = torch.tanh(x @ W_hh + b_hh) + c_new2 = c_new + c_new_no_grad + h_new2 = h_new + h_new_no_grad + return c_new2, h_new2 - expected_result = rnn( - torch.permute(x, (1, 0, 2)), torch.unsqueeze(h[:, 0, :], 0) + result = scan_fct(fct_no_grad_bih_Wih_bhh_Whh, h, x, dim=dim, reverse=reverse) + result_exp = _fake_scan( + fct_no_grad_bih_Wih_bhh_Whh, h, x, dim=dim, reverse=reverse ) - expected_result_state = torch.permute(expected_result[1], (1, 0, 2)) - result = scan(RNN, init=torch.select_copy(h, dim, 0), xs=x, dim=dim) - self.assertEqual(result[0].unsqueeze(0), expected_result_state) - self.assertEqual(result[1], expected_result[0]) + self.assertEqual(result, result_exp) + + if autograd: + self.check_autograd(result[1], result_exp[1], (h, x)) + + @unittest.skipIf(not SM70OrLater, "triton") + @requires_cuda + @parametrize("reverse", [False, True]) + @parametrize("compile_mode", ["none", "eager"]) + @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) + @parametrize("autograd", [False, True]) + def test_scan_closure_combine_fn_carries_ys_same_grad( + self, reverse, compile_mode, device, autograd + ): + dim = 1 + scan_fct = compile_mode_helper(scan, compile_mode) + x = torch.randn(3, 10, 7, device=device, requires_grad=autograd) + h = torch.randn(3, 7, device=device, requires_grad=autograd) + W_ih = torch.randn(7, 7, device=device, requires_grad=autograd) + b_ih = torch.randn(7, device=device, requires_grad=autograd) + W_hh = torch.randn(7, 7, device=device, requires_grad=autograd) + b_hh = torch.randn(7, device=device, requires_grad=autograd) + + def fct_no_grad_bih_Wih_bhh_Whh(x: torch.Tensor, y: torch.Tensor): + c_new = x + y + h_new = c_new + 1 + with torch.no_grad(): + c_new_no_grad = y @ W_ih + b_ih + h_new_no_grad = torch.tanh(x @ W_hh + b_hh) + c_new2 = c_new + c_new_no_grad + h_new2 = h_new + h_new_no_grad + return c_new2, h_new2 + + result = scan_fct(fct_no_grad_bih_Wih_bhh_Whh, h, x, dim=dim, reverse=reverse) + result_exp = _fake_scan( + fct_no_grad_bih_Wih_bhh_Whh, h, x, dim=dim, reverse=reverse + ) + self.assertEqual(result, result_exp) + + if autograd: + self.check_autograd(result[1], result_exp[1], (h, x)) + + @unittest.skipIf(not SM70OrLater, "triton") + @requires_cuda + @parametrize("reverse", [False, True]) + @parametrize("compile_mode", ["none", "eager"]) + @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) + @parametrize("autograd", [False, True]) + def test_scan_closure_nested(self, reverse, compile_mode, device, autograd): + scan_fct = compile_mode_helper(scan, compile_mode) + + # Simple non-nested case + x = torch.randn(3, 20, 5, device=device, requires_grad=autograd) + h = torch.randn(3, 7, device=device, requires_grad=autograd) + W = torch.randn(5, 7, device=device, requires_grad=autograd) + b = torch.randn(7, device=device, requires_grad=autograd) + + def f1(x: torch.Tensor, y: torch.Tensor): + c_new = y @ W + b + h_new = torch.tanh(c_new + x) + return c_new, h_new + + result = scan_fct(f1, h, x, dim=1, reverse=reverse) + result_exp = _fake_scan(f1, h, x, dim=1, reverse=reverse) + self.assertEqual(result, result_exp) + + if autograd: + self.check_autograd(result, result_exp, (h, x, W, b)) + + # Nested case + def chain_fct(fct, f_1, f_2, xs, h_1, h_2): + o1 = fct( + f_1, + h_1, + xs, + dim=1, + reverse=reverse, + ) + o2 = fct( + f_2, + h_2, + o1[1], + dim=0, + reverse=reverse, + ) + return o2 + + x1 = torch.ones(3, 20, 5, device=device, requires_grad=autograd) + h1 = torch.zeros(3, 7, device=device, requires_grad=autograd) + h2 = torch.zeros(3, 3, device=device, requires_grad=autograd) + W_1 = torch.randn(5, 7, device=device, requires_grad=autograd) + b_1 = torch.randn(7, device=device, requires_grad=autograd) + W_2 = torch.randn(7, 3, device=device, requires_grad=autograd) + b_2 = torch.randn(3, device=device, requires_grad=autograd) + + def f1(x: torch.Tensor, y: torch.Tensor): + c_new = y @ W_1 + b_1 + h_new = torch.tanh(c_new + x) + return c_new, h_new + + def f2(x: torch.Tensor, y: torch.Tensor): + c_new = y @ W_2 + b_2 + h_new = torch.tanh(c_new + x) + return c_new, h_new + + result1 = chain_fct(scan_fct, f1, f2, x1, h1, h2) + expected_result = chain_fct(_fake_scan, f1, f2, x1, h1, h2) + self.assertEqual(result1, expected_result) + + if autograd: + self.check_autograd(result1, expected_result, (h1, h2, x1, W_1, b_1)) + + # Complex case + x1 = torch.randn(3, 20, 3, device=device, requires_grad=autograd) + h1 = torch.randn(3, 3, device=device, requires_grad=autograd) + h2 = torch.randn(3, 3, device=device, requires_grad=autograd) + W_1 = torch.randn(3, 3, device=device, requires_grad=autograd) + b_1 = torch.randn(3, device=device, requires_grad=autograd) + W_2 = torch.randn(3, 3, device=device, requires_grad=autograd) + b_2 = torch.randn(3, device=device, requires_grad=autograd) + + def f1(x: torch.Tensor, y: torch.Tensor): + c_new = y @ W_1 + b_1 + h_new = torch.tanh(c_new + x) + return c_new, h_new + + def f2(x: torch.Tensor, y: torch.Tensor): + c_new = y @ W_2 + b_2 * b_1 + y @ W_1 + h_new = torch.tanh(c_new + x) + return c_new, h_new + + result1 = chain_fct(scan_fct, f1, f2, x1, h1, h2) + expected_result = chain_fct(_fake_scan, f1, f2, x1, h1, h2) + self.assertEqual(result1, expected_result) + + if autograd: + self.check_autograd( + result1, expected_result, (h1, h2, x1, W_1, b_1, W_2, b_2) + ) @skipIfNoDynamoSupport def test_scan_simple_graph_wrong_dtype(self): @@ -2728,12 +3277,12 @@ def f(fct, init, xs): def forward(self, fct_1, init_1, xs_1): permute = torch.ops.aten.permute.default(xs_1, [0, 1, 2]) flip = torch.ops.aten.flip.default(permute, [0]); permute = None - sym_size_int = torch.ops.aten.sym_size.int(init_1, 1) - sym_size_int_1 = torch.ops.aten.sym_size.int(init_1, 2) - sym_size_int_2 = torch.ops.aten.sym_size.int(xs_1, 1) - sym_size_int_3 = torch.ops.aten.sym_size.int(xs_1, 2); xs_1 = None + sym_size_int_1 = torch.ops.aten.sym_size.int(init_1, 1) + sym_size_int_2 = torch.ops.aten.sym_size.int(init_1, 2) + sym_size_int_3 = torch.ops.aten.sym_size.int(xs_1, 1) + sym_size_int_4 = torch.ops.aten.sym_size.int(xs_1, 2); xs_1 = None scan_combine_graph_0 = self.scan_combine_graph_0 - scan = torch.ops.higher_order.scan(scan_combine_graph_0, [init_1], [flip], (sym_size_int, sym_size_int_1, sym_size_int_2, sym_size_int_3)); scan_combine_graph_0 = init_1 = flip = sym_size_int = sym_size_int_1 = sym_size_int_2 = sym_size_int_3 = None + scan = torch.ops.higher_order.scan(scan_combine_graph_0, [init_1], [flip], (sym_size_int_1, sym_size_int_2, sym_size_int_3, sym_size_int_4)); scan_combine_graph_0 = init_1 = flip = sym_size_int_1 = sym_size_int_2 = sym_size_int_3 = sym_size_int_4 = None getitem = scan[0] getitem_1 = scan[1]; scan = None flip_1 = torch.ops.aten.flip.default(getitem_1, [0]); getitem_1 = None @@ -6715,7 +7264,6 @@ def forward(self, L_init_ : torch.Tensor, L_xs_ : torch.Tensor, L_add_closure_0_ out = scan[1]; scan = None return (carry, out)""", # noqa: B950 ) - else: self.assertExpectedInline( backend.graphs[0].code.strip(), diff --git a/torch/_higher_order_ops/scan.py b/torch/_higher_order_ops/scan.py index 0d69141e6190a7..5a90783e0c3eaa 100644 --- a/torch/_higher_order_ops/scan.py +++ b/torch/_higher_order_ops/scan.py @@ -1,21 +1,23 @@ # mypy: allow-untyped-defs import functools import itertools -from typing import Any, Callable +from collections.abc import Sequence +from typing import Any, Callable, Optional import torch import torch._prims_common as utils -import torch._subclasses.functional_tensor import torch.utils._pytree as pytree from torch._C import DispatchKey +from torch._higher_order_ops.cond import create_bw_fn, materialize_as_graph from torch._higher_order_ops.utils import ( _has_potential_branch_input_alias, _has_potential_branch_input_mutation, _maybe_compile_and_run_fn, - autograd_not_implemented, check_meta_consistency, first_slice_copy, reenter_make_fx, + save_tensors_and_symints_for_backward, + saved_tensors_and_symints, unique_graph_id, UnsupportedAliasMutationException, validate_subgraph_args_types, @@ -45,7 +47,66 @@ def wrap_combine_fn_flat( def _extract_carry_and_out(flat_out: list[Any], num_carry: int): - return flat_out[:num_carry], flat_out[num_carry:] + return split_into_chunks(flat_out, [num_carry, len(flat_out) - num_carry]) + + +# We also do a clone with contiguous_format. This is to be consistent with +# eager semantic of scan, which stacks the outputs. The result is contiguous +# as a result of the stack operation. +def stack_y(y: torch.Tensor, scan_length: int) -> torch.Tensor: + return ( + y.unsqueeze(0) + .repeat(*([scan_length] + [1] * y.ndim)) + .clone(memory_format=torch.contiguous_format) + ) + + +# NOTE: These functions can be reused in associative_scan and eventually moved to +# torch._higher_order_ops.utils +def get_tensor_mask(tensor_list: list[Any]) -> list[bool]: + # Returns a mask whether a list element is a tensor or not + return [True if isinstance(v, torch.Tensor) else False for v in tensor_list] + + +def mask_list( + mask: list[bool], inp: list[Any], other: Optional[list[Any]] = None +) -> list[Any]: + # Masks elements on an `inp` list. + # If other is None, then the elements of the `inp` list where the mask is False are removed + # If other is not None, then the elements of the `inp` list where the mask is False are + # replaced with the elements of the `other` list + assert len(mask) == len( + inp + ), "The length of the mask needs to be identical to the length of the input" + if other is not None: + assert len(inp) == len( + other + ), "If an input and an other list is provided, they need to have the same length" + return [i if m else o for m, i, o in zip(mask, inp, other)] + else: + return [i for m, i in zip(mask, inp) if m] + + +def first_slice_copy_with_grad(li: list[Any]) -> list[Any]: + # First_slice_copy does not keep the original requires_grad flag, + # but we need it for materialize_as_graph + # in order to compute the correct gradients + # The reason why first_slice_copy doesn't keep requires_grad flag is + # because it's called in torch.autograd.Function.backward/forward. + slc = [first_slice_copy(x).requires_grad_(x.requires_grad) for x in li] + return slc + + +def split_into_chunks(iterable: Sequence[Any], chunk_sizes: list[int]) -> list[Any]: + it = iter(iterable) + assert sum(chunk_sizes) == len( + iterable + ), "the sum of all chunks needs to match the length of the iterable." + return [list(itertools.islice(it, size)) for size in chunk_sizes] + + +def call_operator(operator, *args): + return pytree.tree_leaves(operator(*args)) def scan( @@ -134,6 +195,14 @@ def _validate_input(cfn, lxs, linit, d, r): for x in lxs: if not isinstance(x, torch.Tensor): raise RuntimeError(f"All xs leaves must be a Tensor but got {x}") + if any(x.ndim <= d for x in lxs): + raise RuntimeError( + "All xs leaves must at least have 'dim' number of dimensions and scan dimension > 0" + ) + if any(x.shape[d] == 0 for x in lxs): + raise RuntimeError( + "All xs leaves must at least have 'dim' number of dimensions and scan dimension > 0" + ) ndim = leaves_xs_orig[0].ndim dim = utils.canonicalize_dim(ndim, dim) @@ -149,7 +218,6 @@ def _validate_input(cfn, lxs, linit, d, r): leaves_xs = [torch.flip(elem, [0]) for elem in leaves_xs] # TODO: Support _inductor lowering - # TODO: Support Autograd # TODO: Unify handling of pytrees for control flow ops, such as cond, while_loop, etc. combine_fn = functools.partial( @@ -202,9 +270,6 @@ def __call__(self, combine_fn, init, xs, additional_inputs): def generic_scan(operator, init, xs, dim=0, additional_inputs=()): - def call_operator(*args): - return pytree.tree_leaves(operator(*args)) - def _scan(init, xs): """Perform scan on `elems` using `elems_init.""" carry = init @@ -218,6 +283,7 @@ def _scan(init, xs): num_init_leaves = len(init) dummy_carry, dummy_out = _extract_carry_and_out( call_operator( + operator, *carry, *[first_slice_copy(elem, dim) for elem in xs], *additional_inputs, @@ -225,24 +291,26 @@ def _scan(init, xs): num_init_leaves, ) + out_tensor_mask = get_tensor_mask(dummy_out) + dummy_out_masked = mask_list(out_tensor_mask, dummy_out) + # Pre-alocate # outs -> Output matrix # idxs -> Index matrix for scatter_ # out: (num_elems, M, N, ...) # idx: (1, M, N) - outs, idxs = zip( - *[ - [ - torch.zeros( - [num_elems] + list(e.size()), - dtype=e.dtype, - device=e.device, - ), - torch.ones_like(e, dtype=torch.int64).unsqueeze(0), - ] - for i, e in enumerate(dummy_out) - ] - ) + outs = [ + torch.zeros( + [num_elems] + list(e.size()), + dtype=e.dtype, + device=e.device, + ) + for i, e in enumerate(dummy_out_masked) + ] + idxs = [ + torch.ones_like(e, dtype=torch.int64).unsqueeze(0) + for i, e in enumerate(dummy_out_masked) + ] def store_out_in_outs(out, ind): # Store the intermediate out in the outs matrix @@ -257,6 +325,7 @@ def store_out_in_outs(out, ind): ind = i carry, out = _extract_carry_and_out( call_operator( + operator, *carry, *[elem.select(dim, ind) for elem in xs], *additional_inputs, @@ -265,25 +334,17 @@ def store_out_in_outs(out, ind): ) # Store the inits in the outs matrix. - store_out_in_outs(out, ind) + store_out_in_outs(mask_list(out_tensor_mask, out), ind) + + # Expand outs with None depending on the tensor mask of the output + outs_expanded = [outs.pop(0) if out_m else None for out_m in out_tensor_mask] - return [*carry, *list(outs)] + return [*carry, *outs_expanded] scans = _scan(init, xs) return scans -# We also do a clone with contiguous_format. This is to be consistent with -# eager semantic of scan, which stacks the outputs. The result is contiguous -# as a result of the stack operation. -def stack_y(y: torch.Tensor, scan_length: int) -> torch.Tensor: - return ( - y.unsqueeze(0) - .repeat(*([scan_length] + [1] * y.ndim)) - .clone(memory_format=torch.contiguous_format) - ) - - def trace_scan( proxy_mode, func_overload, @@ -353,9 +414,423 @@ def scan_op_dense(combine_fn, init, xs, additional_inputs): return generic_scan(combine_fn, init, xs, additional_inputs=additional_inputs) -scan_op.py_impl(DispatchKey.Autograd)( - autograd_not_implemented(scan_op, deferred_error=True) -) +class ScanAutogradOp(torch.autograd.Function): + """ + Example :: + + def combine_fn(x: torch.Tensor, y: torch.Tensor): + next_carry = y = x * y + return next_carry, y + + The ``combine_fn_bw``, computing the gradients for x and y of ``combine_fn`` is computed as: + def combine_fn_bw(x: torch.Tensor, y: torch.Tensor, g_carry: torch.Tensor, g_y: torch.Tensor): + return g_y * y + g_carry * y, g_y * x + g_carry * x + + Note: In a real usecase of scan, there may be additional_inputs that participate in the + forward as well as in the backward of the scan operator. For the sake of readability those inputs + have been omitted in the following example, but are included in the subsequent detailed description below + + The forward output of scan is computed as: + carry, ys = scan(combine_fn, init, xs). + + This computation can be unpacked as + c_0, ys_0 = combine_fn(init, xs_0) + c_1, ys_1 = combine_fn(carry_0, xs_1) + c_2, ys_2 = combine_fn(carry_1, xs_2) + ... + c_T, ys_T = combine_fn(carry_(T-1), xs_T) + + We collect c_0, c_1, ..., c_T into a vector of carries that we save for the backward, + but we only output (c_T, ys), + where ys is the vector of all intermediate outputs [y_0, y_1, ..., y_T]. + + Given the carries and the ys, the gradients for xs and for init can be computed as follows: + We receive the upstream gradients in torch.autograd.Function, i.e., we get g_c_T and g_ys, + where g_ys is the vector of all intermediate gradients of the outputs [g_ys_0, g_ys_1, ..., g_ys_T] + + We then proceed to compute the gradients for the init (g_init) and the xs (g_xs) by running a + scan operation reverse over time. For example, + + g_c_(T-1), g_xs_T = combine_fn_bw(c_(T-1), xs_T, g_c_T, g_ys_T) + g_c_(T-2), g_xs_(T-1) = combine_fn_bw(c_(T-2), xs_(T-1), g_c_(T-1), g_ys_(T-1)) + g_c_(T-3), g_xs_(T-2) = combine_fn_bw(c_(T-3), xs_(T-2), g_c_(T-2), g_ys_(T-2)) + ... + g_init, g_xs_1 = combine_fn_bw(c_0, xs_1, g_c_0, g_ys_1) + 0 , g_xs_0 = combine_fn_bw(init, xs_0, g_init, g_ys_0), + + where combine_fn_bw takes the forward inputs of step t (i.e. c_(t-1), xs_t), + the gradients of the carry of step t (i.e. g_c_t) and + the upstream gradient of the output of step t (i.e. g_ys_T) + and returns the gradient of xs_t -> g_xs_t, as well as the gradient for the carry of step t-1 -> g_c_(t-1). + + Through this procedure we end up with the + gradients for the init -> g_init, + the gradients for the xs -> g_xs. + + + NOTE: [scan autograd implementation] + + The forward of scan can be computed as: + 1.) Prepare the forward graph wrapper ``combine_fn_with_carry_checkpoint``: + To use a scan operation for the backward path as well, we need access to the carries from all steps. + Thus, the function ``combine_fn`` is wrapped such that it returns all carries and not only the last carry. + In particular, we define ``combine_fn_with_carry_checkpoint``: + def combine_fn_with_carry_checkpoint(x: torch.Tensor, y: torch.Tensor): + carry, y = combine_fn(x, y) + return carry, (carry, y) + + The scan operator will stack all outputs along the scan dimension. + Thus, by putting next_carry also into outputs of ``combine_fn_with_carry_checkpoint``, + the carries from all steps will be stacked and hence gives us chekpointed_carries + + 2.) Compute all carries, the last carry and all outputs using ``combine_fn_with_carry_checkpoint``: + c_T, (carries, ys) = scan_op(combine_fn_with_carry_checkpoint, init, xs, additional_inputs), + Where c_T (last carry) and ys (all outputs) are the original results of scan with the ``combine_fn``. + However, carries are checkpointed carries from all steps. + As a result of the forward, only the last carry c_T and the ys are returned, + while all carries are saved for the backward. + + The backward of scan can be computed as: + + 3.) Prepare the backward graph: + We prepare the backward graph to be used in the backward function. + We utilize ``create_bw_fn`` to generate the joint function, i.e., + ctx._combine_fn_bw = create_bw_fn(ctx._combine_fn, fw_operands), where fw_operands = [init, xs_0, additional_inputs] + + The ctx._combine_fn_bw requires the primals (operands) + followed by the tangents (upstream gradients) from a single step + and produces the gradients of that step, i.e., + g_c_(T-1), g_xs_T, g_additional_input_T = ctx._combine_fn_bw(c_(T-1), xs_T, additional_inputs, g_c_T, g_ys_T). + + 4.) Create a wrapper of the ``combine_fn_bw``, i.e., ``combine_fn_bw_grad_accumulation``: + In the forward, there may be additional inputs that participate in every forward step. + The gradients for those additional inputs are also computed at every step and need to be accumulated over all steps, + which is taken care of in this wrapper. For example: + def combine_fn_bw_grad_accumulation(*args): + carried_g_additional_input = args[:num_additional_inputs] + inputs_bw_fn = args[num_additional_inputs:] + g_c_(t-1), g_xs_t, g_additional_input_t = ctx._combine_fn_bw(*inputs_bw_fn) + new_g_additional_inputs = carried_g_additional_input + g_additional_input_t + # The ``new_g_additional_inputs`` and the ``g_c_t`` are encoded in the carry of the backward scan operator + # The ``g_xs_t`` is encoded as the output of the backward scan operator + return [*new_g_additional_inputs, *g_c_t, *g_xs_t] + + 5.) Perform the backward scan as + g_additional_inputs, g_init, g_xs = scan_op(combine_fn_bw_grad_accumulation, bw_init, bw_xs), where + bw_init consists of the initial gradient carry for the additional_inputs (initialized with 0s): + initial_g_additional_inputs, and the gradient of the last carry: g_c_T. Thus: + bwd_init = [*initial_g_additional_inputs, *g_c_T]. + + bw_xs consists of the combination of the upstream gradients g_ys, + the forward carries prepended with the fw_init, i.e., bw_carries = concat([fw_init, fw_carries[:-1]]) and + the fw_xs. In particular, + bwd_xs = [*g_ys, *bw_carries, *fw_xs]. + + Note: g_c_T and g_ys are provided through the torch.autograd.Function.backward's input + + As demonstrated in the Example above, this backward scan then yields the gradient for the init -> g_init + and the gradient for the xs -> g_xs + + NOTE: [scan partial grad handling] + If any element of init, of xs, of the outputs or of the additional_inputs does not require gradients, + i.e., requires_grad=False, there will be still gradients returned for those elements, + but those gradients will be a tensor filled with zeros of the same shape as the element itself. + + A special case are additional_inputs that are not tensors. Such inputs can occur for example with symbolic tracing, + where the shape symbol (SymInt) becomes an additional_input. + For such cases, we compute a ``additional_inputs_tensor_mask``, which is True for elements of additional_inputs + that are tensors and False otherwise. Gradients of additional_inputs are only accumulated if this mask is True, + otherwise, the value of initial_g_additional_inputs is passed, which is None for non-Tensor values. + """ + + @staticmethod + def forward( + ctx, + combine_fn, + num_leaves_init, + num_leaves_xs, + num_additional_inputs, + *operands, + ): + ctx._num_leaves_init = num_leaves_init + ctx._num_leaves_xs = num_leaves_xs + ctx._num_additional_inputs = num_additional_inputs + ctx._combine_fn = combine_fn + init, xs, additional_inputs = split_into_chunks( + operands, [num_leaves_init, num_leaves_xs, num_additional_inputs] + ) + additional_inputs_tensor_mask = get_tensor_mask(additional_inputs) + ctx._additional_inputs_tensor_mask = additional_inputs_tensor_mask + + # We snapshot the dispatch keys in forward for materializing the + # the bw_graph in backward. + ctx._fw_include_key_set = torch._C._dispatch_tls_local_include_set() + ctx._fw_exclude_key_set = torch._C._dispatch_tls_local_exclude_set() + + # 1.) Prepare the forward graph wrapper ``combine_fn_with_carry_checkpoint`` + # The wrapper of the forward graph returns carries from all iterations, + # not just from the last iteration. These are required in the backward path + def combine_fn_with_carry_checkpoint(*args): + carry, y = _extract_carry_and_out(combine_fn(*args), num_leaves_init) + return [ + *carry, + # We additionally checkpoint all the intemediate carry outputs for backward. + *[ + n_c.clone().detach() if isinstance(n_c, torch.Tensor) else n_c + for n_c in carry + ], + *y, + ] + + with torch._C._AutoDispatchBelowAutograd(): + # 2.) Compute the all carries, the last carry and all outputs using ``combine_fn_with_carry_checkpoint`` + c_T, carries_ys = _extract_carry_and_out( + scan_op( + combine_fn_with_carry_checkpoint, + init, + xs, + additional_inputs, + ), + num_leaves_init, + ) + + # Collect the carries for each time step from the outs + # and save them for the backward path + carries = list(carries_ys[:num_leaves_init]) + ys = list(carries_ys[num_leaves_init:]) + save_tensors_and_symints_for_backward(ctx, list(operands) + carries + ys) + ctx._num_leaves_ys = len(ys) + + return (*c_T, *ys) + + @staticmethod + def backward(ctx, *flat_grads): + r""" + This function computes the gradients of the scan operation. + It does so by using a scan operator using all carries and the upstream gradients (see description above) + + Args: + flat_grads (torch.Tensor): The tensor of flattened upstream gradients. + """ + + # Collect the saved items from the forward + num_leaves_init = ctx._num_leaves_init + num_leaves_xs = ctx._num_leaves_xs + num_leaves_ys = ctx._num_leaves_ys + num_additional_inputs = ctx._num_additional_inputs + additional_inputs_tensor_mask = ctx._additional_inputs_tensor_mask + + def prepend_init_to_carries(init, carries): + # Prepare the carries for the backward path. + # This requires to concatenate the init and the carries + return [ + torch.cat([torch.unsqueeze(i, 0), c[:-1]], dim=0) + for i, c in zip(init, carries) + ] + + def initialize_g_additional_inputs( + additional_inputs, + ): + # The initial gradients for the additional_inputs are all zeros + g_additional_inputs = [ + torch.zeros_like(ai) if ai_tm else None + for ai_tm, ai in zip(additional_inputs_tensor_mask, additional_inputs) + ] + return g_additional_inputs + + # Retrieve the forward inputs and the forward outputs and dissect them + flat_args = saved_tensors_and_symints(ctx) + fw_init, fw_xs, additional_inputs, fw_carries, fw_ys = split_into_chunks( + flat_args, + [ + num_leaves_init, + num_leaves_xs, + num_additional_inputs, + num_leaves_init, + num_leaves_ys, + ], + ) + + # 3.) Prepare the backward graph + fw_operands = ( + *fw_init, + *[first_slice_copy(xs) for xs in fw_xs], + *additional_inputs, + ) + ctx._combine_fn_bw = create_bw_fn(ctx._combine_fn, fw_operands) + + # 4.) Create the BW wrapper to accumulate the gradients for the additional_inputs + def combine_fn_bw_grad_accumulation(*args): + # Dissect args and re-order them for the ``ctx._combine_fn_bw`` + # The content of ``combine_fn_bw_tangents`` is [*carries_g, *outs_g] + # The content of ``combine_fn_bw_primals`` is [*init, *xs, *additional_inputs] + ( + carried_g_additional_input, + combine_fn_bw_tangents, + combine_fn_bw_primals, + ) = split_into_chunks( + args, + [ + num_additional_inputs, + num_leaves_init + num_leaves_ys, + num_leaves_init + num_leaves_xs + num_additional_inputs, + ], + ) + combine_fn_bw_args = (*combine_fn_bw_primals, *combine_fn_bw_tangents) + + g_c_t, g_xs_t, g_additional_inputs_t = split_into_chunks( + ctx._combine_fn_bw(*combine_fn_bw_args), + [num_leaves_init, num_leaves_xs, num_additional_inputs], + ) + + new_g_additional_inputs = [ + # If the additional inputs are ints or SymInts, those values are taken as is and no gradients are added + carr_g + curr_g if add_inp_tm else carr_g + for add_inp_tm, carr_g, curr_g in zip( + additional_inputs_tensor_mask, + carried_g_additional_input, + g_additional_inputs_t, + ) + ] + + # The ``new_g_additional_inputs`` and the ``g_c_t`` are encoded in the carry of the backward scan operator + # The ``g_xs_t`` is encoded as the output of the backward scan operator + return [*new_g_additional_inputs, *g_c_t, *g_xs_t] + + # Materialize the ``combine_fn_bw_grad_accumulation`` + def construct_args_single_step_bw(): + # This function constructs the arguments for a single step of the backward scan. + # In other words, it creates the arguments for ``combine_fn_bw_grad_accumulation`` + # The order of the arguments returned is identical to the order the backward scan + # operations provides + + # The following arguments are used for the backward part of the joint graph + # The first argument relates to the gradient accumulation of the additional inputs. + # Because only tensor elements of additional inputs can have requires_grad=True, + # the values for non-tensor elements of additional inputs are None + masked_additional_inputs = [ + a.clone() if add_inp_tm else None + for add_inp_tm, a in zip( + additional_inputs_tensor_mask, additional_inputs + ) + ] + + # The second argument relates to the gradients of the carries. + # Because the arguments are for a single step only, + # only the first slice of the carries is used. + sliced_carries = [first_slice_copy(c) for c in fw_carries] + + # The third argument relates to the gradients of the ys. + # Because the arguments are for a single step only, + # only the first slice of the ys is used. + sliced_ys = [first_slice_copy(o) for o in fw_ys] + + # The following arguments are used for the forward part of the joint graph + # The fourth argument relates to the init for the forward. + # I.e., fw_init + + # The fifth argument relates to the xs for the forward. + # Because the arguments are for a single step only, + # only the first slice of the xs is used. + # Note: It is important to preserve the requires_grad flag of xs + # and thus we use the wrapper function ``first_slice_copy_with_grad`` + fw_xs_slice = first_slice_copy_with_grad(fw_xs) + + # The last argument relates to the additional inputs for the forward. + # I.e., additional_inputs + + return ( + *masked_additional_inputs, + *sliced_carries, + *sliced_ys, + *fw_init, + *fw_xs_slice, + *additional_inputs, + ) + + args_single_step_bw = construct_args_single_step_bw() + + # TODO: we need to materialize the bw graphs because dynamo is unable to + # trace through the joint function when torch.compile torch.autograd.grad. + combine_fn_bw_grad_accumulation_gm = materialize_as_graph( + combine_fn_bw_grad_accumulation, + args_single_step_bw, + ctx._fw_include_key_set, + ctx._fw_exclude_key_set, + force_enable_grad=True, + ) + + # Decompose the flat_grads into g_c_T, g_ys + g_c_T, g_ys = split_into_chunks(flat_grads, [num_leaves_init, num_leaves_ys]) + + # Initialize the g_additional_inputs with zero-tensors. + # This step is necessary because the gradients of the additional inputs are accumulated in the + # ``wrapper_bwd_combine_fn`` and thus need a zero-initialized starting point + initial_g_additional_inputs = initialize_g_additional_inputs(additional_inputs) + + # Prepend the inits to the carries. + # This is needed, because when computing the gradients, the last carry is not needed + # but the first carry, the init, is required. + bw_carries = prepend_init_to_carries(fw_init, fw_carries) + + # Prepare the xs for the backward scan. + bwd_xs = [*g_ys, *bw_carries, *fw_xs] + + # The flipping of the ``bwd_xs`` is necessary because the scan_op in the backward is always performed in reverse + bwd_xs = [torch.flip(elem, [0]) for elem in bwd_xs] + + # Prepare the bwd_init + bwd_init = [*initial_g_additional_inputs, *g_c_T] + + # 5.) Perform the backwrad scan: + # The ``combine_fn_bw_wrapped`` receives the + # initial_g_additional_inputs and the last carry as the ``bwd_init`` and the + # gradients of the outputs (g_ys), as well as the fw_carries and the fw_xs of the forward as the ``bwd_xs`` + gradients = scan_op( + combine_fn_bw_grad_accumulation_gm, + bwd_init, + bwd_xs, + additional_inputs, + ) + + # Unpack the computed gradients + g_additional_inputs, g_init, g_xs = split_into_chunks( + gradients, [num_additional_inputs, num_leaves_init, num_leaves_xs] + ) + + # The flipping back along the scan dimension is required to get the gradients in the right order for ``xs`` + g_xs = [torch.flip(elem, [0]) for elem in g_xs] + + return *[None] * 4, *g_init, *g_xs, *g_additional_inputs + + +@scan_op.py_impl(DispatchKey.Autograd) +def scan_autograd(combine_fn, init, xs, additional_inputs): + if not any( + el.requires_grad + for el in (tuple(init) + tuple(xs) + additional_inputs) + if isinstance(el, torch.Tensor) + ): + with torch._C._AutoDispatchBelowAutograd(): + return scan_op( + combine_fn, + init, + xs, + additional_inputs, + ) + + num_leaves_init = len(init) + num_leaves_xs = len(xs) + num_additional_inputs = len(additional_inputs) + + flat_out = ScanAutogradOp.apply( + combine_fn, + num_leaves_init, + num_leaves_xs, + num_additional_inputs, + *(tuple(init) + tuple(xs) + additional_inputs), + ) + return *flat_out[:num_leaves_init], *flat_out[num_leaves_init:] @scan_op.py_impl(ProxyTorchDispatchMode)