8000 Implement fast access to individual elements of jagged nested tensors by fleonce · Pull Request #148497 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Implement fast access to individual elements of jagged nested tensors #148497

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

fleonce
Copy link
Contributor
@fleonce fleonce commented Mar 4, 2025

I removed the dependency on tensor.unbind() discussed in #148379 and replaced it with basic indexing ops on the values tensor based on the inputs.

Feedback would greatly be appreciated, I am not sure i got the part with the lengths right - wasnt able to find a lot of documentation on jagged tensors, I hope I understood NestedTensor._lengths correctly

Fixes #148379

Copy link
pytorch-bot bot commented Mar 4, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/148497

Note: Links to docs will display an error until the docs builds have been completed.

❌ 15 New Failures, 1 Unrelated Failure

As of commit 5dbff0f with merge base f30776c (image):

NEW FAILURES - The following jobs have failed:

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@fleonce fleonce marked this pull request as draft March 4, 2025 22:27
Copy link
Contributor
@jbschlosser jbschlosser left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @fleonce, thanks for the contribution!

Sorry for the lack of docs around _lengths - it's a more recent addition to allow for "non-contiguous with holes" views. Op support for that type of NJT is certainly less complete than for contiguous NJTs.

The idea is that _offsets define the start points within the packed dimension and _lengths define the lengths. Together, this is enough information to define a ragged structure with holes.

I added a couple comments for correctness, and I suggest running the pre-existing tests locally to ensure correctness of the refactor. Thanks again!

index_len = inp._lengths[index]
return inp._values[begin:end, :index_len]
# if tensor has no holes, we can just select from the start and end pos
return inp._values[begin:end]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that it's possible that the packed dim is not dim=0 of values. An NJT's _ragged_idx specifies which dim is ragged, and the dim at _ragged_idx - 1 is the packed dim within values.

The indexing here should ideally take this into account.. you could probably use something like inp._values.narrow() on _ragged_idx - 1 to get what you need.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I totally missed that, I'll look into it, thanks for pointing it out! I didnt know narrow() exists, I'll give it a try.

if inp._lengths is not None:
# if the tensor has a hole, we must include the size of the jagged dim for this element
index_len = inp._lengths[index]
return inp._values[begin:end, :index_len]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this isn't quite correct - you want the end to be defined by the lengths. So offsets[index] defines the beginning and offsets[index] + lengths[index] defines the end.

btw there are pre-existing tests that you can use to verify correctness for this refactor:

python test/test_nestedtensor.py -k test_forward_select
python test/test_nestedtensor.py -k test_backward_select
python test/test_nestedtensor.py -k test_compile_forward_select
python test/test_nestedtensor.py -k test_compile_backward_select

these should all pass, or if there are reasonable limitations, we should define the appropriate xfails / skips here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, I will update the code accordingly

@jbschlosser jbschlosser added topic: performance topic category release notes: nested tensor Changes that have a direct impact on nested tensors labels Mar 7, 2025
@fleonce fleonce marked this pull request as ready for review March 10, 2025 08:00
@fleonce
Copy link
Contributor Author
fleonce commented Mar 10, 2025

I managed to eliminate the non-torch.compile test errors, however I'm not sure how to approach the compile errors, all I'm seeing is something along these lines:

torch._dynamo.exc.UserError: Could not extract specialized integer from data-dependent expression u1 (unhinted: u1).  (Size-like symbols: none)

Caused by: return op_fn(*args, **kwargs)  # test/test_nestedtensor.py:8687 in f (_ops.py:799 in decompose)
For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u1"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing

User Stack (most recent call last):
  (snipped, see stack below for prefix)
  File "test/test_nestedtensor.py", line 8687, in f
    return op_fn(*args, **kwargs)

@fleonce
Copy link
Contributor Author
fleonce commented Mar 10, 2025

I'm starting to get where the issue lies, but am confused anyways:
if I guard all possible conditions for begin, end and length

             begin = begin.item()
             end = end.item()
             torch._check(begin >= 0)
             torch._check(end >= 0)
             torch._check(end >= begin)
             torch._check(begin < size)
             torch._check(end < size)
             length = end - begin
             torch._check(length >= 0)
             torch._check(begin + length == end)
             torch._check(begin + length < size)
             torch._check(end - length == begin)
             torch._check(end - length < size)
             torch._check_is_size(begin)
             torch._check_is_size(length)
             return inp._values.narrow(inp._ragged_idx - 1, begin, length)

i get another error:

torch._dynamo.exc.InternalTorchDynamoError: PendingUnbackedSymbolNotFound: Pending unbacked symbols {u1} not in returned outputs FakeTensor(..., size=(-u0 + u1, 4), dtype=torch.int64) ((4, 1), 4*u0).
Did you accidentally call new_dynamic_size() or item() more times than you needed to in your fake implementation?
For more help, see https://docs.google.com/document/d/1RWrH-3wLEpzR9kCS6gGBNen_-Fs-8PVbWWFE5AcgeWE/edit

