-
Notifications
You must be si 8000 gned in to change notification settings - Fork 24.2k
Padded tensor subclass #105325
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Is what you want a MaskedTensor? https://pytorch.org/docs/stable/masked.html |
A MaskedTensor is more general than a Padded Tensor, and as a result has twice as much memory use because you need to store the entire mask. And I imagine is much slower, because you need to repeatedly read from the mask as opposed to a single value per padded dimension. |
The difference with nested tensor is (1) unlike the default nested tensor layout, you don't need to record strides for each inner tensor and (2) there is no non-uniform shapiness in an individual time; instead, you expect variation across several runs. |
I think this is helpful also for modeling expressiveness. Currently, when using padding we just pass around shapes-before-padding. This is not very bad, but if more methods supported this "mask format", it would be better (e.g. masked layer norm, masked conv1d, masked attention). One way to proceed might be as in xFormers which created types/classes for some (structured or otherwise) mask formats first and introduce |
@vadimkantorov What you want is different, but also useful. You want "efficient" masked tensor, where the jagged dimension is padded out (so you have a dense tensor), but you have a per-batch sequence length so you know which elements are masked out and you, e.g., can do reductions correctly. You don't save FLOPs but you get correct semantics. It's a lot of work and kind of ugly to introduce mask= to all pre-existing PyTorch ops, so if you really don't want to use the subclass we'd probably have to build out a separate namespace of operator variants with mask= arguments (which would probably be implemented by wrapping things into masked tensor and then running the subclass lol). Subclass is less of a deal breaker once @bdhirsh gets compilation going and there is no overhead. |
Yes, the case for this is mainly correct semantics... I would say, proper mask processing (arguments) is needed not for all ops, but mainly for reduction (and myabe for pooling/conv/batchnorm). Indeed, for this usecase correct semantics is the goal, not perf (saving time on ignoring masked out elements). For subclasses I'm worried that many ops would fails if they don't know how to handle subclasses or would silently unexpectedly ignore mask. So basically I'm not very trustful of various fallthrough mechanisms... @ezyang are you proposing for a "padded nested tensor"?. Btw if you can group tensors/shapes based on storage size, just keeping padded storage size might be a simple solution. Also, where does stand this JaggedTensor structure (in relation to other above-discussed formats)? |
Well, the top level proposal here is for padding for non-nested tensor, so no, I'm not. But your thing is reasonable too, if someone wants to work on it. JaggedTensor doesn't pad, so it's a more compact rep / saves flops. |
@ezyang I guess too many padding/nested/jagged useful formats :) so I'm a bit confused about your original proposal :) |
Arguably an even bigger benefit of this work is divisibility guarantees during codegen. We've seen empirically that divisibility hints in triton can have up to a 17x perf difference between dynamic/static. |
I'm going to try to take this from prototype to production. |
🐛 Describe the bug
Suppose you have a network which operates on dynamically sized inputs / has data dependent dynamism internally. Our default policy is to represent such a tensor as compactly as possible (e.g., with no padding) to minimize storage and FLOPs needed to operate on it.
However, in some situations, it could be profitable / cost free to pad out the tensor:
Padded batch tensor for dynamic batch size is probably the easiest to implement to start, because you can use symbolic shape propagation rules to propagate batch dim and ensure they're properly padded (I can't think of a good way to make vmap do this.) Annoyance is avoiding wasted FLOPs by "adjusting down" the logical size.
Related: #65156
Related: nested tensor
cc @cpuhrsch @jbschlosser @bhosmer @drisspg @msaroufim @albanD
Versions
main
The text was updated successfully, but these errors were encountered: