8000 [NJT] Inline through torch.nested.nested_tensor_from_jagged instead o… · pytorch/pytorch@cf5ca58 · GitHub
[go: up one dir, main page]

Skip to content

Commit cf5ca58

Browse files
soulitzerpytorchmergebot
authored andcommitted
[NJT] Inline through torch.nested.nested_tensor_from_jagged instead of graph break (#124343)
Pull Request resolved: #124343 Approved by: https://github.com/jbschlosser
1 parent acbf888 commit cf5ca58

File tree

3 files changed

+10
-0
lines changed

3 files changed

+10
-0
lines changed

test/dynamo/test_subclasses.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1361,6 +1361,14 @@ def fn(x):
13611361
self._check_recompiles(fn, (nt,), (nt2,), False)
13621362
self._check_recompiles(fn, (nt,), (nt3,), True)
13631363

1364+
def test_inline_nested_tensor_from_jagged(self):
1365+
nt, _ = self._get_jagged_tensor(((2, 3, 4), 5), None)
1366+
1367+
def fn(x):
1368+
return torch.nested.nested_tensor_from_jagged(x.values() * 2, x.offsets())
1369+
1370+
torch.compile(fn, fullgraph=True, backend="aot_eager")(nt)
1371+
13641372
def _get_views(self):
13651373
# Test all cases with both an NT base and a dense base
13661374
# Subclass -> Subclass

test/profiler/test_profiler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -634,6 +634,7 @@ def test_execution_trace_no_capture(self):
634634
found_root_node = True
635635
assert found_root_node
636636

637+
@skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/124500")
637638
def test_execution_trace_nested_tensor(self):
638639
fp = tempfile.NamedTemporaryFile("w+t", suffix=".et.json", delete=False)
639640
fp.close()

torch/_dynamo/trace_rules.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@
173173
"torch.nn.Parameter": TorchInGraphFunctionVariable,
174174
"torch._nested_tensor_from_mask": SkipFunctionVariable,
175175
"torch._nested_from_padded": SkipFunctionVariable,
176+
"torch.nested.nested_tensor_from_jagged": UserFunctionVariable,
176177
# symbol operators implemented in Python
177178
"torch.sym_not": TorchInGraphFunctionVariable,
178179
"torch.sym_float": TorchInGraphFunctionVariable,

0 commit comments

Comments
 (0)
0