for this simple repro:

import torch
torch._dynamo.config.capture_scalar_outputs = True

def fn(nt, *args, **kwargs):
    return nt.select(*args, **kwargs)

nt = torch.nested.nested_tensor([torch.full((3, 4), 4), torch.full((5, 4), 5), torch.full((3, 4), 4), torch.full((5, 4), 5)], layout=torch.jagged)

compiled_f = torch.compile(
    fn, fullgraph=True, backend="aot_eager_decomp_partition"
)

compiled_f(nt, 0, 3)

@jbschlosser
Copy link
Contributor
jbschlosser commented Mar 10, 2025

Hey @fleonce, sorry that you're running into these unbacked symint errors. This is certainly the most difficult part of dealing with NJT support within torch.compile.

As far as this error goes:

torch._dynamo.exc.InternalTorchDynamoError: PendingUnbackedSymbolNotFound: Pending unbacked symbols {u1} not in returned outputs FakeTensor(..., size=(-u0 + u1, 4), dtype=torch.int64) ((4, 1), 4*u0).
Did you accidentally call new_dynamic_size() or item() more times than you needed to in your fake implementation?
For more help, see https://docs.google.com/document/d/1RWrH-3wLEpzR9kCS6gGBNen_-Fs-8PVbWWFE5AcgeWE/edit

there are some limitations as to how unbacked SymInts are intended to be used. In particular, it's expected that any unbacked SymInt allocated during execution of an op can be directly pulled from the sizes, strides, or storage offsets of one of the tensors returned by the op. In this case, the returned tensor has size = (-u0 + u1, 4), and the system is not smart enough to pull out u1 from the expression -u0 + u1. This can be worked around by calculating length as an unbacked SymInt instead of end:

length = (end - begin).item()
begin = begin.item()

and using those in the call to narrow(). This sort of thing was done in my PR #142063 for more context.

@soulitzer soulitzer added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Mar 10, 2025
@fleonce
Copy link
Contributor Author
fleonce commented Mar 10, 2025

Hey @jbschlosser,
thank you again for responding, that makes a lot of sense to me, lots to learn I see :)

I'll give it a try and keep you posted in case I run into any additional issues!

…ds, start working on the backward compile test failures
@fleonce
Copy link
Contributor Author
fleonce commented Mar 10, 2025

Hello again @jbschlosser!

I managed to properly remove the test errors for the forward compile, the backwards compile however has proven very tricky, I needed to introduce another guard in the backward to ensure the compiler knows the gradient size for grad_input_view in the inp._ragged_idx - 1 dimension equals the size of the grad_output in the same dimension, as I was getting another guard failure otherwise:

inp = new_kwargs.pop("input")
grad_output = new_kwargs.pop("grad_output")
ragged_dim = inp._ragged_idx - 1
grad_input = torch.zeros_like(inp, dtype=grad_output.dtype)
grad_input_view = grad_input.select(new_kwargs["dim"], new_kwargs["index"])
torch._check(grad_input_view.size(ragged_dim) == grad_output.size(ragged_dim))
grad_input_view.copy_(grad_output)
return grad_input

When doing that, I managed to find another very interesting error that I need some additional help with I fear:

FAIL: test_compile_backward_select_cpu_float32 (__main__.TestNestedTensorOpInfoCPU.test_compile_backward_select_cpu_float32) (sample='4D_noncontig_transposed_with_seqlen_cache: batch_dim, index=5', idx=62)
----------------------------------------------------------------------
torch._dynamo.exc.BackendCompilerFailed: backend='aot_eager_decomp_partition' raised:
GuardOnDataDependentSymNode: Could not guard on data-dependent expression Eq(IsNonOverlappingAndDenseIndicator(5, 3, u8, 81, 27, 1), 1) (unhinted: Eq(IsNonOverlappingAndDenseIndicator(5, 3, u8, 3*s1, s1, 1), 1)).  (Size-like symbols: u8)

where u8 is the size of the selected inp._ragged_idx - 1:

grad_input=NestedTensor(size=(6, 5, 3, s0), offsets=FunctionalTensor(_to_functional_tensor(FakeTensor(..., size=(7,), dtype=torch.int64))), contiguous=True) grad_output=FunctionalTensor(_to_functional_tensor(FakeTensor(..., size=(5, 3, u0))))

Copy link
Contributor

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions bot added the Stale label May 10, 2025
@jbschlosser
Copy link
Contributor

Hey @fleonce, sorry for the delay - I haven't had time recently to debug this in detail. Data-dependent errors like the one you're seeing are quite hard to figure out and fix, unfortunately. There's a guide that may or may not help out. Wish I had more insight for you but unfortunately this work has been deprioritized on our end. If you're able to figure it out, I can help you get the change merged.

@jbschlosser jbschlosser removed the Stale label May 12, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
open source release notes: nested tensor Changes that have a direct impact on nested tensors topic: performance topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Improve nested jagged tensor select performance on batch dim
4 participants
0