8000 [Nested Tensor] Support NT construction inside PT2 graph · Issue #118446 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[Nested Tensor] Support NT construction inside PT2 graph #118446

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
3 tasks
davidberard98 opened this issue Jan 27, 2024 · 3 comments
Closed
3 tasks

[Nested Tensor] Support NT construction inside PT2 graph #118446

davidberard98 opened this issue Jan 27, 2024 · 3 comments
Assignees
Labels
feature A request for a proper, new feature. module: nestedtensor NestedTensor tag see issue #25032 oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@davidberard98
Copy link
Contributor
davidberard98 commented Jan 27, 2024

🚀 The feature, motivation and pitch

Repro:

import torch
from torch.nested._internal.nested_tensor import ViewNestedFromBuffer, buffer_from_jagged

def fn(values, offsets):
    # return values.cos().sin()
    return ViewNestedFromBuffer.apply(values.cos(), offsets).sin()

fn_c = torch.compile(fn, backend="aot_eager")
# values = torch.rand((12, 8), requires_grad=True)
values = torch.rand((12, 8))
offsets = torch.tensor([0, 1, 2, 5, 8, 9, 12])
lengths = torch.tensor([1, 1, 3, 3, 1, 3])

nt = ViewNestedFromBuffer.apply(values, offsets)
fn_c(values, offsets)
  • ViewNestedFromBuffer hits skipfiles Using autograd.Functions defined in torch/ cause graph breaks #118334
  • Dynamic shapes doesn't really work with this, if any part of the values tensor has dynamic shapes in the non-batch dimensions:
    • If values is dynamic, then when we construct self._strides = (ragged_size * stride[self._ragged_idx - 1], *stride): stride is values.stride(); if stride[self._ragged_idx - 1] is symbolic, then ragged_size * stride[..] is a multiplication of a singleton symbolic symint and a normal python symbolic symint; this is not supported.
  • Symbolic strides don't work correctly
    • __tensor_unflatten__ needs values to be symbolic in order to identify that we can update the _tensor_symint_registry. When a NT is an input, we know about dynamism properties because of mark_dynamic in the NT constructor. However, when the NT is constructed inside, we don't know about those dynamism properties until after we trace through (and we probably need to check trace order etc. to make sure the dynamism is marked at a point early enough...)
      • Why do we actually need these to be symbolic to do the jaggedness stuff?
    • Suggestion from Jeffrey and Joel: let's just use values instead of offsets, and construct the symbolic int stuff before tracing starts
      • I think we'll still hit the same issue, we need values to be partially dynamic.

Alternatives

No response

Additional context

No response

cc @cpuhrsch @jbschlosser @bhosmer @drisspg @soulitzer @ezyang @msaroufim @bdhirsh @anijain2305 @zou3519 @chauhang

@davidberard98 davidberard98 self-assigned this Jan 27, 2024
@davidberard98 davidberard98 changed the title [Nested Tensor] Support NT construction inside graph [Nested Tensor] Support NT construction inside PT2 graph Jan 27, 2024
@colesbury colesbury added module: nestedtensor NestedTensor tag see issue #25032 oncall: pt2 labels Jan 29, 2024
@jbschlosser
Copy link
Contributor

Why do we actually need these to be symbolic to do the jaggedness stuff?

This is a good question. Symbolic SymInts support more operations than non-symbolic SingletonSymInts do (e.g. multiplication). It might be theoretically possible to add any support needed by PT2 downstream to non-symbolic SingletonSymInts; it's just a decent amount of work.

cc @soulitzer in case there he is aware of any theoretical issues preventing us from adding this support

@davidberard98
Copy link
Contributor Author
davidberard98 commented Jan 29, 2024

@jbschlosser what I meant was, whether we can do this: #118577 (still waiting on CI to see if anything fails...)

(edit: the PR has changed since when I originally wrote the comment... originally it just removed the check for symbolic sizes, but that fails; now the PR is testing other stuff)

@mlazos mlazos added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jan 30, 2024
@yf225 yf225 added the feature A request for a proper, new feature. label Mar 29, 2024
soulitzer added a commit that referenced this issue Apr 22, 2024
… from inputs"


Creating symbolic nested ints within the graph is difficult. Using unbacked symints should solve the most important(?) cases in the mean time.

See #118446

Known gaps:
- creating NJT from intermediate offsets (offsets created within the graph, as opposed to being offsets passed in as inputs)
- when the same offsets is also passed in as a input to the graph. We are not smart enough to realize that the offsets from that input is the same and therefore would fail when  the sizes are compare ("s0 cannot be compared with u0")



cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
soulitzer added a commit that referenced this issue Apr 22, 2024
…ph using offsets from inputs"


Creating symbolic nested ints within the graph is difficult. Using unbacked symints should solve the most important(?) cases in the mean time.

See #118446

Known gaps:
- creating NJT from intermediate offsets (offsets created within the graph, as opposed to being offsets passed in as inputs)
- when the same offsets is also passed in as a input to the graph. We are not smart enough to realize that the offsets from that input is the same and therefore would fail when  the sizes are compare ("s0 cannot be compared with u0")



cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
soulitzer added a commit that referenced this issue Apr 22, 2024
… from inputs"


Creating symbolic nested ints within the graph is difficult. Using unbacked symints should solve the most important(?) cases in the mean time.

See #118446

Known gaps:
- creating NJT from intermediate offsets (offsets created within the graph, as opposed to being offsets passed in as inputs)
- when the same offsets is also passed in as a input to the graph. We are not smart enough to realize that the offsets from that input is the same and therefore would fail when  the sizes are compare ("s0 cannot be compared with u0")



cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
soulitzer added a commit that referenced this issue Apr 22, 2024
…ph using offsets from inputs"


Creating symbolic nested ints within the graph is difficult. Using unbacked symints should solve the most important(?) cases in the mean time.

See #118446

Known gaps:
- creating NJT from intermediate offsets (offsets created within the graph, as opposed to being offsets passed in as inputs)
- when the same offsets is also passed in as a input to the graph. We are not smart enough to realize that the offsets from that input is the same and therefore would fail when  the sizes are compare ("s0 cannot be compared with u0")



cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
soulitzer added a commit that referenced this issue Apr 22, 2024
… from inputs"


Creating symbolic nested ints within the graph is difficult. Using unbacked symints should solve the most important(?) cases in the mean time.

See #118446

Known gaps:
- creating NJT from intermediate offsets (offsets created within the graph, as opposed to being offsets passed in as inputs)
- when the same offsets is also passed in as a input to the graph. We are not smart enough to realize that the offsets from that input is the same and therefore would fail when  the sizes are compare ("s0 cannot be compared with u0")



cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
@IvanKobzarev
Copy link
Contributor

Doing issue scrapping. The repro does not fail, closing.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature A request for a proper, new feature. module: nestedtensor NestedTensor tag see issue #25032 oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

6 participants
0