8000 Support narrow() on batch dim for NJT by jbschlosser · Pull Request #142063 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Support narrow() on batch dim for NJT #142063

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 20 commits into from

Conversation

jbschlosser
Copy link
Contributor
@jbschlosser jbschlosser commented Dec 4, 2024

Copy link
pytorch-bot bot commented Dec 4, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/142063

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 New Failures

As of commit dedcddf with merge base 46390e9 (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: fx release notes category label Dec 4, 2024
jbschlosser added a commit that referenced this pull request Dec 4, 2024
ghstack-source-id: 11452bd
Pull Request resolved: #142063
@jbschlosser jbschlosser added module: nestedtensor NestedTensor tag see issue #25032 topic: improvements topic category release notes: nested tensor Changes that have a direct impact on nested tensors labels Dec 4, 2024
@albanD albanD removed their request for review December 4, 2024 19:48
cc cpuhrsch bhosmer drisspg soulitzer davidberard98 YuqingJ ezyang SherlockNoMad EikanWang jgong5 wenzhe-nrv

[ghstack-poisoned]
jbschlosser added a commit that referenced this pull request Dec 5, 2024
ghstack-source-id: 11452bd
Pull Request resolved: #142063
cc cpuhrsch bhosmer drisspg soulitzer davidberard98 YuqingJ ezyang SherlockNoMad EikanWang jgong5 wenzhe-nrv

[ghstack-poisoned]
jbschlosser added a commit that referenced this pull request Dec 5, 2024
ghstack-source-id: ec7ce2e
Pull Request resolved: #142063
cc cpuhrsch bhosmer drisspg soulitzer davidberard98 YuqingJ ezyang SherlockNoMad EikanWang jgong5 wenzhe-nrv

[ghstack-poisoned]
jbschlosser added a commit that referenced this pull request Dec 5, 2024
ghstack-source-id: 86b5fac
Pull Request resolved: #142063
cc cpuhrsch bhosmer drisspg soulitzer davidberard98 YuqingJ ezyang SherlockNoMad EikanWang jgong5 wenzhe-nrv

[ghstack-poisoned]
jbschlosser added a commit that referenced this pull request Dec 5, 2024
ghstack-source-id: e170412
Pull Request resolved: #142063
cc cpuhrsch bhosmer drisspg soulitzer davidberard98 YuqingJ ezyang SherlockNoMad EikanWang jgong5 wenzhe-nrv

[ghstack-poisoned]
jbschlosser added a commit that referenced this pull request Dec 5, 2024
ghstack-source-id: e170412
Pull Request resolved: #142063
cc cpuhrsch bhosmer drisspg soulitzer davidberard98 YuqingJ ezyang SherlockNoMad EikanWang jgong5 wenzhe-nrv

[ghstack-poisoned]
jbschlosser added a commit that referenced this pull request Dec 6, 2024
ghstack-source-id: eee29a8
Pull Request resolved: #142063
cc cpuhrsch bhosmer drisspg soulitzer davidberard98 YuqingJ ezyang SherlockNoMad EikanWang jgong5 wenzhe-nrv

[ghstack-poisoned]
jbschlosser added a commit that referenced this pull request Dec 11, 2024
ghstack-source-id: 328a004
Pull Request resolved: #142063
jbschlosser added a commit that referenced this pull request Dec 12, 2024
ghstack-source-id: 213c4d4
Pull Request resolved: #142063
cc cpuhrsch bhosmer drisspg soulitzer davidberard98 YuqingJ ezyang SherlockNoMad EikanWang jgong5 wenzhe-nrv

[ghstack-poisoned]
jbschlosser added a commit that referenced this pull request Dec 12, 2024
ghstack-source-id: 6fbd22c
Pull Request resolved: #142063
cc cpuhrsch bhosmer drisspg soulitzer davidberard98 YuqingJ ezyang SherlockNoMad EikanWang jgong5 wenzhe-nrv

[ghstack-poisoned]
cc cpuhrsch bhosmer drisspg soulitzer davidberard98 YuqingJ ezyang SherlockNoMad EikanWang jgong5 wenzhe-nrv

[ghstack-poisoned]
cc cpuhrsch bhosmer drisspg soulitzer davidberard98 YuqingJ ezyang SherlockNoMad EikanWang jgong5 wenzhe-nrv

[ghstack-poisoned]
cc cpuhrsch bhosmer drisspg soulitzer davidberard98 YuqingJ ezyang SherlockNoMad EikanWang jgong5 wenzhe-nrv

[ghstack-poisoned]
Requested in #136270

cc cpuhrsch bhosmer drisspg soulitzer davidberard98 YuqingJ ezyang SherlockNoMad EikanWang jgong5 wenzhe-nrv

[ghstack-poisoned]
Requested in #136270

cc cpuhrsch bhosmer drisspg soulitzer davidberard98 YuqingJ ezyang SherlockNoMad EikanWang jgong5 wenzhe-nrv

[ghstack-poisoned]
@jbschlosser jbschlosser requested a review from cpuhrsch December 23, 2024 21:13
@jbschlosser
Copy link
Contributor Author

@soulitzer / @cpuhrsch this is ready now - the PR works in eager / compile without graph breaks for contiguous NJTs

Copy link
Contributor
@soulitzer soulitzer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! Some small comments

end_val += inp._values.size(dim)
start_val = max(min(start_val, inp._values.size(dim)), 0)
end_val = max(min(end_val, inp._values.size(dim)), 0)
length_val = max(min(length_val, end_val - start_val), 0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunate that we have to duplicate narrow input checking/manipulation logic here, but I guess maybe hard to avoid due to as_strided use.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah agreed, I don't like this at all :( but indeed we aren't redispatching to narrow(), so we can't utilize the checks there

if operating_on_batch:
# batch dim narrowing requires custom logic involving offsets
out_kwargs = extract_kwargs(inp)
start_val, length_val = new_kwargs["start"], new_kwargs["length"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trying to think of weird edge cases (totally fine if normal narrow doesn't cover this either, but I guess we want parity?)

  • What happens if length_val is negative? I didn't see any explicit check, and docs say that it must "weakly positive" whatever that means 🤔.

  • What happens if start_val is negative, but length_val > size of the tensor.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good call, will add test cases for these

it must "weakly positive" whatever that means

my understanding is that this allows for length=0? idk for sure though

@@ -8495,6 +8567,18 @@ def __torch_dispatch__(self, func, types, args=..., kwargs=None):
op_match_fn=lambda device, op: (op.full_name in {"cdouble", "cfloat", "chalf"}),
name="unimplemented_view_as_real",
),
# narrow(): unbacked SymInt bug with non-contig transposed inputs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oof. Just curious - Is there an issue somewhere? If there is, maybe worth linking.

start_val += inp._values.size(dim)
if end_val < 0:
end_val += inp._values.size(dim)
start_val = max(min(start_val, inp._values.size(dim)), 0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm feels like it shouldn't be possible to be negative here (prior to max with 0)

self.assertEqual(out3_comp, nt_comp)

# length past the end
with self.assertRaisesRegex(RuntimeError, "exceeds dimension size"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In theory this is the type of thing we could also handle via XFail with sample_match_fn right?

Is it fair to say that we don't want to clutter that too much with basic input validity checks, and reserve those for actual bugs, features that are not implemented.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that's right - I have a mental TODO to use the error_inputs_func feature of OpInfos to handle this type of expected error checking failures. Ideally as you mentioned I'd want to avoid cluttering xfails with things that we'll never address and don't actually represent bugs

@torch._dynamo.utils.disable_cache_limit()
@dtypes(torch.float32)
@parametrize("env", ["eager", "compile", "compile_dynamic"])
def test_narrow_on_batch_dim(self, device, dtype, env):
Copy link
Contributor
@soulitzer soulitzer Jan 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder how much overlap there is with this test and the narrow OpInfo one.
Like could the "first few", "middle", and "last" batch items be formulated as sample inputs?

And then we could reformalate this one to two smaller tests that do more specific things, e.g. test_narrow_on_batch_dim_input_validation and test_narrow_on_batch_dim_narrow_of_narrow.

Copy link
Contributor Author
@jbschlosser jbschlosser Jan 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Like could the "first few", "middle", and "last" batch items be formulated as sample inputs?

Definitely and I should do that :) will fix

Edit: I realized I'm kind of already doing this for the generated sample inputs on non-ragged dims

And then we could reformalate this one to two smaller tests that do more specific things, e.g. test_narrow_on_batch_dim_input_validation and test_narrow_on_batch_dim_narrow_of_narrow.

yeah these are harder to test with OpInfo so I think it makes sense to break them out

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh I just realized the compile tests for narrow-on-narrow weren't actually compiling :p

changing them to actually compile fails them; investigating

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay i tracked this down to incorrect clamping. old logic was clamping (start, end) on inner values dim space, but we want to clamp on outer batch dim space. This fixed the data-dependent guard errors.

There's still some work to be done for non-contiguous NJTs apparently, but I've been deprioritizing those in general so I'll just land this as-is

Requested in #136270

cc cpuhrsch bhosmer drisspg soulitzer davidberard98 YuqingJ ezyang SherlockNoMad EikanWang jgong5 wenzhe-nrv

[ghstack-poisoned]
Requested in #136270

cc cpuhrsch bhosmer drisspg soulitzer davidberard98 YuqingJ ezyang SherlockNoMad EikanWang jgong5 wenzhe-nrv

[ghstack-poisoned]
Copy link
Contributor

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions bot added the Stale label Apr 11, 2025
Divigroup-RAP pushed a commit to Divigroup-RAP/PYTORCH that referenced this pull request Apr 22, 2025
ghstack-source-id: 0f65635
Pull Request resolved: pytorch/pytorch#142063
@github-actions github-actions bot closed this May 12, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
fx module: nestedtensor NestedTensor tag see issue #25032 release notes: fx release notes category release notes: nested tensor Changes that have a direct impact on nested tensors Stale topic: improvements topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants
0