diff --git a/torch/nested/_internal/ops.py b/torch/nested/_internal/ops.py index 9525508d750706..e6f0d988157df4 100644 --- a/torch/nested/_internal/ops.py +++ b/torch/nested/_internal/ops.py @@ -1793,10 +1793,43 @@ def select_int(func, *args, **kwargs): inp.dim(), new_kwargs["dim"], inp._ragged_idx, "select", allow_batch_dim=True ) - # handle batch dim slicing via unbind() for now - # TODO: make this more efficient if operating_on_batch: - return inp.unbind()[new_kwargs["index"]] + index = new_kwargs["index"] + size = inp._values.size(inp._ragged_idx - 1) + if inp.size(new_kwargs["dim"]) == 1: + # i think this shortcut is necessary: + # when adding the guards below, test_compile_backward_select will + # try to guard (or rather test?) on the following + # Eq(s1, u8) which equals to Eq(size, length) + # which in turn is only true if inp.numel == 1 + return inp._values + + if inp._lengths is not None: + begin = inp._offsets.select(0, index) + begin = begin.item() + length = inp._lengths.select(0, index) + length = length.item() + else: + begin, end = inp._offsets.narrow(0, index, 2) + length = (end - begin).item() + begin = begin.item() + + # as stated above, (inp.numel() == 1) implies length == size + # but in any other case, length < size + # or do we support 0 length elements in NJTs? + + # Eq(u8, u0) equals Eq(length, begin) + torch._check(begin >= 0) + torch._check(length >= 1) + torch._check(length > 0) + torch._check(length < size) + torch._check(begin + length <= size) + torch._check(begin < size) + torch._check_is_size(begin + length) + torch._check_is_size(begin) + torch._check_is_size(length) + + return inp._values.narrow(inp._ragged_idx - 1, begin, length) if inp._lengths is not None: raise ValueError( @@ -2474,9 +2507,12 @@ def _nested_select_backward_default(func, *args, **kwargs): inp = new_kwargs.pop("input") grad_output = new_kwargs.pop("grad_output") + ragged_dim = inp._ragged_idx - 1 grad_input = torch.zeros_like(inp, dtype=grad_output.dtype) - grad_input.select(new_kwargs["dim"], new_kwargs["index"]).copy_(grad_output) + grad_input_view = grad_input.select(new_kwargs["dim"], new_kwargs["index"]) + torch._check(grad_input_view.size(ragged_dim) == grad_output.size(ragged_dim)) + grad_input_view.copy_(grad_output) return grad_input