-
Notifications
You must be signed in to change notification settings - Fork 24.2k
[NJT] can only chunk if the 2nd dimension is ragged #153238
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Thanks for the report, @mdeff! Support for a ragged dimension not immediately next to the batch dimension (as in your second example) is newer and buggy wrt input validation. Validation logic here still uses an older definition of "contiguous" that required pytorch/torch/nested/_internal/ops.py Lines 114 to 116 in dc47295
The easiest fix for this particular case is to change the schema string from pytorch/torch/nested/_internal/ops.py Line 969 in dc47295
and error for truly non-contiguous inputs. We'd accept a PR implementing this if you'd like to contribute. |
Hi @jbschlosser , I just submitted a PR for this issue, please have a look. |
Thanks for the explanation and guidance, @jbschlosser! In general, there seems to be confusion/ambiguity between contiguity ( In #153237 I identified the contiguity–hollowness missmatch, which was updated in #153529 to mean contiguity. But maybe "hollowness" is more relevant in the Finally, it's assumed throughout the code that import torch
values = torch.randn(6, 5)
offsets = torch.tensor([0, 2, 3, 6])
lengths = torch.tensor([2, 1, 3])
x = torch.nested.nested_tensor_from_jagged(values, offsets, lengths)
assert x.lengths() is not None
assert not x.is_contiguous() # but it has no holes and is contiguous
assert torch.equal(x.lengths(), x.offsets().diff()) I see two fixes:
|
@mdeff Extremely solid analysis!
100% correct. I'll mention that support for holes was tacked on later, intending to contribute to support for "a ragged view of a dense tensor" usable for kv-cache logic.
Also correct. In theory, we -should- allow
I like 1 better, but even so, it's not ideal to compute the data-dependent, GPU sync incurring logic of
I think there's a good argument to be made for including this in the |
@jbschlosser thanks!
Makes sense! Nice to tackle that with NJTs.
Wasn't aware of it; thanks for educating me. With that in mind, the current "technically too strict" approach seems to strike a good tradeoff. :) |
🐛 Describe the bug
The following works as expected:
While the following fails:
With the unhelpful exception
NestedTensor chunk_default(self: jt, chunks: any, dim: any?): expected self to be a contiguous jagged layout NestedTensor
.Not a problem for the strided layout.
In real-world use, that'd allow to simplify
into
Versions
PyTorch 2.7.0+cu126 from PyPI.
Full collect_env.py output
cc @cpuhrsch @jbschlosser @bhosmer @drisspg @soulitzer @davidberard98 @YuqingJ
The text was updated successfully, but these errors were encountered: