8000 Padded tensor subclass · Issue #105325 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Padded tensor subclass #105325

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
ezyang opened this issue Jul 17, 2023 · 11 comments
Open

Padded tensor subclass #105325

ezyang opened this issue Jul 17, 2023 · 11 comments
Assignees
Labels
feature A request for a proper, new feature. module: nestedtensor NestedTensor tag see issue #25032 tensor subclass Related to tensor subclasses triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@ezyang
Copy link
Contributor
ezyang commented Jul 17, 2023

🐛 Describe the bug

Suppose you have a network which operates on dynamically sized inputs / has data dependent dynamism internally. Our default policy is to represent such a tensor as compactly as possible (e.g., with no padding) to minimize storage and FLOPs needed to operate on it.

However, in some situations, it could be profitable / cost free to pad out the tensor:

  • If you are CUDA graphing with dynamic shapes and you know your maximum size, padding in the outermost (e.g., batch) dimension is effectively free, because the CUDA graph will require you to maintain memory equivalent to the maximum memory usage for your dynamic shapes. In fact, it is better than free, because ensuring you always allocate the same amount of memory every iteration ensures that you will use the same allocations; the allocator otherwise can make bad decisions in the name of "saving" memory (e.g., if you previously allocated a tensor out of a 10MB block, but this time you only need 5MB because you halved your sequence length, instead of serving the allocation out of the 10MB, it might allocate an extra 5MB to "save" the 10MB for later (even though it will never be used!)
  • Increasing the size of tensors can improve the performance on kernels. @Chillee has a good explainer about this phenomenon in matmuls at https://twitter.com/cHHillee/status/1630274804795445248 In fact, @msaroufim and @Chillee tried to add this optimization directly to PyTorch but the post facto layout change was a bit hard to implement. Doing the layout change "early" with a tensor subclass should be easier to implement (albeit less automatic.) These improvements generalize beyond matmuls, although mostly for making sure your sizes are divisible by something nice. Fully automatic size increases here are a little difficult to do, because you have to know that later you're going to do a matmul, and you also have to know that you aren't losing all your gains from non-contiguous kernels. However, if you have a net where one of the input dimensions is dynamic, you can choose to bucket to reduce the number of CUDA graphs you need. That being said, if the dynamic dimension is batch size (or even sequence length but you have embeddings on the inner dimensions so there's no padding problems), you aren't going to get kernel perf wins.

Padded batch tensor for dynamic batch size is probably the easiest to implement to start, because you can use symbolic shape propagation rules to propagate batch dim and ensure they're properly padded (I can't think of a good way to make vmap do this.) Annoyance is avoiding wasted FLOPs by "adjusting down" the logical size.

Related: #65156
Related: nested tensor

cc @cpuhrsch @jbschlosser @bhosmer @drisspg @msaroufim @albanD

Versions

main

@zou3519
Copy link
Contributor
zou3519 commented Jul 17, 2023

Is what you want a MaskedTensor? https://pytorch.org/docs/stable/masked.html

@eellison
Copy link
Contributor

A MaskedTensor is more general than a Padded Tensor, and as a result has twice as much memory use because you need to store the entire mask. And I imagine is much slower, because you need to repeatedly read from the mask as opposed to a single value per padded dimension.

@vadimkantorov
Copy link
Contributor
vadimkantorov commented Jul 17, 2023

Ah, I see, this is actually about a PaddedTensor (probably most of varlen tensors were right-padded anyway, but the seqlens has been passed around manually in a separate variable - btw it would be cool if seqlen shape could be multidim and select dimensions which are padded e.g. to support BTC or TBC)


I think this is actually NestedTensor @cpuhrsch, no (e.g. if one could pass a certain storage size to force good allocator behavior)? :)

Some of our previous discussions on NestedTensor / TensorLists on other contextx:

@ezyang
Copy link
Contributor Author
ezyang commented Jul 18, 2023

The difference with nested tensor is (1) unlike the default nested tensor layout, you don't need to record strides for each inner tensor and (2) there is no non-uniform shapiness in an individual time; instead, you expect variation across several runs.

@mikaylagawarecki mikaylagawarecki added module: nestedtensor NestedTensor tag see issue #25032 tensor subclass Related to tensor subclasses triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module feature A request for a proper, new feature. labels Jul 18, 2023
@vadimkantorov
Copy link
Contributor

I think this is helpful also for modeling expressiveness. Currently, when using padding we just pass around shapes-before-padding. This is not very bad, but if more methods supported this "mask format", it would be better (e.g. masked layer norm, masked conv1d, masked attention).

One way to proceed might be as in xFormers which created types/classes for some (structured or otherwise) mask formats first and introduce mask= argument for regular pytorch methods? and then one could or could not introduce a special tensor subclass for this depending on preference (I myself do not like much the subclasses as the only possible way of accessing the masked functionality)

@ezyang
Copy link
Contributor Author
ezyang commented Jul 26, 2023

@vadimkantorov What you want is different, but also useful. You want "efficient" masked tensor, where the jagged dimension is padded out (so you have a dense tensor), but you have a per-batch sequence length so you know which elements are masked out and you, e.g., can do reductions correctly. You don't save FLOPs but you get correct semantics.

It's a lot of work and kind of ugly to introduce mask= to all pre-existing PyTorch ops, so if you really don't want to use the subclass we'd probably have to build out a separate namespace of operator variants with mask= arguments (which would probably be implemented by wrapping things into masked tensor and then running the subclass lol). Subclass is less of a deal breaker once @bdhirsh gets compilation going and there is no overhead.

@vadimkantorov
Copy link
Contributor
vadimkantorov commented Jul 26, 2023

Yes, the case for this is mainly correct semantics... I would say, proper mask processing (arguments) is needed not for all ops, but mainly for reduction (and myabe for pooling/conv/batchnorm). Indeed, for this usecase correct semantics is the goal, not perf (saving time on ignoring masked out elements). For subclasses I'm worried that many ops would fails if they don't know how to handle subclasses or would silently unexpectedly ignore mask. So basically I'm not very trustful of various fallthrough mechanisms...

@ezyang are you proposing for a "padded nested tensor"?. Btw if you can group tensors/shapes based on storage size, just keeping padded storage size might be a simple solution.

Also, where does stand this JaggedTensor structure (in relation to other above-discussed formats)?

@ezyang
Copy link
Contributor Author
ezyang commented Aug 8, 2023

Well, the top level proposal here is for padding for non-nested tensor, so no, I'm not. But your thing is reasonable too, if someone wants to work on it. JaggedTensor doesn't pad, so it's a more compact rep / saves flops.

@vadimkantorov
Copy link
Contributor

@ezyang I guess too many padding/nested/jagged useful formats :) so I'm a bit confused about your original proposal :)

@bobrenjc93
Copy link
Contributor

Arguably an even bigger benefit of this work is divisibility guarantees during codegen. We've seen empirically that divisibility hints in triton can have up to a 17x perf difference between dynamic/static.

@bobrenjc93 bobrenjc93 self-assigned this May 16, 2025
@bobrenjc93
Copy link
Contributor

I'm going to try to take this from prototype to production.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature A request for a proper, new feature. module: nestedtensor NestedTensor tag see issue #25032 tensor subclass Related to tensor subclasses triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

6 participants
0