diff --git a/test/dynamo/test_subclasses.py b/test/dynamo/test_subclasses.py index 8005d6e3a28c44..ec6134cb636377 100644 --- a/test/dynamo/test_subclasses.py +++ b/test/dynamo/test_subclasses.py @@ -1361,6 +1361,47 @@ def fn(x): self._check_recompiles(fn, (nt,), (nt2,), False) self._check_recompiles(fn, (nt,), (nt3,), True) + def test_construct_from_jagged_with_offsets_from_inputs(self): + # Basic case + nt, _ = self._get_jagged_tensor(((2, 3, 4), 5), None) + nt2, _ = self._get_jagged_tensor(((2, 3, 4), 5), None) + + def fn(values, offsets): + return torch.nested.nested_tensor_from_jagged(values * 2, offsets) * 2 + + values = nt.values().requires_grad_(True) + out = torch.compile(fn, fullgraph=True, backend="aot_eager")(values, nt.offsets()) + ref_out = fn(values, nt.offsets()) + grad, = torch.autograd.grad(out, inputs=(values,), grad_outputs=(torch.ones_like(out),)) + ref_grad, = torch.autograd.grad(ref_out, inputs=(values,), grad_outputs=(torch.ones_like(ref_out),)) + self.assertEqual(out, ref_out) + self.assertEqual(grad, ref_grad) + + # Binary op + def fn(values, offsets, offsets2): + nt1 = torch.nested.nested_tensor_from_jagged(values * 2, offsets) + nt2 = torch.nested.nested_tensor_from_jagged(values * 3, offsets2) + return nt1 * nt2 + + out = torch.compile(fn, fullgraph=True, backend="aot_eager")(values, nt.offsets(), nt.offsets()) + ref_out = fn(values, nt.offsets(), nt.offsets()) + grad, = torch.autograd.grad(out, inputs=(values,), grad_outputs=(torch.ones_like(out),)) + ref_grad, = torch.autograd.grad(ref_out, inputs=(values,), grad_outputs=(torch.ones_like(ref_out),)) + self.assertEqual(out, ref_out) + self.assertEqual(grad, ref_grad) + + # Not only do we recompile, we also error out on the recompile with + # an error message mentioning data-dependent-ness. + with self.assertRaisesRegex(RuntimeError, "data-dependent"): + out = torch.compile(fn, fullgraph=True, backend="aot_eager")(values, nt.offsets(), nt2.offsets()) + + values = values.detach() + # Offsets which is an intermediate works without autograd + def fn(values, offsets): + return torch.nested.nested_tensor_from_jagged(values * 2, offsets.clone()) * 2 + + out = torch.compile(fn, fullgraph=True, backend="aot_eager")(values, nt.offsets()) + def test_inline_nested_tensor_from_jagged(self): nt, _ = self._get_jagged_tensor(((2, 3, 4), 5), None) @@ -1488,7 +1529,6 @@ def f(x): # view. To construct this intermediate properly, we need the associated nested int # to be symbolic. This view is expected to fail compilation until symbolic nested ints # are cached onto fake offsets to solve this problem. - @unittest.expectedFailure def test_subclass_dense_subclass_dense_view(self): x = self._get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True)[0].clone() offsets2 = x.offsets().clone().detach() diff --git a/test/test_fake_tensor.py b/test/test_fake_tensor.py index 1ffb0a6cb3ed16..cda49bda10433e 100644 --- a/test/test_fake_tensor.py +++ b/test/test_fake_tensor.py @@ -578,6 +578,7 @@ def test_same_shape_env_preserved(self): self.assertIs(t2.size(0).node.shape_env, t1.size(0).node.shape_env) self.assertEqual(str(t2.size(0)), str(t1.size(0))) + @unittest.skipIf(TEST_WITH_TORCHDYNAMO, "creating NJT in the middle of graph fails in some cases") def test_jagged_fake_to_fake_preserved(self): from torch.nested._internal.nested_tensor import jagged_from_list diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py index 6b7f7aa7ca13f9..a1daef1befc982 100644 --- a/test/test_nestedtensor.py +++ b/test/test_nestedtensor.py @@ -3787,7 +3787,6 @@ def test_unbind(self, device): for i, t in enumerate(out): self.assertEqual(t, tensor_list[i]) - @xfailIfTorchDynamo def test_layer_norm_2(self, device): test_tensor_list = self._get_list_for_jagged_tensor( ((2, 3, 4), 3), device=device, requires_grad=True diff --git a/torch/nested/_internal/nested_tensor.py b/torch/nested/_internal/nested_tensor.py index 5494633b09bbe6..6d432d8339b96b 100644 --- a/torch/nested/_internal/nested_tensor.py +++ b/torch/nested/_internal/nested_tensor.py @@ -5,6 +5,8 @@ from torch._prims_common import is_expandable_to from torch.fx.experimental.symbolic_shapes import has_free_symbols from torch.utils.weak import WeakTensorKeyDictionary +from torch._subclasses.fake_tensor import FakeTensor +from torch._subclasses.functional_tensor import FunctionalTensor from typing import * # noqa: F403 _tensor_id_counter = 0 @@ -15,11 +17,19 @@ def get_tensor_symint(tensor, *, coeff=1): global _tensor_id_counter tensor_symint = _tensor_symint_registry.get(tensor) if tensor_symint is None: - tensor_symint = torch._C._get_nested_int(_tensor_id_counter, coeff) - _tensor_id_counter += 1 - _tensor_symint_registry[tensor] = tensor_symint - return tensor_symint + if isinstance(tensor, FunctionalTensor): + tensor = torch._from_functional_tensor(tensor.elem) + return get_tensor_symint(tensor, coeff=coeff) + elif isinstance(tensor, FakeTensor): + shape_env = tensor.fake_mode.shape_env + tensor_symint = shape_env.create_unbacked_symint() + # Do we need to constrain as size? + else: + tensor_symint = torch._C._get_nested_int(_tensor_id_counter, coeff) + _tensor_id_counter += 1 + _tensor_symint_registry[tensor] = tensor_symint # cache the symint + return tensor_symint # SDPA metadata; max / min seqlens are needed for e.g. flash def _get_sdpa_extreme_seqlen(func, tensor): @@ -190,7 +200,10 @@ def __tensor_unflatten__(inner_tensors: Dict, meta, outer_size, outer_stride): # during aot autograd, FunctionalTensors are not fake but hold # symbolic sizes. ragged_source = offsets if lengths is None else lengths - if has_free_symbols(ragged_source) or has_free_symbols(values): + + # If we are constructing a NestedTensor from within the graph, the + # values may not be dynamic. TODO: what would happen in this case? + if isinstance(values, FakeTensor) or isinstance(values, FunctionalTensor): # Associate offsets or lengths (possibly fake, possibly functionalized) # with the ragged_size. ragged_size = outer_size[ragged_idx]