8000 [NJT] Allow construction of NJT within graph using offsets from inputs by soulitzer · Pull Request #124624 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[NJT] Allow construction of NJT within graph using offsets from inputs #124624

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: gh/soulitzer/296/base
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 41 additions & 1 deletion test/dynamo/test_subclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -1361,6 +1361,47 @@ def fn(x):
self._check_recompiles(fn, (nt,), (nt2,), False)
self._check_recompiles(fn, (nt,), (nt3,), True)

def test_construct_from_jagged_with_offsets_from_inputs(self):
# Basic case
nt, _ = self._get_jagged_tensor(((2, 3, 4), 5), None)
nt2, _ = self._get_jagged_tensor(((2, 3, 4), 5), None)

def fn(values, offsets):
return torch.nested.nested_tensor_from_jagged(values * 2, offsets) * 2

values = nt.values().requires_grad_(True)
out = torch.compile(fn, fullgraph=True, backend="aot_eager")(values, nt.offsets())
ref_out = fn(values, nt.offsets())
grad, = torch.autograd.grad(out, inputs=(values,), grad_outputs=(torch.ones_like(out),))
ref_grad, = torch.autograd.grad(ref_out, inputs=(values,), grad_outputs=(torch.ones_like(ref_out),))
self.assertEqual(out, ref_out)
self.assertEqual(grad, ref_grad)

# Binary op
def fn(values, offsets, offsets2):
nt1 = torch.nested.nested_tensor_from_jagged(values * 2, offsets)
nt2 = torch.nested.nested_tensor_from_jagged(values * 3, offsets2)
return nt1 * nt2

out = torch.compile(fn, fullgraph=True, backend="aot_eager")(values, nt.offsets(), nt.offsets())
ref_out = fn(values, nt.offsets(), nt.offsets())
grad, = torch.autograd.grad(out, inputs=(values,), grad_outputs=(torch.ones_like(out),))
ref_grad, = torch.autograd.grad(ref_out, inputs=(values,), grad_outputs=(torch.ones_like(ref_out),))
self.assertEqual(out, ref_out)
self.assertEqual(grad, ref_grad)

# Not only do we recompile, we also error out on the recompile with
# an error message mentioning data-dependent-ness.
with self.assertRaisesRegex(RuntimeError, "data-dependent"):
out = torch.compile(fn, fullgraph=True, backend="aot_eager")(values, nt.offsets(), nt2.offsets())

values = values.detach()
# Offsets which is an intermediate works without autograd
def fn(values, offsets):
return torch.nested.nested_tensor_from_jagged(values * 2, offsets.clone()) * 2

out = torch.compile(fn, fullgraph=True, backend="aot_eager")(values, nt.offsets())

def test_inline_nested_tensor_from_jagged(self):
nt, _ = self._get_jagged_tensor(((2, 3, 4), 5), None)

Expand Down Expand Up @@ -1488,7 +1529,6 @@ def f(x):
# view. To construct this intermediate properly, we need the associated nested int
# to be symbolic. This view is expected to fail compilation until symbolic nested ints
# are cached onto fake offsets to solve this problem.
@unittest.expectedFailure
def test_subclass_dense_subclass_dense_view(self):
x = self._get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True)[0].clone()
offsets2 = x.offsets().clone().detach()
Expand Down
1 change: 1 addition & 0 deletions test/test_fake_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,7 @@ def test_same_shape_env_preserved(self):
self.assertIs(t2.size(0).node.shape_env, t1.size(0).node.shape_env)
self.assertEqual(str(t2.size(0)), str(t1.size(0)))

@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "creating NJT in the middle of graph fails in some cases")
def test_jagged_fake_to_fake_preserved(self):
from torch.nested._internal.nested_tensor import jagged_from_list

Expand Down
1 change: 0 additions & 1 deletion test/test_nestedtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3787,7 +3787,6 @@ def test_unbind(self, device):
for i, t in enumerate(out):
self.assertEqual(t, tensor_list[i])

@xfailIfTorchDynamo
def test_layer_norm_2(self, device):
test_tensor_list = self._get_list_for_jagged_tensor(
((2, 3, 4), 3), device=device, requires_grad=True
Expand Down
23 changes: 18 additions & 5 deletions torch/nested/_internal/nested_tensor.py
8000
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from torch._prims_common import is_expandable_to
from torch.fx.experimental.symbolic_shapes import has_free_symbols
from torch.utils.weak import WeakTensorKeyDictionary
from torch._subclasses.fake_tensor import FakeTensor
from torch._subclasses.functional_tensor import FunctionalTensor
from typing import * # noqa: F403

_tensor_id_counter = 0
Expand All @@ -15,11 +17,19 @@ def get_tensor_symint(tensor, *, coeff=1):
global _tensor_id_counter
tensor_symint = _tensor_symint_registry.get(tensor)
if tensor_symint is None:
tensor_symint = torch._C._get_nested_int(_tensor_id_counter, coeff)
_tensor_id_counter += 1
_tensor_symint_registry[tensor] = tensor_symint
return tensor_symint
if isinstance(tensor, FunctionalTensor):
tensor = torch._from_functional_tensor(tensor.elem)
return get_tensor_symint(tensor, coeff=coeff)
elif isinstance(tensor, FakeTensor):
shape_env = tensor.fake_mode.shape_env
tensor_symint = shape_env.create_unbacked_symint()
# Do we need to constrain as size?
Comment on lines +23 to +26
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why would this path ever be hit if we manually update the registry for fake tensors during tensor_unflatten?

else:
tensor_symint = torch._C._get_nested_int(_tensor_id_counter, coeff)
_tensor_id_counter += 1
_tensor_symint_registry[tensor] = tensor_symint # cache the symint

return tensor_symint

# SDPA metadata; max / min seqlens are needed for e.g. flash
def _get_sdpa_extreme_seqlen(func, tensor):
Expand Down Expand Up @@ -190,7 +200,10 @@ def __tensor_unflatten__(inner_tensors: Dict, meta, outer_size, outer_stride):
# during aot autograd, FunctionalTensors are not fake but hold
# symbolic sizes.
ragged_source = offsets if lengths is None else lengths
if has_free_symbols(ragged_source) or has_free_symbols(values):

# If we are constructing a NestedTensor from within the graph, the
# values may not be dynamic. TODO: what would happen in this case?
if isinstance(values, FakeTensor) or isinstance(values, FunctionalTensor):
# Associate offsets or lengths (possibly fake, possibly functionalized)
# with the ragged_size.
ragged_size = outer_size[ragged_idx]
Expand Down
Loading
0