8000 Update on "[NJT] Allow construction of NJT within graph using offsets… · pytorch/pytorch@13cc6c8 · GitHub
[go: up one dir, main page]

Skip to content

Commit 13cc6c8

Browse files
committed
Update on "[NJT] Allow construction of NJT within graph using offsets from inputs"
Creating symbolic nested ints within the graph is difficult. Using unbacked symints should solve the most important(?) cases in the mean time. See #118446 Known gaps: - creating NJT from intermediate offsets (offsets created within the graph, as opposed to being offsets passed in as inputs) - when the same offsets is also passed in as a input to the graph. We are not smart enough to realize that the offsets from that input is the same and therefore would fail when the sizes are compare ("s0 cannot be compared with u0") cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang [ghstack-poisoned]
1 parent 6c758af commit 13cc6c8

File tree

3 files changed

+11
-4
lines changed

3 files changed

+11
-4
lines changed

test/dynamo/test_subclasses.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1371,7 +1371,11 @@ def fn(values, offsets):
13711371

13721372
values = nt.values().requires_grad_(True)
13731373
out = torch.compile(fn, fullgraph=True, backend="aot_eager")(values, nt.offsets())
1374-
torch.autograd.grad(out, inputs=(values,), grad_outputs=(torch.ones_like(out),))
1374+
ref_out = fn(values, nt.offsets())
1375+
grad, = torch.autograd.grad(out, inputs=(values,), grad_outputs=(torch.ones_like(out),))
1376+
ref_grad, = torch.autograd.grad(ref_out, inputs=(values,), grad_outputs=(torch.ones_like(ref_out),))
1377+
self.assertEqual(out, ref_out)
1378+
self.assertEqual(grad, ref_grad)
13751379

13761380
# Binary op
13771381
def fn(values, offsets, offsets2):
@@ -1380,7 +1384,11 @@ def fn(values, offsets, offsets2):
13801384
return nt1 * nt2
13811385

13821386
out = torch.compile(fn, fullgraph=True, backend="aot_eager")(values, nt.offsets(), nt.offsets())
1383-
torch.autograd.grad(out, inputs=(values,), grad_outputs=(torch.ones_like(out),))
1387+
ref_out = fn(values, nt.offsets(), nt.offsets())
1388+
grad, = torch.autograd.grad(out, inputs=(values,), grad_outputs=(torch.ones_like(out),))
1389+
ref_grad, = torch.autograd.grad(ref_out, inputs=(values,), grad_outputs=(torch.ones_like(ref_out),))
1390+
self.assertEqual(out, ref_out)
1391+
self.assertEqual(grad, ref_grad)
13841392

13851393
# Not only do we recompile, we also error out on the recompile with
13861394
# an error message mentioning data-dependent-ness.
@@ -1521,7 +1529,6 @@ def f(x):
15211529
# view. To construct this intermediate properly, we need the associated nested int
15221530
# to be symbolic. This view is expected to fail compilation until symbolic nested ints
15231531
# are cached onto fake offsets to solve this problem.
1524-
@unittest.expectedFailure
15251532
def test_subclass_dense_subclass_dense_view(self):
15261533
x = self._get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True)[0].clone()
15271534
offsets2 = x.offsets().clone().detach()

test/test_fake_tensor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -578,6 +578,7 @@ def test_same_shape_env_preserved(self):
578578
self.assertIs(t2.size(0).node.shape_env, t1.size(0).node.shape_env)
579579
self.assertEqual(str(t2.size(0)), str(t1.size(0)))
580580

581+
@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "creating NJT in the middle of graph fails in some cases")
581582
def test_jagged_fake_to_fake_preserved(self):
582583
from torch.nested._internal.nested_tensor import jagged_from_list
583584

test/test_nestedtensor.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3787,7 +3787,6 @@ def test_unbind(self, device):
37873787
for i, t in enumerate(out):
37883788
self.assertEqual(t, tensor_list[i])
37893789

3790-
@xfailIfTorchDynamo
37913790
def test_layer_norm_2(self, device):
37923791
test_tensor_list = self._get_list_for_jagged_tensor(
37933792
((2, 3, 4), 3), device=device, requires_grad=True

0 commit comments

Comments
 (0)
0