@@ -1795,36 +1795,41 @@ def select_int(func, *args, **kwargs):
1795
1795
1796
1796
if operating_on_batch :
1797
1797
index = new_kwargs ["index" ]
1798
- begin , end = inp ._offsets .narrow (0 , index , 2 )
1799
1798
size = inp ._values .size (inp ._ragged_idx - 1 )
1800
- begin = begin .item ()
1801
- end = end .item ()
1802
- # torch._check_is_size(begin)
1803
- # torch._check_is_size(end)
1804
- torch ._check (begin >= 0 )
1805
- torch ._check (begin < size )
1799
+ if size <= 1 :
1800
+ # i think this shortcut is necessary:
1801
+ # when adding the guards below, test_compile_backward_select will
1802
+ # try to guard (or rather test?) on the following
1803
+ # Eq(s1, u8) which equals to Eq(size, length)
1804
+ # which in turn is only true if inp.numel == 1
1805
+ return inp ._values
1806
+
1806
1807
if inp ._lengths is not None :
1807
- length = inp ._lengths [index ].item ()
1808
- torch ._check (length >= 0 )
1809
- torch ._check (begin + length < size )
1810
- torch ._check (begin < size )
1811
- # torch._check_is_size(begin + length)
1808
+ begin = inp ._offsets .select (0 , index )
1809
+ begin = begin .item ()
1810
+ length = <
10000
span class=pl-s1>inp._lengths .select (0 , index )
1811
+ length = length .item ()
1812
1812
else :
1813
- torch ._check (begin >= 0 )
1814
- torch ._check (end >= 0 )
1815
- torch ._check (end >= begin )
1816
- torch ._check (end < size )
1817
- length = end - begin
1818
- torch ._check (length >= 0 )
1819
- torch ._check (begin + length == end )
1820
- torch ._check (begin + length < size )
1821
- torch ._check (end - length == begin )
1822
- torch ._check (end - length < size )
1823
- # if tensor has no holes, we can just select from the start and end pos
1813
+ begin , end = inp ._offsets .narrow (0 , index , 2 )
1814
+ length = (end - begin ).item ()
1815
+ begin = begin .item ()
1816
+
1817
+ # as stated above, (inp.numel() == 1) implies length == size
1818
+ # but in any other case, length < size
1819
+ # or do we support 0 length elements in NJTs?
1820
+
1821
+ # Eq(u8, u0) equals Eq(length, begin)
1822
+ torch ._check (begin >= 0 )
1823
+ torch ._check (length >= 1 )
1824
+ torch ._check (length > 0 )
1825
+ torch ._check (length < size )
1826
+ torch ._check (begin + length <= size )
1827
+ torch ._check (begin < size )
1828
+ torch ._check_is_size (begin + length )
1824
1829
torch ._check_is_size (begin )
1825
1830
torch ._check_is_size (length )
1831
+
1826
1832
return inp ._values .narrow (inp ._ragged_idx - 1 , begin , length )
1827
- # return inp._values[begin:end]
1828
1833
1829
1834
if inp ._lengths is not None :
1830
1835
raise ValueError (
@@ -2502,9 +2507,12 @@ def _nested_select_backward_default(func, *args, **kwargs):
2502
2507
2503
2508
inp = new_kwargs .pop ("input" )
2504
2509
grad_output = new_kwargs .pop ("grad_output" )
2510
+ ragged_dim = inp ._ragged_idx - 1
2505
2511
2506
2512
grad_input = torch .zeros_like (inp , dtype = grad_output .dtype )
2507
- grad_input .select (new_kwargs ["dim" ], new_kwargs ["index" ]).copy_ (grad_output )
2513
+ grad_input_view = grad_input .select (new_kwargs ["dim" ], new_kwargs ["index" ])
2514
+ torch ._check (grad_input_view .size (ragged_dim ) == grad_output .size (ragged_dim ))
2515
+ grad_input_view .copy_ (grad_output )
2508
2516
2509
2517
return grad_input
2510
2518
0 commit comments