8000 Remove compile_forward test case errors by adding (the correct!) guar… · pytorch/pytorch@dfa68cb · GitHub
[go: up one dir, main page]

Skip to content

Commit dfa68cb

Browse files
committed
Remove compile_forward test case errors by adding (the correct!) guards, start working on the backward compile test failures
1 parent 2e40b3e commit dfa68cb

File tree

1 file changed

+33
-25
lines changed

1 file changed

+33
-25
lines changed

torch/nested/_internal/ops.py

Lines changed: 33 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1795,36 +1795,41 @@ def select_int(func, *args, **kwargs):
17951795

17961796
if operating_on_batch:
17971797
index = new_kwargs["index"]
1798-
begin, end = inp._offsets.narrow(0, index, 2)
17991798
size = inp._values.size(inp._ragged_idx - 1)
1800-
begin = begin.item()
1801-
end = end.item()
1802-
# torch._check_is_size(begin)
1803-
# torch._check_is_size(end)
1804-
torch._check(begin >= 0)
1805-
torch._check(begin < size)
1799+
if size <= 1:
1800+
# i think this shortcut is necessary:
1801+
# when adding the guards below, test_compile_backward_select will
1802+
# try to guard (or rather test?) on the following
1803+
# Eq(s1, u8) which equals to Eq(size, length)
1804+
# which in turn is only true if inp.numel == 1
1805+
return inp._values
1806+
18061807
if inp._lengths is not None:
1807-
length = inp._lengths[index].item()
1808-
torch._check(length >= 0)
1809-
torch._check(begin + length < size)
1810-
torch._check(begin < size)
1811-
# torch._check_is_size(begin + length)
1808+
begin = inp._offsets.select(0, index)
1809+
begin = begin.item()
1810+
length = < 10000 span class=pl-s1>inp._lengths.select(0, index)
1811+
length = length.item()
18121812
else:
1813-
torch._check(begin >= 0)
1814-
torch._check(end >= 0)
1815-
torch._check(end >= begin)
1816-
torch._check(end < size)
1817-
length = end - begin
1818-
torch._check(length >= 0)
1819-
torch._check(begin + length == end)
1820-
torch._check(begin + length < size)
1821-
torch._check(end - length == begin)
1822-
torch._check(end - length < size)
1823-
# if tensor has no holes, we can just select from the start and end pos
1813+
begin, end = inp._offsets.narrow(0, index, 2)
1814+
length = (end - begin).item()
1815+
begin = begin.item()
1816+
1817+
# as stated above, (inp.numel() == 1) implies length == size
1818+
# but in any other case, length < size
1819+
# or do we support 0 length elements in NJTs?
1820+
1821+
# Eq(u8, u0) equals Eq(length, begin)
1822+
torch._check(begin >= 0)
1823+
torch._check(length >= 1)
1824+
torch._check(length > 0)
1825+
torch._check(length < size)
1826+
torch._check(begin + length <= size)
1827+
torch._check(begin < size)
1828+
torch._check_is_size(begin + length)
18241829
torch._check_is_size(begin)
18251830
torch._check_is_size(length)
1831+
18261832
return inp._values.narrow(inp._ragged_idx - 1, begin, length)
1827-
# return inp._values[begin:end]
18281833

18291834
if inp._lengths is not None:
18301835
raise ValueError(
@@ -2502,9 +2507,12 @@ def _nested_select_backward_default(func, *args, **kwargs):
25022507

25032508
inp = new_kwargs.pop("input")
25042509
grad_output = new_kwargs.pop("grad_output")
2510+
ragged_dim = inp._ragged_idx - 1
25052511

25062512
grad_input = torch.zeros_like(inp, dtype=grad_output.dtype)
2507-
grad_input.select(new_kwargs["dim"], new_kwargs["index"]).copy_(grad_output)
2513+
grad_input_view = grad_input.select(new_kwargs["dim"], new_kwargs["index"])
2514+
torch._check(grad_input_view.size(ragged_dim) == grad_output.size(ragged_dim))
2515+
grad_input_view.copy_(grad_output)
25082516

25092517
return grad_input
25102518

0 commit comments

Comments
 (0)
0