-
Notifications
You must be signed in to change notification settings - Fork 24.2k
Implement fast access to individual elements of jagged nested tensors #148497
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/148497
Note: Links to docs will display an error until the docs builds have been completed. ❌ 15 New Failures, 1 Unrelated FailureAs of commit 5dbff0f with merge base f30776c ( NEW FAILURES - The following jobs have failed:
UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @fleonce, thanks for the contribution!
Sorry for the lack of docs around _lengths
- it's a more recent addition to allow for "non-contiguous with holes" views. Op support for that type of NJT is certainly less complete than for contiguous NJTs.
The idea is that _offsets
define the start points within the packed dimension and _lengths
define the lengths. Together, this is enough information to define a ragged structure with holes.
I added a couple comments for correctness, and I suggest running the pre-existing tests locally to ensure correctness of the refactor. Thanks again!
torch/nested/_internal/ops.py
Outdated
index_len = inp._lengths[index] | ||
return inp._values[begin:end, :index_len] | ||
# if tensor has no holes, we can just select from the start and end pos | ||
return inp._values[begin:end] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that it's possible that the packed dim is not dim=0 of values
. An NJT's _ragged_idx
specifies which dim is ragged, and the dim at _ragged_idx - 1
is the packed dim within values
.
The indexing here should ideally take this into account.. you could probably use something like inp._values.narrow()
on _ragged_idx - 1
to get what you need.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I totally missed that, I'll look into it, thanks for pointing it out! I didnt know narrow()
exists, I'll give it a try.
torch/nested/_internal/ops.py
Outdated
if inp._lengths is not None: | ||
# if the tensor has a hole, we must include the size of the jagged dim for this element | ||
index_len = inp._lengths[index] | ||
return inp._values[begin:end, :index_len] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this isn't quite correct - you want the end to be defined by the lengths. So offsets[index]
defines the beginning and offsets[index] + lengths[index]
defines the end.
btw there are pre-existing tests that you can use to verify correctness for this refactor:
python test/test_nestedtensor.py -k test_forward_select
python test/test_nestedtensor.py -k test_backward_select
python test/test_nestedtensor.py -k test_compile_forward_select
python test/test_nestedtensor.py -k test_compile_backward_select
these should all pass, or if there are reasonable limitations, we should define the appropriate xfails / skips here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you, I will update the code accordingly
I managed to eliminate the non-
|
I'm starting to get where the issue lies, but am confused anyways: begin = begin.item()
end = end.item()
torch._check(begin >= 0)
torch._check(end >= 0)
torch._check(end >= begin)
torch._check(begin < size)
torch._check(end < size)
length = end - begin
torch._check(length >= 0)
torch._check(begin + length == end)
torch._check(begin + length < size)
torch._check(end - length == begin)
torch._check(end - length < size)
torch._check_is_size(begin)
torch._check_is_size(length)
return inp._values.narrow(inp._ragged_idx - 1, begin, length) i get another error:
for this simple repro: import torch
torch._dynamo.config.capture_scalar_outputs = True
def fn(nt, *args, **kwargs):
return nt.select(*args, **kwargs)
nt = torch.nested.nested_tensor([torch.full((3, 4), 4), torch.full((5, 4), 5), torch.full((3, 4), 4), torch.full((5, 4), 5)], layout=torch.jagged)
compiled_f = torch.compile(
fn, fullgraph=True, backend="aot_eager_decomp_partition"
)
compiled_f(nt, 0, 3) |
Hey @fleonce, sorry that you're running into these unbacked symint errors. This is certainly the most difficult part of dealing with NJT support within torch.compile. As far as this error goes:
there are some limitations as to how unbacked SymInts are intended to be used. In particular, it's expected that any unbacked SymInt allocated during execution of an op can be directly pulled from the sizes, strides, or storage offsets of one of the tensors returned by the op. In this case, the returned tensor has
and using those in the call to |
Hey @jbschlosser, I'll give it a try and keep you posted in case I run into any additional issues! |
…ds, start working on the backward compile test failures
Hello again @jbschlosser! I managed to properly remove the test errors for the forward compile, the backwards compile however has proven very tricky, I needed to introduce another guard in the backward to ensure the compiler knows the gradient size for pytorch/torch/nested/_internal/ops.py Lines 2508 to 2517 in dfa68cb
When doing that, I managed to find another very interesting error that I need some additional help with I fear:
where
|
…the size of the ragged dim)
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
Hey @fleonce, sorry for the delay - I haven't had time recently to debug this in detail. Data-dependent errors like the one you're seeing are quite hard to figure out and fix, unfortunately. There's a guide that may or may not help out. Wish I had more insight for you but unfortunately this work has been deprioritized on our end. If you're able to figure it out, I can help you get the change merged. |
I removed the dependency on
tensor.unbind()
discussed in #148379 and replaced it with basic indexing ops on the values tensor based on the inputs.Feedback would greatly be appreciated, I am not sure i got the part with the lengths right - wasnt able to find a lot of documentation on jagged tensors, I hope I understood
NestedTensor._lengths
correctlyFixes #148379