@@ -1748,18 +1748,17 @@ def forward(self, L_xs_ : torch.Tensor, L_y_ : torch.Tensor):
1748
1748
l_y_ = L_y_
1749
1749
map_body_1 = self.map_body_1
1750
1750
map_impl = torch.ops.higher_order.map_impl(map_body_1, [l_xs_], [l_y_]); map_body_1 = l_xs_ = l_y_ = None
1751
- getitem_1 = map_impl[0]; map_impl = None
1752
- return (getitem_1 ,)""" ,
1751
+ getitem = map_impl[0]; map_impl = None
1752
+ return (getitem ,)""" ,
1753
1753
)
1754
1754
self .assertExpectedInline (
1755
1755
body_graph ,
1756
1756
"""\
1757
1757
def forward(self, child : torch.Tensor, l_y_ : torch.Tensor):
1758
- child_1 = child[0]; child_1 = None
1759
1758
map_body_0 = self.map_body_0
1760
1759
map_impl = torch.ops.higher_order.map_impl(map_body_0, [child], [l_y_]); map_body_0 = child = l_y_ = None
1761
- getitem_1 = map_impl[0]; map_impl = None
1762
- return (getitem_1 ,)""" ,
1760
+ getitem = map_impl[0]; map_impl = None
1761
+ return (getitem ,)""" ,
1763
1762
)
1764
1763
1765
1764
def test_map_multi_return (self ):
@@ -1777,9 +1776,9 @@ def forward(self, L_x_ : torch.Tensor):
1777
1776
l_x_ = L_x_
1778
1777
map_body_0 = self.map_body_0
1779
1778
map_impl = torch.ops.higher_order.map_impl(map_body_0, [l_x_], []); map_body_0 = l_x_ = None
1780
- getitem_1 = map_impl[0]
1781
- getitem_2 = map_impl[1]; map_impl = None
1782
- return (getitem_1, getitem_2 )""" ,
1779
+ getitem = map_impl[0]
1780
+ getitem_1 = map_impl[1]; map_impl = None
1781
+ return (getitem, getitem_1 )""" ,
1783
1782
)
1784
1783
self .assertExpectedInline (
1785
1784
body_graph ,
@@ -1811,14 +1810,14 @@ def forward(self, L_x_ : torch.Tensor):
1811
1810
l_x_ = L_x_
1812
1811
map_body_0 = self.map_body_0
1813
1812
map_impl = torch.ops.higher_order.map_impl(map_body_0, [l_x_], []); map_body_0 = l_x_ = None
1814
- getitem_1 = map_impl[0]
1815
- getitem_2 = map_impl[1]
1816
- getitem_3 = map_impl[2]
1817
- getitem_4 = map_impl[3]
1818
- getitem_5 = map_impl[4]
1819
- getitem_6 = map_impl[5]
1813
+ getitem = map_impl[0]
1814
+ getitem_1 = map_impl[1]
1815
+ getitem_2 = map_impl[2]
1816
+ getitem_3 = map_impl[3]
1817
+ getitem_4 = map_impl[4]
1818
+ getitem_5 = map_impl[5]
1820
1819
value = map_impl[6]; map_impl = None
1821
- return (getitem_1, getitem_2, getitem_3, getitem_4, getitem_5, getitem_6 , value)""" ,
1820
+ return (getitem, getitem_1, getitem_2, getitem_3, getitem_4, getitem_5, value)""" ,
1822
1821
)
1823
1822
self .assertExpectedInline (
1824
1823
body_graph ,
@@ -1857,8 +1856,8 @@ def forward(self, L_x_ : torch.Tensor):
1857
1856
l_x_ = L_x_
1858
1857
map_body_0 = self.map_body_0
1859
1858
map_impl = torch.ops.higher_order.map_impl(map_body_0, [l_x_], [3]); map_body_0 = l_x_ = None
1860
- getitem_1 = map_impl[0]; map_impl = None
1861
- return (getitem_1 ,)""" ,
1859
+ getitem = map_impl[0]; map_impl = None
1860
+ return (getitem ,)""" ,
1862
1861
)
1863
1862
self .assertExpectedInline (
1864
1863
body_graph ,
@@ -1888,8 +1887,8 @@ def forward(self, L_x_ : torch.Tensor):
1888
1887
l_x_ = L_x_
1889
1888
map_body_0 = self.map_body_0
1890
1889
map_impl = torch.ops.higher_order.map_impl(map_body_0, [l_x_], [3]); map_body_0 = l_x_ = None
1891
- getitem_1 = map_impl[0]; map_impl = None
1892
- return (getitem_1 ,)""" ,
1890
+ getitem = map_impl[0]; map_impl = None
1891
+ return (getitem ,)""" ,
1893
1892
)
1894
1893
self .assertExpectedInline (
1895
1894
body_graph ,
@@ -2279,15 +2278,12 @@ def body(x):
2279
2278
mod = Module ()
2280
2279
2281
2280
mod_for_compile = torch .compile (mod , backend = cnt , dynamic = True , fullgraph = False )
2282
- mod_for_eager = Module ()
2283
2281
2284
- res = mod_for_compile (torch .Tensor ([[6 , 4 , 5 ], [3 , 4 , 5 ], [6 , 6 , 6 ]]))
2285
- # There is graph break right when we enter body of map
2286
- # Since we are tracing through the Python dispatch logic, it ends up 8 graphs.
2287
- self .assertEqual (len (backend .graphs ), 8 )
2288
- self .assertEqual (
2289
- res , mod_for_eager (torch .Tensor ([[6 , 4 , 5 ], [3 , 4 , 5 ], [6 , 6 , 6 ]]))
2290
- )
2282
+ with self .assertRaisesRegex (
2283
+ torch ._dynamo .exc .UncapturedHigherOrderOpError ,
2284
+ "map doesn't work unless it is captured completely with torch.compile" ,
2285
+ ):
2286
+ mod_for_compile (torch .Tensor ([[6 , 4 , 5 ], [3 , 4 , 5 ], [6 , 6 , 6 ]]))
2291
2287
2292
2288
def test_map_side_effect (self ):
2293
2289
backend = EagerAndRecordGraphs ()
@@ -2312,17 +2308,12 @@ def body(x):
2312
2308
mod = Module ()
2313
2309
2314
2310
mod_for_compile = torch .compile (mod , backend = cnt , dynamic = True , fullgraph = False )
2315
- mod_for_eager = Module ()
2316
-
2317
- res = mod_for_compile (torch .Tensor ([[6 , 4 , 5 ], [3 , 4 , 5 ], [6 , 6 , 6 ]]))
2318
- res = mod_for_compile (torch .Tensor ([[6 , 4 , 5 ], [3 , 4 , 5 ], [6 , 6 , 6 ]]))
2319
2311
2320
- eager = mod_for_eager (torch .Tensor ([[6 , 4 , 5 ], [3 , 4 , 5 ], [6 , 6 , 6 ]]))
2321
- eager = mod_for_eager (torch .Tensor ([[6 , 4 , 5 ], [3 , 4 , 5 ], [6 , 6 , 6 ]]))
2322
-
2323
- # Since we are tracing through the Python dispatch logic, it ends up 9 graphs.
2324
- self .assertEqual (len (backend .graphs ), 9 )
2325
- self .assertEqual (res , eager )
2312
+ with self .assertRaisesRegex (
2313
+ torch ._dynamo .exc .UncapturedHigherOrderOpError ,
2314
+ "map doesn't work unless it is captured completely with torch.compile" ,
2315
+ ):
2316
+ mod_for_compile (torch .Tensor ([[6 , 4 , 5 ], [3 , 4 , 5 ], [6 , 6 , 6 ]]))
2326
2317
2327
2318
def test_wrap_subgraph_name_is_valid (self ):
2328
2319
backend = EagerAndRecordGraphs ()
@@ -2923,7 +2914,10 @@ def inner2(x, y):
2923
2914
actual_stack = self ._get_source_fn_stack (gm , {"cos" , "add" , "sin" })
2924
2915
self .assertExpectedInline (
2925
2916
pprint .pformat (actual_stack ),
2926
- """{'add': ['map', 'map', 'add'], 'cos': ['map', 'cos'], 'sin': ['sin']}""" ,
2917
+ """\
2918
+ {'add': ['map_impl', 'map_impl', 'add'],
2919
+ 'cos': ['map_impl', 'cos'],
2920
+ 'sin': ['sin']}""" ,
2927
2921
)
2928
2922
2929
2923
def test_grad_source_fn_stack (self ):
0 commit comments