-
Notifications
You must be signed in to change notification settings - Fork 24.2k
Support narrow() on batch dim for NJT #142063
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/142063
Note: Links to docs will display an error until the docs builds have been completed. ❌ 2 New FailuresAs of commit dedcddf with merge base 46390e9 ( NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
cc cpuhrsch bhosmer drisspg soulitzer davidberard98 YuqingJ ezyang SherlockNoMad EikanWang jgong5 wenzhe-nrv [ghstack-poisoned]
cc cpuhrsch bhosmer drisspg soulitzer davidberard98 YuqingJ ezyang SherlockNoMad EikanWang jgong5 wenzhe-nrv [ghstack-poisoned]
cc cpuhrsch bhosmer drisspg soulitzer davidberard98 YuqingJ ezyang SherlockNoMad EikanWang jgong5 wenzhe-nrv [ghstack-poisoned]
cc cpuhrsch bhosmer drisspg soulitzer davidberard98 YuqingJ ezyang SherlockNoMad EikanWang jgong5 wenzhe-nrv [ghstack-poisoned]
cc cpuhrsch bhosmer drisspg soulitzer davidberard98 YuqingJ ezyang SherlockNoMad EikanWang jgong5 wenzhe-nrv [ghstack-poisoned]
cc cpuhrsch bhosmer drisspg soulitzer davidberard98 YuqingJ ezyang SherlockNoMad EikanWang jgong5 wenzhe-nrv [ghstack-poisoned]
cc cpuhrsch bhosmer drisspg soulitzer davidberard98 YuqingJ ezyang SherlockNoMad EikanWang jgong5 wenzhe-nrv [ghstack-poisoned]
cc cpuhrsch bhosmer drisspg soulitzer davidberard98 YuqingJ ezyang SherlockNoMad EikanWang jgong5 wenzhe-nrv [ghstack-poisoned]
cc cpuhrsch bhosmer drisspg soulitzer davidberard98 YuqingJ ezyang SherlockNoMad EikanWang jgong5 wenzhe-nrv [ghstack-poisoned]
cc cpuhrsch bhosmer drisspg soulitzer davidberard98 YuqingJ ezyang SherlockNoMad EikanWang jgong5 wenzhe-nrv [ghstack-poisoned]
cc cpuhrsch bhosmer drisspg soulitzer davidberard98 YuqingJ ezyang SherlockNoMad EikanWang jgong5 wenzhe-nrv [ghstack-poisoned]
cc cpuhrsch bhosmer drisspg soulitzer davidberard98 YuqingJ ezyang SherlockNoMad EikanWang jgong5 wenzhe-nrv [ghstack-poisoned]
@soulitzer / @cpuhrsch this is ready now - the PR works in eager / compile without graph breaks for contiguous NJTs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! Some small comments
end_val += inp._values.size(dim) | ||
start_val = max(min(start_val, inp._values.size(dim)), 0) | ||
end_val = max(min(end_val, inp._values.size(dim)), 0) | ||
length_val = max(min(length_val, end_val - start_val), 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunate that we have to duplicate narrow input checking/manipulation logic here, but I guess maybe hard to avoid due to as_strided use.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah agreed, I don't like this at all :( but indeed we aren't redispatching to narrow(), so we can't utilize the checks there
if operating_on_batch: | ||
# batch dim narrowing requires custom logic involving offsets | ||
out_kwargs = extract_kwargs(inp) | ||
start_val, length_val = new_kwargs["start"], new_kwargs["length"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Trying to think of weird edge cases (totally fine if normal narrow doesn't cover this either, but I guess we want parity?)
-
What happens if length_val is negative? I didn't see any explicit check, and docs say that it must "weakly positive" whatever that means 🤔.
-
What happens if start_val is negative, but length_val > size of the tensor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good call, will add test cases for these
it must "weakly positive" whatever that means
my understanding is that this allows for length=0? idk for sure though
test/test_nestedtensor.py
Outdated
@@ -8495,6 +8567,18 @@ def __torch_dispatch__(self, func, types, args=..., kwargs=None): | |||
op_match_fn=lambda device, op: (op.full_name in {"cdouble", "cfloat", "chalf"}), | |||
name="unimplemented_view_as_real", | |||
), | |||
# narrow(): unbacked SymInt bug with non-contig transposed inputs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oof. Just curious - Is there an issue somewhere? If there is, maybe worth linking.
torch/nested/_internal/ops.py
Outdated
start_val += inp._values.size(dim) | ||
if end_val < 0: | ||
end_val += inp._values.size(dim) | ||
start_val = max(min(start_val, inp._values.size(dim)), 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm feels like it shouldn't be possible to be negative here (prior to max with 0)
test/test_nestedtensor.py
Outdated
self.assertEqual(out3_comp, nt_comp) | ||
|
||
# length past the end | ||
with self.assertRaisesRegex(RuntimeError, "exceeds dimension size"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In theory this is the type of thing we could also handle via XFail with sample_match_fn right?
Is it fair to say that we don't want to clutter that too much with basic input validity checks, and reserve those for actual bugs, features that are not implemented.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah that's right - I have a mental TODO to use the error_inputs_func feature of OpInfos to handle this type of expected error checking failures. Ideally as you mentioned I'd want to avoid cluttering xfails with things that we'll never address and don't actually represent bugs
test/test_nestedtensor.py
Outdated
@torch._dynamo.utils.disable_cache_limit() | ||
@dtypes(torch.float32) | ||
@parametrize("env", ["eager", "compile", "compile_dynamic"]) | ||
def test_narrow_on_batch_dim(self, device, dtype, env): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder how much overlap there is with this test and the narrow OpInfo one.
Like could the "first few", "middle", and "last" batch items be formulated as sample inputs?
And then we could reformalate this one to two smaller tests that do more specific things, e.g. test_narrow_on_batch_dim_input_validation
and test_narrow_on_batch_dim_narrow_of_narrow
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Like could the "first few", "middle", and "last" batch items be formulated as sample inputs?
Definitely and I should do that :) will fix
Edit: I realized I'm kind of already doing this for the generated sample inputs on non-ragged dims
And then we could reformalate this one to two smaller tests that do more specific things, e.g. test_narrow_on_batch_dim_input_validation and test_narrow_on_batch_dim_narrow_of_narrow.
yeah these are harder to test with OpInfo so I think it makes sense to break them out
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh I just realized the compile tests for narrow-on-narrow weren't actually compiling :p
changing them to actually compile fails them; investigating
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
okay i tracked this down to incorrect clamping. old logic was clamping (start, end) on inner values dim space, but we want to clamp on outer batch dim space. This fixed the data-dependent guard errors.
There's still some work to be done for non-contiguous NJTs apparently, but I've been deprioritizing those in general so I'll just land this as-is
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
ghstack-source-id: 0f65635 Pull Request resolved: pytorch/pytorch#142063
Stack from ghstack (oldest at bottom):
Requested in #136270
cc @cpuhrsch @bhosmer @drisspg @soulitzer @davidberard98 @YuqingJ @ezyang @SherlockNoMad @EikanWang @jgong5 @wenzhe-nrv