-
Notifications
You must be signed in to change notification settings - Fork 24.2k
[RFC] PyTorch DistributedTensor #88838
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Wondering if there are applications here for in-memory sharded datasets, a use-case not mentioned above. |
@TimZaman Yeah I think this is entirely possible, for indexing into a massive in memory data set which have the same length per imagine, it looks to me a sharded embedding look up on images, which can be easily implemented with DTensor, underlying we will call into necessary collectives to get a imagine slice according to the index. |
… and tests to core distributed" This PR moves tensor/parallel folder and tests to torch.distributed. part of #88838 [ghstack-poisoned]
…ests to core distributed" This PR moves remaining tests, i.e. tensor_ops, op db tests to core distributed part of #88838 [ghstack-poisoned]
…re distributed" This PR moves tensor/parallel folder and tests to torch.distributed. part of #88838 [ghstack-poisoned]
…tributed" This PR moves remaining tests, i.e. tensor_ops, op db tests to core distributed part of #88838 [ghstack-poisoned]
Nit: why post here instead of https://github.com/pytorch/rfcs? Edit: never mind, found pytorch/rfcs#44 :) |
How much control will users have over the sharding and materialization? E.g., to get decent throughput for parameter sharding or CPU offloading, we'll need some sort of prefetching mechanism. Also, what's the story for saving + loading large, sharded models? Torch Snapshot? |
I wanna ask when will models such as DeviceMesh, Shard, distribute_ 10000 tensor can be used? Will documentation be provided for these models? I want to use the tensordot function to realize the large-scale tensor contraction of distributed storage, so how can I code to realize it? |
@dblalock user will have control about how to implement a operator in a distributed fashion, if you want to apply prefetching mechanism on top, I think it would just running this operator on DTensor with a separate CUDA stream (if it's cuda) or do async offloading just like you do with torch.Tensor
To save/load large sharded models, we are working on releasing |
@ntyz We are working on landing it to pytorch soon, you can subscribe to this stack of PRs #88180 and once those are merged (hopefully this week or next week), you should be able to use it immediately in master or nightly build 1-2 days later. We plan to release DTensor as a prototype feature in the next pytorch release, which we might add more documentation in the code APIs directly, but not in official https://docs.pytorch.org/ yet. It will be added to the doc website when it releases beta. Note that we will release Feel free to try it out and submit issues to pytorch/pytorch directly or even contributing once it's there :) |
…ests to core distributed" This PR moves remaining tests, i.e. tensor_ops, op db tests to core distributed part of #88838 [ghstack-poisoned]
… and tests to core distributed" This PR moves tensor/parallel folder and tests to torch.distributed. part of #88838 [ghstack-poisoned]
…re distributed" This PR moves tensor/parallel folder and tests to torch.distributed. part of #88838 [ghstack-poisoned]
…tributed" This PR moves remaining tests, i.e. tensor_ops, op db tests to core distributed part of #88838 [ghstack-poisoned]
… and tests to core distributed" This PR moves tensor/parallel folder and tests to torch.distributed. part of #88838 [ghstack-poisoned]
…ps to core distributed" This PR moves the view related DTensor ops to core distributed, tests will be add in follow up PRs part of #88838 [ghstack-poisoned]
…ibuted" This PR moves the view related DTensor ops to core distributed, tests will be add in follow up PRs part of #88838 [ghstack-poisoned]
… core distributed" This PR moves DTensor op tests to core distributed, including prop_rule, pointwise op, matrix op tests, etc. part of #88838 [ghstack-poisoned]
@Arnold1 Sorry for the late reply. It could possibly run in a cluster with CPUs but probably not together with databricks or pyspark. |
@fengjinyuan This might be something we want to explore in long term but not the current focus. The current focus is still around homogenous hardware.
Yeah this is something we are exploring |
Were there any updates to the API? I'm trying to run the very first example from the README, which appears not to be working (anymore). My full code is as follows:
I'm running this script via
If I change the line that creates the tensor to |
Hey sorry for the late reply! We prototype released the APIs and are currently working on enhancing it to push to beta release. The issue you observed is because when you first do However If you just construct DeviceMesh alone without initializing process group, DeviceMesh will automatically set up the devices on each process. I also updated the README to make everything runnable |
How do we handle fused layers with DTensor? For example, in SwiGLU, there are frequently two input matrices in the FF layer. These two matrices are fused into one big matrix. If we apply DTensor Tensor Parallel to shard this big matrix in the output dimension, the sharding will cross the wrong dimensions. I notice that gpt-fast and torchtitan both avoid using fused layers because of this problem, but that comes with a performance penalty. With megatron layers, I handle the TP sharding and unsharding myself. |
Hey @ad8e, thanks for bringing this up! For fused layers (i.e. fused SwiGLU or even fused QKV), DTensor should be able to represent its sharding layout by implementing a more complicated sharding layout like stride aware sharding, we are currently working on this but haven't enabled it yet, planning to make this work in the upcoming weeks!
The fused layers could be more performant, but I feel it might not give too much e2e perf gains to the Transformer model training. Horizontal fusions would give good wins when the CUDA computations are not saturated but usually the compute are saturated already for LLMs. Either way I think it's better if we could support this out of box, but this does involves some work for a more complicated layouts, so stay tuned :) |
On a transformer, with TP + Sequence Parallel (i.e. what torchtitan is doing), we should theoretically be able to hide all the TP comms inside the computation. Is this planned for DTensor? TP is very slow without it.
@ad8e TP is expected to be slower due to the on-critical path communications happened in each TransformerBlock, that's why TP always need to happen intra-host (where NVLink enable). TP's value is to enable larger model training (i.e. 70B) with hundreds to thousands of GPUs (where FSDP alone would OOM, please take a look at the tutorial we recently added). We are also working on fusing the compute and communication, but NCCL itself aren't performing well with p2p ops so we are working on some alternative solutions (i.e. we found that with NCCL fusing comm and comp is even slower than just exposing the TP comm). @yifuwang is making a cuda p2p backend that would make comm/computation overlap be performant, and we expect to make no model code change performance improvement with torch.compile, and provide a simple eager API for users who don't want to enable torch.compile. Wondering how slow you are experiencing? For performance comparison, one common thing I want to mention is that when enabling FSDP + TP and compare the perf with FSDP only, one should bump their local batch size to multiply it by the |
With Llama3 70B: if I use full activation checkpointing with TP=1, it's 2x the MFU as TP=8 with DTensor without activation checkpointing. These reach the memory limit of an 16 HGX H100 system, which is the limitation I'm working around. Using the TP+sequence parallel example in your tutorial: If only |
@ad8e 16 H100 is really a small number of GPUs that applying 2D parallelism won't see benefits comparing to just use FSDP alone for example, in practice you would really want to apply TP when the world size goes beyond 256 GPUs. I also tried locally using torchtitan on Llama3 8B:
I can try to run some 70B jobs but I would be surprised to see 2x slower here given that the compute is even heavier.
The optimizations that JAX is doing in this talk is exactly what we are approaching it, we are enabling partial gemm outputs with collectives (basically gemm + reduce_scatter decompose to gemms + collective permute), the issue still hold: NCCL is bad at this optimization, so we are pursuing a intra-host only optimization currently, and we'll possibily land this optimization in the next few weeks |
I tried your TP=4 and TP=1 tests on our internal codebase. Llama 8B, non-fused RMSNorm, FSDP1, PyTorch 2.3. 8 H100, full AC, 2048 context, 7x attention flops multiplier (rather than 12x). The TP4/TP1 ratio is 150/210 = 71.4%. Yours is 22.4/26.9 = 83.3%, which is better. I ran my TP=4 inside a profiler for 3 steps. The GPU waits for the CPU a lot. https://drive.google.com/file/d/1U3PSLETgbI4F5hTwHtkS-pxo_quRzvtp/view?usp=sharing TP=1 profile is more reasonable: the GPU waits for comms, but at least there's no CPU waiting. https://drive.google.com/file/d/1p4AzFm_sqheVhGvGpcCFu3SKmADYO09E/view?usp=sharing Characteristics of my benchmark: |
@ad8e curious why you use the 2048 seq length instead of the llama3 8B default 8192? We did observe some CPU overhead when the model size is small or seq_length is small (i.e. same issue with llama2 7B). This is usually not an issue for the cases where TP is actually needed, with either large model size or data that already saturates the cuda computation (i.e. even Llama3 8b with 8192 seq length works well). We are also investing into torch.compile so that TP layers can be fully compiled so that there would be F438 almost no CPU overhead. |
Good point; I'll redo the tests at 8192. It was that way as a leftover from some vocab replacement runs, but 2048 is a bad benchmark here. |
8192 context. New ratio is 257/308, which is the same as yours at 8B, 83.4%. |
@ad8e glad to see our numbers matched! I think for models beyond 8B, with the same setting the numbers should match or even higher when training on clusters. I'll do some more benchmarks for 70B models on my side too. Once we have the comm/compute overlap enable this should be higher :) |
My opinion from the user-friendliness side: I'm willing to do some manual DTensor labeling in my code. There are a few types of PyTorch users in my mind. Small users: training up to 8B, FSDP alone is enough on an HGX. It's easy, and it's performant if the user can make torch.compile work. Medium users: at 13-34B, users will put some engineering effort in, but it's to the extent of writing a torchtitan-like codebase, or using one. Big users: at 70B+, I'm willing to annotate anything that needs to be annotated. If it's not easy to deduce a property with torch.compile, I can specify that property. The perf is more important than user friendliness. I'm considering this because the sequence parallel example I gave does not seem easy to deduce automatically. The comm/comp overlap should only require two splits in most cases. The computation of But from the PyTorch side, tracing the independence of dimensions would require labeling a lot of ops. So if I can decorate the graph to say where the independent dimensions are, to make analysis easier, I would. |
@ad8e Thanks a lot for the insightful feedbacks!! Yeah I think I'm aligned with you here. That is also why we don't specialize on Tensor Parallel, and working on getting DTensor to a better state. The model would change/evolve over time and different parallelisms would emerge, having the ability to allow user to annotate in every detail on how to perform sharding and communication is important for medium and big users! We'll expose enough control and flexibility to express sharding and runtime communication. For the sequence parallel case specifically, we'll expose a fused comm + compute method that user can plug in to their forward (i.e. attention or ffn) together with the sharding of TP, so that we can even have eager mode performance. |
Can you please provide a document describing the a 10000 pplication scenarios and best practices? |
section
thanks |
With the current code, |
thanks, could you explain a bit on what communication reordering is? thanks. (reorder allgather and reducescatter? reorder something within allgather/reducescatter?) |
I think the default FSDP ordering of its own all-gather and reduce-scatter in backward is pretty reasonable. I was thinking more along the lines of reordering with collectives from other parts of the model, e.g. TP communications or if you have a sparse embedding part of your model, some embedding all-to-alls. |
If I call distribute_tensor states, "If [placements] is not specified, we will by default replicate the tensor across the |
@ad8e Yes, when calling |
Hi @wanchaol , can you share if there's plan or ongoing work to support compilation or lazy tensors. We are looking to support DTensor with XLA tensors better so the existing parallelism utils can be reused, including FSDP wrappers. Would like to learn if there's opportunity for collaboration. |
🚀 The feature, motivation and pitch
RFC: PyTorch DistributedTensor
We have been developing a DistributedTensor (a.k.a DTensor) concept under the pytorch/tau repo in the past few months, and now we are moving the implementation over to pytorch with the stack #88180. This RFC proposes the DistributedTensor to torch.distributed. Any early feedbacks are welcomed!
Update:
DTensor now available in PyTorch 2.0 and nightly build! You can now play around with DTensor even in a co-lab Notebook! see a quick e2e tutorial here https://colab.research.google.com/drive/12Pl5fvh0eLPUrcVO7s6yY4n2_RZo8pLR#scrollTo=stYPKb9Beq4e
Introduction
We propose distributed tensor primitives to allow easier distributed computation authoring in SPMD(Single Program Multiple Devices) paradigm. The primitives are simple but powerful when used to express tensor distributions with both sharding and replication parallelism strategies. This could empower native Tensor parallelism among other advanced parallelism explorations. For example, to shard a big tensor across devices with 3 lines of code:
Motivation
Today there are mainly three ways to scale up distributed training: Data Parallel, Tensor Parallel and Pipeline Parallel. Each of them works on a separate dimension where solutions have been built independently (i.e. PyTorch DDP, FSDP, ShardedTensor, PiPPy, etc.). When training really large models, users would like to use these technologies together (i.e. 3-D Parallelism), while the interoperability of the existing solutions are not great and often hard to use (i.e. users might want arbitrary combinations of the data parallel, tensor parallel and pipeline parallel). This is becoming an issue for users and one of the biggest reasons is that there’s no common abstractions that build the bridge between different parallelism strategies.
An ideal scenario is that users could just build their models like in a single node/device, without worrying about how to do distributed training in a cluster, and our solutions could help them run distributed training in an efficient manner. For example, researchers just need to build their big transformer model, and PyTorch Distributed automatically figures out how to split the model and run pipeline parallel across different nodes, how to run data parallel and tensor parallel within each node. In order to achieve this, we need some common abstractions to represent data distribution and run the distributed computation.
There're many recent works that working on tensor level parallelism to provide common abstractions, see the
Related Works
in the last section for more details. Inspired by GSPMD, Oneflow and TF’s DTensor, we introduce a DistributedTensor concept to represent generic data distributions across hosts. DistributedTensor is the next evolution of ShardedTensor and provides basic abstractions to distribute storage and compute. It serves as one of the basic building blocks for distributed program translations and describes the layout of a distributed training program. With the DistributedTensor abstraction, we can seamlessly build parallelism strategies such as tensor parallelism, DDP and FSDP.Value Propsition
DistributedTensor primarily:
PyTorch DistributedTensor
DistributedTensor API
We offer both a lower level DistributedTensor API and a module level API to create a
nn.Module
with “distributed” parameters.Basic DistributedTensor API Examples
Here are some basic DistributedTensor API examples that showcase:
torch.Tensor
.High level User Facing APIs
Users can use DistributedTensor tensor constructors directly to create a distributed tensor (i.e.
distributed.ones/empty
), but for existing modules like nn.Linear that are already having torch.Tensor as parameters, how to make them distributed parameters? We offer a way to directly distribute a torch.Tensor and a module level APIs to directly distribute the module parameters. Below is the high level API we introduce:High level API examples:
Compiler and DistributedTensor
DistributedTensor provides efficient solutions for cases like Tensor Parallelism. But when using the DTensor's replication in a data parallel fashion, it might become observably slow compared to our existing solutions like DDP/FSDP. This is mainly because existing solutions like DDP/FSDP could have the global view of entire model architecture, thus could optimize for data parallel specifically, i.e. collective fusion and computation overlap, etc. DistributedTensor itself is only a Tensor-like object and only knows its local computation operation, it does not know the subsequent operations that happened afterwards.
In order to make the performance on par when using DistributedTensor directly to do data parallel training, DistributedTensor also needs the global view to do things like communication optimization. We are exploring a compiler based solution accompanied with DistributedTensor so that we could run optimizations on top of it, which will be shared later.
Related Works
This work is mainly inspired by GSPMD, Oneflow and TF’s DTensor. All of these three works use a single “distributed tensor” concept for both replication and sharding, and the solutions could enable users to build up their distributed training program in a uniform SPMD programming model. Specifically:
GSPMD:
OneFlow GlobalTensor:
TensorFlow DTensor:
There are also several cutting edge research fields that embeds tensor sharding as part of the system, i.e. Megatron-LM for tensor parallelism on Transformer based models. DeepSpeed for training large scale models with different optimization techniques on top of tensor sharding.
Alternatives
In PyTorch, we have existing ShardedTensor work in the prototype stage, which introduces basic PyTorch sharding primitives as our Tensor Parallelism solution. But ShardedTensor only has tensor sharding support, which makes it hard to be used by users to describe other data distributions strategies like replication or replication + sharding. As a distributed system developer who wants to explore more parallelism patterns, it’s crucial to have a basic building block that describes the data distribution in a uniform way. This DistributedTensor RFC aims at solving this and provide a fundamental abstraction for distributed training.
Additional context
We are gathering early feedbacks about this proposal. We have also posted this RFC to the dev-discuss forum, please feel free to comment directly in this issue or in the forum post. To see a complete design doc with additional details about this proposal, please refer to this doc
cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @fduwjj @XilunWu @gnadathur @anj-s @zdevito @ezyang @albanD
The text was updated successfully, but these errors were encountered: