From 8063269cae5cc2d445c6d0ab5deb0feef90cce41 Mon Sep 17 00:00:00 2001 From: fleonce Date: Tue, 4 Mar 2025 23:22:06 +0100 Subject: [PATCH 1/5] Implement a faster access to individual elements of jagged nested tensors --- torch/nested/_internal/ops.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/torch/nested/_internal/ops.py b/torch/nested/_internal/ops.py index 9525508d750706..8cf7d677046b8e 100644 --- a/torch/nested/_internal/ops.py +++ b/torch/nested/_internal/ops.py @@ -1793,10 +1793,15 @@ def select_int(func, *args, **kwargs): inp.dim(), new_kwargs["dim"], inp._ragged_idx, "select", allow_batch_dim=True ) - # handle batch dim slicing via unbind() for now - # TODO: make this more efficient if operating_on_batch: - return inp.unbind()[new_kwargs["index"]] + index = new_kwargs["index"] + begin, end = inp._offsets[[index, index+1]] + if inp._lengths is not None: + # if the tensor has a hole, we must include the size of the jagged dim for this element + index_len = inp._lengths[index] + return inp._values[begin:end, :index_len] + # if tensor has no holes, we can just select from the start and end pos + return inp._values[begin:end] if inp._lengths is not None: raise ValueError( From 8ea0a8a9d31bb78085cd974c36f779dab76c5316 Mon Sep 17 00:00:00 2001 From: fleonce Date: Mon, 10 Mar 2025 08:59:59 +0100 Subject: [PATCH 2/5] Fix non-compile test cases by using narrow on the appropriate ragged dim --- torch/nested/_internal/ops.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/torch/nested/_internal/ops.py b/torch/nested/_internal/ops.py index 8cf7d677046b8e..8e06507794780b 100644 --- a/torch/nested/_internal/ops.py +++ b/torch/nested/_internal/ops.py @@ -1797,11 +1797,12 @@ def select_int(func, *args, **kwargs): index = new_kwargs["index"] begin, end = inp._offsets[[index, index+1]] if inp._lengths is not None: - # if the tensor has a hole, we must include the size of the jagged dim for this element - index_len = inp._lengths[index] - return inp._values[begin:end, :index_len] + length = inp._lengths[index] + else: + length = end - begin # if tensor has no holes, we can just select from the start and end pos - return inp._values[begin:end] + return inp._values.narrow(inp._ragged_idx - 1, begin, length) +# return inp._values[begin:end] if inp._lengths is not None: raise ValueError( From 2e40b3e7d938b29bc328c2797f78d323a2cdb364 Mon Sep 17 00:00:00 2001 From: fleonce Date: Mon, 10 Mar 2025 13:50:37 +0100 Subject: [PATCH 3/5] Adding more guards for now, should probably be removed later --- torch/nested/_internal/ops.py | 26 ++++++++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/torch/nested/_internal/ops.py b/torch/nested/_internal/ops.py index 8e06507794780b..357261ac4bcf6a 100644 --- a/torch/nested/_internal/ops.py +++ b/torch/nested/_internal/ops.py @@ -1795,12 +1795,34 @@ def select_int(func, *args, **kwargs): if operating_on_batch: index = new_kwargs["index"] - begin, end = inp._offsets[[index, index+1]] + begin, end = inp._offsets.narrow(0, index, 2) + size = inp._values.size(inp._ragged_idx - 1) + begin = begin.item() + end = end.item() +# torch._check_is_size(begin) +# torch._check_is_size(end) + torch._check(begin >= 0) + torch._check(begin < size) if inp._lengths is not None: - length = inp._lengths[index] + length = inp._lengths[index].item() + torch._check(length >= 0) + torch._check(begin + length < size) + torch._check(begin < size) +# torch._check_is_size(begin + length) else: + torch._check(begin >= 0) + torch._check(end >= 0) + torch._check(end >= begin) + torch._check(end < size) length = end - begin + torch._check(length >= 0) + torch._check(begin + length == end) + torch._check(begin + length < size) + torch._check(end - length == begin) + torch._check(end - length < size) # if tensor has no holes, we can just select from the start and end pos + torch._check_is_size(begin) + torch._check_is_size(length) return inp._values.narrow(inp._ragged_idx - 1, begin, length) # return inp._values[begin:end] From dfa68cbf044f85b3288a30e8a6f3184a048df4f6 Mon Sep 17 00:00:00 2001 From: fleonce Date: Tue, 11 Mar 2025 00:25:00 +0100 Subject: [PATCH 4/5] Remove compile_forward test case errors by adding (the correct!) guards, start working on the backward compile test failures --- torch/nested/_internal/ops.py | 58 ++++++++++++++++++++--------------- 1 file changed, 33 insertions(+), 25 deletions(-) diff --git a/torch/nested/_internal/ops.py b/torch/nested/_internal/ops.py index 357261ac4bcf6a..f9e9846a458dd2 100644 --- a/torch/nested/_internal/ops.py +++ b/torch/nested/_internal/ops.py @@ -1795,36 +1795,41 @@ def select_int(func, *args, **kwargs): if operating_on_batch: index = new_kwargs["index"] - begin, end = inp._offsets.narrow(0, index, 2) size = inp._values.size(inp._ragged_idx - 1) - begin = begin.item() - end = end.item() -# torch._check_is_size(begin) -# torch._check_is_size(end) - torch._check(begin >= 0) - torch._check(begin < size) + if size <= 1: + # i think this shortcut is necessary: + # when adding the guards below, test_compile_backward_select will + # try to guard (or rather test?) on the following + # Eq(s1, u8) which equals to Eq(size, length) + # which in turn is only true if inp.numel == 1 + return inp._values + if inp._lengths is not None: - length = inp._lengths[index].item() - torch._check(length >= 0) - torch._check(begin + length < size) - torch._check(begin < size) -# torch._check_is_size(begin + length) + begin = inp._offsets.select(0, index) + begin = begin.item() + length = inp._lengths.select(0, index) + length = length.item() else: - torch._check(begin >= 0) - torch._check(end >= 0) - torch._check(end >= begin) - torch._check(end < size) - length = end - begin - torch._check(length >= 0) - torch._check(begin + length == end) - torch._check(begin + length < size) - torch._check(end - length == begin) - torch._check(end - length < size) - # if tensor has no holes, we can just select from the start and end pos + begin, end = inp._offsets.narrow(0, index, 2) + length = (end - begin).item() + begin = begin.item() + + # as stated above, (inp.numel() == 1) implies length == size + # but in any other case, length < size + # or do we support 0 length elements in NJTs? + + # Eq(u8, u0) equals Eq(length, begin) + torch._check(begin >= 0) + torch._check(length >= 1) + torch._check(length > 0) + torch._check(length < size) + torch._check(begin + length <= size) + torch._check(begin < size) + torch._check_is_size(begin + length) torch._check_is_size(begin) torch._check_is_size(length) + return inp._values.narrow(inp._ragged_idx - 1, begin, length) -# return inp._values[begin:end] if inp._lengths is not None: raise ValueError( @@ -2502,9 +2507,12 @@ def _nested_select_backward_default(func, *args, **kwargs): inp = new_kwargs.pop("input") grad_output = new_kwargs.pop("grad_output") + ragged_dim = inp._ragged_idx - 1 grad_input = torch.zeros_like(inp, dtype=grad_output.dtype) - grad_input.select(new_kwargs["dim"], new_kwargs["index"]).copy_(grad_output) + grad_input_view = grad_input.select(new_kwargs["dim"], new_kwargs["index"]) + torch._check(grad_input_view.size(ragged_dim) == grad_output.size(ragged_dim)) + grad_input_view.copy_(grad_output) return grad_input From 5dbff0fa68c8289e84b8632d336d227291dba0b0 Mon Sep 17 00:00:00 2001 From: fleonce Date: Tue, 11 Mar 2025 08:41:29 +0100 Subject: [PATCH 5/5] Use the correct size to check for size==1 (use the input tensor, not the size of the ragged dim) --- torch/nested/_internal/ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/nested/_internal/ops.py b/torch/nested/_internal/ops.py index f9e9846a458dd2..e6f0d988157df4 100644 --- a/torch/nested/_internal/ops.py +++ b/torch/nested/_internal/ops.py @@ -1796,7 +1796,7 @@ def select_int(func, *args, **kwargs): if operating_on_batch: index = new_kwargs["index"] size = inp._values.size(inp._ragged_idx - 1) - if size <= 1: + if inp.size(new_kwargs["dim"]) == 1: # i think this shortcut is necessary: # when adding the guards below, test_compile_backward_select will # try to guard (or rather test?) on the following