8000 [map] always turn on dynamo for map (#152041) · pytorch/pytorch@ceb009b · GitHub
[go: up one dir, main page]

Skip to content

Commit ceb009b

Browse files
ydwu4pytorchmergebot
authored andcommitted
[map] always turn on dynamo for map (#152041)
Summary: X-link: pytorch/executorch#10409 Reland D72896450 Make map consistent with other control flow ops. After the change, map is able to support accessing closures in the map fn. Test Plan: See existing tests. Reviewed By: zou3519 Differential Revision: D73138427 Pull Request resolved: #152041 Approved by: https://github.com/zou3519
1 parent c5b4dc9 commit ceb009b

29 files changed

+232
-179
lines changed

test/dynamo/test_export.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1977,7 +1977,7 @@ def body(x):
19771977
xs = torch.randn(0, 2)
19781978
with self.assertRaisesRegex(
19791979
torch._dynamo.exc.Unsupported,
1980-
"zero-sized tensor",
1980+
"Observed exception",
19811981
):
19821982
torch._dynamo.export(mod)(xs)
19831983

test/dynamo/test_higher_order_ops.py

Lines changed: 32 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1748,18 +1748,17 @@ def forward(self, L_xs_ : torch.Tensor, L_y_ : torch.Tensor):
17481748
l_y_ = L_y_
17491749
map_body_1 = self.map_body_1
17501750
map_impl = torch.ops.higher_order.map_impl(map_body_1, [l_xs_], [l_y_]); map_body_1 = l_xs_ = l_y_ = None
1751-
getitem_1 = map_impl[0]; map_impl = None
1752-
return (getitem_1,)""",
1751+
getitem = map_impl[0]; map_impl = None
1752+
return (getitem,)""",
17531753
)
17541754
self.assertExpectedInline(
17551755
body_graph,
17561756
"""\
17571757
def forward(self, child : torch.Tensor, l_y_ : torch.Tensor):
1758-
child_1 = child[0]; child_1 = None
17591758
map_body_0 = self.map_body_0
17601759
map_impl = torch.ops.higher_order.map_impl(map_body_0, [child], [l_y_]); map_body_0 = child = l_y_ = None
1761-
getitem_1 = map_impl[0]; map_impl = None
1762-
return (getitem_1,)""",
1760+
getitem = map_impl[0]; map_impl = None
1761+
return (getitem,)""",
17631762
)
17641763

17651764
def test_map_multi_return(self):
@@ -1777,9 +1776,9 @@ def forward(self, L_x_ : torch.Tensor):
17771776
l_x_ = L_x_
17781777
map_body_0 = self.map_body_0
17791778
map_impl = torch.ops.higher_order.map_impl(map_body_0, [l_x_], []); map_body_0 = l_x_ = None
1780-
getitem_1 = map_impl[0]
1781-
getitem_2 = map_impl[1]; map_impl = None
1782-
return (getitem_1, getitem_2)""",
1779+
getitem = map_impl[0]
1780+
getitem_1 = map_impl[1]; map_impl = None
1781+
return (getitem, getitem_1)""",
17831782
)
17841783
self.assertExpectedInline(
17851784
body_graph,
@@ -1811,14 +1810,14 @@ def forward(self, L_x_ : torch.Tensor):
18111810
l_x_ = L_x_
18121811
map_body_0 = self.map_body_0
18131812
map_impl = torch.ops.higher_order.map_impl(map_body_0, [l_x_], []); map_body_0 = l_x_ = None
1814-
getitem_1 = map_impl[0]
1815-
getitem_2 = map_impl[1]
1816-
getitem_3 = map_impl[2]
1817-
getitem_4 = map_impl[3]
1818-
getitem_5 = map_impl[4]
1819-
getitem_6 = map_impl[5]
1813+
getitem = map_impl[0]
1814+
getitem_1 = map_impl[1]
1815+
getitem_2 = map_impl[2]
1816+
getitem_3 = map_impl[3]
1817+
getitem_4 = map_impl[4]
1818+
getitem_5 = map_impl[5]
18201819
value = map_impl[6]; map_impl = None
1821-
return (getitem_1, getitem_2, getitem_3, getitem_4, getitem_5, getitem_6, value)""",
1820+
return (getitem, getitem_1, getitem_2, getitem_3, getitem_4, getitem_5, value)""",
18221821
)
18231822
self.assertExpectedInline(
18241823
body_graph,
@@ -1857,8 +1856,8 @@ def forward(self, L_x_ : torch.Tensor):
18571856
l_x_ = L_x_
18581857
map_body_0 = self.map_body_0
18591858
map_impl = torch.ops.higher_order.map_impl(map_body_0, [l_x_], [3]); map_body_0 = l_x_ = None
1860-
getitem_1 = map_impl[0]; map_impl = None
1861-
return (getitem_1,)""",
1859+
getitem = map_impl[0]; map_impl = None
1860+
return (getitem,)""",
18621861
)
18631862
self.assertExpectedInline(
18641863
body_graph,
@@ -1888,8 +1887,8 @@ def forward(self, L_x_ : torch.Tensor):
18881887
l_x_ = L_x_
18891888
map_body_0 = self.map_body_0
18901889
map_impl = torch.ops.higher_order.map_impl(map_body_0, [l_x_], [3]); map_body_0 = l_x_ = None
1891-
getitem_1 = map_impl[0]; map_impl = None
1892-
return (getitem_1,)""",
1890+
getitem = map_impl[0]; map_impl = None
1891+
return (getitem,)""",
18931892
)
18941893
self.assertExpectedInline(
18951894
body_graph,
@@ -2279,15 +2278,12 @@ def body(x):
22792278
mod = Module()
22802279

22812280
mod_for_compile = torch.compile(mod, backend=cnt, dynamic=True, fullgraph=False)
2282-
mod_for_eager = Module()
22832281

2284-
res = mod_for_compile(torch.Tensor([[6, 4, 5], [3, 4, 5], [6, 6, 6]]))
2285-
# There is graph break right when we enter body of map
2286-
# Since we are tracing through the Python dispatch logic, it ends up 8 graphs.
2287-
self.assertEqual(len(backend.graphs), 8)
2288-
self.assertEqual(
2289-
res, mod_for_eager(torch.Tensor([[6, 4, 5], [3, 4, 5], [6, 6, 6]]))
2290-
)
2282+
with self.assertRaisesRegex(
2283+
torch._dynamo.exc.UncapturedHigherOrderOpError,
2284+
"map doesn't work unless it is captured completely with torch.compile",
2285+
):
2286+
mod_for_compile(torch.Tensor([[6, 4, 5], [3, 4, 5], [6, 6, 6]]))
22912287

22922288
def test_map_side_effect(self):
22932289
backend = EagerAndRecordGraphs()
@@ -2312,17 +2308,12 @@ def body(x):
23122308
mod = Module()
23132309

23142310
mod_for_compile = torch.compile(mod, backend=cnt, dynamic=True, fullgraph=False)
2315-
mod_for_eager = Module()
2316-
2317-
res = mod_for_compile(torch.Tensor([[6, 4, 5], [3, 4, 5], [6, 6, 6]]))
2318-
res = mod_for_compile(torch.Tensor([[6, 4, 5], [3, 4, 5], [6, 6, 6]]))
23192311

2320-
eager = mod_for_eager(torch.Tensor([[6, 4, 5], [3, 4, 5], [6, 6, 6]]))
2321-
eager = mod_for_eager(torch.Tensor([[6, 4, 5], [3, 4, 5], [6, 6, 6]]))
2322-
2323-
# Since we are tracing through the Python dispatch logic, it ends up 9 graphs.
2324-
self.assertEqual(len(backend.graphs), 9)
2325-
self.assertEqual(res, eager)
2312+
with self.assertRaisesRegex(
2313+
torch._dynamo.exc.UncapturedHigherOrderOpError,
2314+
"map doesn't work unless it is captured completely with torch.compile",
2315+
):
2316+
mod_for_compile(torch.Tensor([[6, 4, 5], [3, 4, 5], [6, 6, 6]]))
23262317

23272318
def test_wrap_subgraph_name_is_valid(self):
23282319
backend = EagerAndRecordGraphs()
@@ -2923,7 +2914,10 @@ def inner2(x, y):
29232914
actual_stack = self._get_source_fn_stack(gm, {"cos", "add", "sin"})
29242915
self.assertExpectedInline(
29252916
pprint.pformat(actual_stack),
2926-
"""{'add': ['map', 'map', 'add'], 'cos': ['map', 'cos'], 'sin': ['sin']}""",
2917+
"""\
2918+
{'add': ['map_impl', 'map_impl', 'add'],
2919+
'cos': ['map_impl', 'cos'],
2920+
'sin': ['sin']}""",
29272921
)
29282922

29292923
def test_grad_source_fn_stack(self):

test/dynamo/test_misc.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5825,11 +5825,15 @@ def body(x):
58255825

58265826
error_message = ""
58275827
if torch._dynamo.config.inline_inbuilt_nn_modules:
5828-
error_message = r"HigherOrderOperator: Mutating a variable not in the current scope \(SideEffects\)"
5828+
error_message = (
5829+
"map doesn't work unless it is captured completely with torch.compile"
5830+
)
58295831
else:
58305832
error_message = "Can't inplace modify module params/buffers"
58315833

5832-
with self.assertRaisesRegex(Unsupported, error_message):
5834+
with self.assertRaisesRegex(
5835+
torch._dynamo.exc.UncapturedHigherOrderOpError, error_message
5836+
):
58335837
opt_fn = torch.compile(mod, backend="eager", fullgraph=True)
58345838
opt_fn(torch.randn(3, 2))
58355839

test/dynamo_expected_failures/ExportTests.test_map_cond_param_buffer_lifted

Whitespace-only changes.

test/dynamo_expected_failures/FuncTorchHigherOrderOpTests.test_vmap_recompile_with_randomness

Whitespace-only changes.

test/dynamo_expected_failures/TestControlFlow.test_map_autograd_nested_list

Whitespace-only changes.

test/dynamo_expected_failures/TestControlFlow.test_map_autograd_no_grad_output

Whitespace-only changes.

test/dynamo_expected_failures/TestControlFlow.test_map_dict_in_out

Whitespace-only changes.

test/dynamo_expected_failures/TestControlFlow.test_map_list_in_out

Whitespace-only changes.

test/dynamo_expected_failures/TestControlFlowTraced.test_map_functionalized

Whitespace-only changes.

0 commit comments

Comments
 (0)
0