10000 [RFC] PyTorch DistributedTensor · Issue #88838 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[RFC] PyTorch DistributedTensor #88838

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wanchaol opened this issue Nov 10, 2022 · 48 comments
Open

[RFC] PyTorch DistributedTensor #88838

wanchaol opened this issue Nov 10, 2022 · 48 comments
Labels
module: dtensor distributed tensor tag oncall: distributed Add this issue/PR to distributed oncall triage queue

Comments

@wanchaol
Copy link
Collaborator
wanchaol commented Nov 10, 2022

🚀 The feature, motivation and pitch

RFC: PyTorch DistributedTensor

We have been developing a DistributedTensor (a.k.a DTensor) concept under the pytorch/tau repo in the past few months, and now we are moving the implementation over to pytorch with the stack #88180. This RFC proposes the DistributedTensor to torch.distributed. Any early feedbacks are welcomed!

Update:
DTensor now available in PyTorch 2.0 and nightly build! You can now play around with DTensor even in a co-lab Notebook! see a quick e2e tutorial here https://colab.research.google.com/drive/12Pl5fvh0eLPUrcVO7s6yY4n2_RZo8pLR#scrollTo=stYPKb9Beq4e

Introduction

We propose distributed tensor primitives to allow easier distributed computation authoring in SPMD(Single Program Multiple Devices) paradigm. The primitives are simple but powerful when used to express tensor distributions with both sharding and replication parallelism strategies. This could empower native Tensor parallelism among other advanced parallelism explorations. For example, to shard a big tensor across devices with 3 lines of code:

# torchrun --standalone --nnodes=1 --nproc-per-node=4 dtensor_example.py
import torch  
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed._tensor import Shard, distribute_tensor  
  
# Create a mesh topology with the available devices.
mesh = init_device_mesh("cuda", (int(os.environ["WORLD_SIZE"]),))
big_tensor = torch.randn(100000, 88) 
# Shard this tensor over the mesh by sharding `big_tensor`'s 0th dimension over the 0th dimension of `mesh`.
my_dtensor = distribute_tensor(big_tensor, mesh, [Shard(dim=0)])

Motivation

Today there are mainly three ways to scale up distributed training: Data Parallel, Tensor Parallel and Pipeline Parallel. Each of them works on a separate dimension where solutions have been built independently (i.e. PyTorch DDP, FSDP, ShardedTensor, PiPPy, etc.). When training really large models, users would like to use these technologies together (i.e. 3-D Parallelism), while the interoperability of the existing solutions are not great and often hard to use (i.e. users might want arbitrary combinations of the data parallel, tensor parallel and pipeline parallel). This is becoming an issue for users and one of the biggest reasons is that there’s no common abstractions that build the bridge between different parallelism strategies.

An ideal scenario is that users could just build their models like in a single node/device, without worrying about how to do distributed training in a cluster, and our solutions could help them run distributed training in an efficient manner. For example, researchers just need to build their big transformer model, and PyTorch Distributed automatically figures out how to split the model and run pipeline parallel across different nodes, how to run data parallel and tensor parallel within each node. In order to achieve this, we need some common abstractions to represent data distribution and run the distributed computation.

There're many recent works that working on tensor level parallelism to provide common abstractions, see the Related Works in the last section for more details. Inspired by GSPMD, Oneflow and TF’s DTensor, we introduce a DistributedTensor concept to represent generic data distributions across hosts. DistributedTensor is the next evolution of ShardedTensor and provides basic abstractions to distribute storage and compute. It serves as one of the basic building blocks for distributed program translations and describes the layout of a distributed training program. With the DistributedTensor abstraction, we can seamlessly build parallelism strategies such as tensor parallelism, DDP and FSDP.

Value Propsition

DistributedTensor primarily:

  • Offers a uniform way to save/load state dict during checkpointing, even when there’re complex data distribution strategies such as combining tensor parallelism with parameter sharding in FSDP.
  • Could natively offer Tensor Parallelism solution in eager mode, just like our current ShardedTensor solution. Moreover, it gives additional flexibility for advanced users who want to mix sharding and replication.
  • Could be the entry point of a SPMD programming model for ML System Engineers, providing good UX to mix up different types of parallelism, and could be used as a fundamental building block of a compiler based distributed training.

PyTorch DistributedTensor

DistributedTensor API

We offer both a lower level DistributedTensor API and a module level API to create a nn.Module with “distributed” parameters.

Basic DistributedTensor API Examples

Here are some basic DistributedTensor API examples that showcase:

  1. How to construct a DistributedTensor directly, to represent different types of sharding, replication, sharding + replication strategies.
  2. How to create DistributedTensor from a local torch.Tensor.
  3. How to “reshard” an existing DistributedTensor to a different DistributedTensor with modified placement strategy or world size.
# torchrun --standalone --nnodes=1 --nproc-per-node=4 dtensor_example.py
import torch
import torch.distributed._tensor as dtensor
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed._tensor import DTensor, Shard, Replicate, distribute_tensor  

# construct a device mesh with available devices (multi-host or single host)
world_size = int(os.environ["WORLD_SIZE"])
device_mesh = init_device_mesh("cuda", (world_size,))
# if we want to do row-wise sharding  
rowwise_placement=[Shard(0)]  
# if we want to do col-wise sharding  
colwise_placement=[Shard(1)]  
# distributed tensor returned will be sharded across the dimension specified in placements  
rowwise_dtensor = dtensor.empty((8, 12), device_mesh=device_mesh, placements=rowwise_placement)
# shard the torch.Tensor on rank 0 and scatter the shards to all ranks
colwise_dtensor = distribute_tensor(torch.randn(8, 12), device_mesh, colwise_placement)

# if we want to do replication across a certain device list  
replica_placement = [Replicate()]  
# distributed tensor will be replicated to all four GPUs.  
dtensor.empty((8, 12), device_mesh=device_mesh, placements=replica_placement)  
  
# if we want to distributed a tensor with both replication and sharding  
device_mesh = init_device_mesh(device_type="cuda", (world_size // 2, 2))
# replicate across the first dimension of device mesh, then sharding on the second dimension of device mesh  
spec=[Replicate(), Shard(0)]  
dtensor.empty((8, 8), device_mesh=device_mesh, placements=spec)  
  
# create a DistributedTensor that shards on dim 0, from a local torch.Tensor  
local_tensor = torch.randn((8, 8), requires_grad=True)  
rowwise_tensor = DTensor.from_local(local_tensor, device_mesh, rowwise_placement)  
  
# reshard the current rowise tensor to a colwise tensor or replicate tensor  
colwise_tensor = rowwise_tensor.redistribute(device_mesh, colwise_placement)  
replica_tensor = colwise_tensor.redistribute(device_mesh, replica_placement)

High level User Facing APIs

Users can use DistributedTensor tensor constructors directly to create a distributed tensor (i.e. distributed.ones/empty), but for existing modules like nn.Linear that are already having torch.Tensor as parameters, how to make them distributed parameters? We offer a way to directly distribute a torch.Tensor and a module level APIs to directly distribute the module parameters. Below is the high level API we introduce:

def distribute_tensor(tensor: torch.Tensor, device_mesh: DeviceMesh=None, placements: List[Placement]=None):  
    '''
    distribute the tensor according to device_mesh and placements, `tensor` could be a "meta" tensor.  
    '''  
  
def distribute_module(  
    module: nn.Module,  
    device_mesh: DeviceMesh=None,  
    partition_fn: Callable[[str, nn.Module, DeviceMesh], ...]=None,
    input_fn: Callable[...., None]=None,  
    output_fn: Callable[...., None]=None,  
):  
    '''  
    This function converts all module parameters to distributed tensor parameters according to the `partition_fn` specified.  
    It could also control the input/output of the module by specifying the `input_fn` and `output_fn`. 
    '''

High level API examples:

from torch.distributed._tensor import distribute_module

def MyModule(nn.Module):  
    def __init__(self):  
        super.__init__()  
        self.fc1 = nn.Linear(8, 8)  
        self.fc2 = nn.Linear(8, 8)  
        self.relu = nn.ReLU()  
     
    def forward(self, input):  
        return self.relu(self.fc1(input) + self.fc2(input))  
  
mesh = init_device_mesh(device_type="cuda", [[0, 1], [2, 3]])

def shard_params(mod_name, m
8000
od, mesh):  
    rowwise_placement = [Shard(0)]
    def to_dist_tensor(t): return distribute_tensor(t, mesh, rowwise_placement)  
    mod._apply(to_dist_tensor)  

sharded_module = distribute_module(model, device_mesh, partition_fn=shard_params)  
  
def shard_fc(mod_name, mod, mesh):  
    rowwise_placement = [Shard(0)]
    if mod_name == "fc1":  
        mod.weight = torch.nn.Parameter(distribute_tensor(mod.weight, mesh, rowwise_placement))

sharded_module = distribute_module(model, device_mesh, partition_fn=shard_fc)

Compiler and DistributedTensor

DistributedTensor provides efficient solutions for cases like Tensor Parallelism. But when using the DTensor's replication in a data parallel fashion, it might become observably slow compared to our existing solutions like DDP/FSDP. This is mainly because existing solutions like DDP/FSDP could have the global view of entire model architecture, thus could optimize for data parallel specifically, i.e. collective fusion and computation overlap, etc. DistributedTensor itself is only a Tensor-like object and only knows its local computation operation, it does not know the subsequent operations that happened afterwards.

In order to make the performance on par when using DistributedTensor directly to do data parallel training, DistributedTensor also needs the global view to do things like communication optimization. We are exploring a compiler based solution accompanied with DistributedTensor so that we could run optimizations on top of it, which will be shared later.

Related Works

This work is mainly inspired by GSPMD, Oneflow and TF’s DTensor. All of these three works use a single “distributed tensor” concept for both replication and sharding, and the solutions could enable users to build up their distributed training program in a uniform SPMD programming model. Specifically:

GSPMD:

  • GSPMD is now the fundamental component of JAX/TensorFlow distributed training and enables various optimizations with the XLA compiler to allow users to train their models efficiently in a large scale setting.
  • Fundamentally, GSPMD have three types of sharding strategies within a tensor: “tiled”, “replicated”, “partially tiled” to represent sharding and replication.
  • At the core of GSPMD Partitioner, it utilizes the XLA compiler to do advanced optimizations, i.e. sharding propagation and compiler based fusion.
  • XLA mark_sharding API: PyTorch XLA’s mark_sharding API uses XLAShardedTensor abstraction (i.e. sharding specs) in PyTorch/XLA. Under the hood XLAShardedTensor is utilizing the GPSMD partitioner to enable SPMD style training on TPU.

OneFlow GlobalTensor:

  • OneFlow is building up their own solution of the “GlobalTensor” concept, which is a variant form of GSPMD sharding, allowing users to explore different parallel strategies with GlobalTensor.
  • OneFlow also has three types of tensor, but they are slightly different from GSPMD: “split”, “broadcast”, and “partial sum”. They don’t use partially tiled and instead have a concept of partial sum to partition the values.

TensorFlow DTensor:

  • DTensor Concepts is an extension of TensorFlow synchronous distributed training. its sharding API, supported features and its compilation passes with MLIR.
  • DTensor also allows sharding and replication on an n-d mesh like device network.
  • DTensor implements MLIR passes to do propagation and operator implementations.

There are also several cutting edge research fields that embeds tensor sharding as part of the system, i.e. Megatron-LM for tensor parallelism on Transformer based models. DeepSpeed for training large scale models with different optimization techniques on top of tensor sharding.

Alternatives

In PyTorch, we have existing ShardedTensor work in the prototype stage, which introduces basic PyTorch sharding primitives as our Tensor Parallelism solution. But ShardedTensor only has tensor sharding support, which makes it hard to be used by users to describe other data distributions strategies like replication or replication + sharding. As a distributed system developer who wants to explore more parallelism patterns, it’s crucial to have a basic building block that describes the data distribution in a uniform way. This DistributedTensor RFC aims at solving this and provide a fundamental abstraction for distributed training.

Additional context

We are gathering early feedbacks about this proposal. We have also posted this RFC to the dev-discuss forum, please feel free to comment directly in this issue or in the forum post. To see a complete design doc with additional details about this proposal, please refer to this doc

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @fduwjj @XilunWu @gnadathur @anj-s @zdevito @ezyang @albanD

@ngimel ngimel added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Nov 10, 2022
@TimZaman
Copy link

Wondering if there are applications here for in-memory sharded datasets, a use-case not mentioned above.
Example: such that you could randomly access into a massive XTB dataset from any node (assuming a solid IB network). Example: if a satellite image covering the entire world is 8TB - if you need to randomly index into this, instead of obtaining multiple jpeg crops from your disk, i could see this implemented as ona huge sharded tensor, preloaded into the nodes, and any node who needs a slice of this can transparently index into this world-tensor. Would eliminate disk access and use rdma - good for throughput and latency.

@wanchaol
Copy link
Collaborator Author

Wondering if there are applications here for in-memory sharded datasets, a use-case not mentioned above. Example: such that you could randomly access into a massive XTB dataset from any node (assuming a solid IB network). Example: if a satellite image covering the entire world is 8TB - if you need to randomly index into this, instead of obtaining multiple jpeg crops from your disk, i could see this implemented as ona huge sharded tensor, preloaded into the nodes, and any node who needs a slice of this can transparently index into this world-tensor. Would eliminate disk access and use rdma - good for throughput and latency.

@TimZaman Yeah I think this is entirely possible, for indexing into a massive in memory data set which have the same length per imagine, it looks to me a sharded embedding look up on images, which can be easily implemented with DTensor, underlying we will call into necessary collectives to get a imagine slice according to the index.

@wanchaol wanchaol added the module: dtensor distributed tensor tag label Nov 10, 2022
wanchaol added a commit that referenced this issue Nov 10, 2022
… and tests to core distributed"


This PR moves tensor/parallel folder and tests to torch.distributed.

part of #88838

[ghstack-poisoned]
wanchaol added a commit that referenced this issue Nov 10, 2022
…ests to core distributed"


This PR moves remaining tests, i.e. tensor_ops, op db tests to core distributed

part of #88838

[ghstack-poisoned]
wanchaol added a commit that referenced this issue Nov 10, 2022
…re distributed"


This PR moves tensor/parallel folder and tests to torch.distributed.

part of #88838

[ghstack-poisoned]
wanchaol added a commit that referenced this issue Nov 10, 2022
…tributed"


This PR moves remaining tests, i.e. tensor_ops, op db tests to core distributed

part of #88838

[ghstack-poisoned]
@wanchaol wanchaol pinned this issue Nov 11, 2022
@byronyi
Copy link
byronyi commented Nov 11, 2022

Nit: why post here instead of https://github.com/pytorch/rfcs?

Edit: never mind, found pytorch/rfcs#44 :)

@dblalock
Copy link

How much control will users have over the sharding and materialization? E.g., to get decent throughput for parameter sharding or CPU offloading, we'll need some sort of prefetching mechanism.

Also, what's the story for saving + loading large, sharded models? Torch Snapshot?

@ntyz
Copy link
ntyz commented Nov 13, 2022

I wanna ask when will models such as DeviceMesh, Shard, distribute_ 10000 tensor can be used? Will documentation be provided for these models? I want to use the tensordot function to realize the large-scale tensor contraction of distributed storage, so how can I code to realize it?

@wanchaol
Copy link
Collaborator Author
wanchaol commented Nov 15, 2022

How much control will users have over the sharding and materialization? E.g., to get decent throughput for parameter sharding or CPU offloading, we'll need some sort of prefetching mechanism.

@dblalock user will have control about how to implement a operator in a distributed fashion, if you want to apply prefetching mechanism on top, I think it would just running this operator on DTensor with a separate CUDA stream (if it's cuda) or do async offloading just like you do with torch.Tensor

Also, what's the story for saving + loading large, sharded models? Torch Snapshot?

To save/load large sharded models, we are working on releasing torch.distributed.checkpoint to beta for large scale model save/load as part of next release. Right now the functionality is there, but it's under torch.distributed._shard.checkpoint but we plan to make it a dedicated subpackage under torch.distributed #88698 cc @kumpera @wz337

@wanchaol
Copy link
Collaborator Author

I wanna ask when will models such as DeviceMesh, Shard, distribute_tensor can be used? Will documentation be provided for these models? I want to use the tensordot function to realize the large-scale tensor contraction of distributed storage, so how can I code to realize it?

@ntyz We are working on landing it to pytorch soon, you can subscribe to this stack of PRs #88180 and once those are merged (hopefully this week or next week), you should be able to use it immediately in master or nightly build 1-2 days later.

We plan to release DTensor as a prototype feature in the next pytorch release, which we might add more documentation in the code APIs directly, but not in official https://docs.pytorch.org/ yet. It will be added to the doc website when it releases beta. Note that we will release DeviceMesh as a beta feature in the next release, and will add documentations on it and demonstrate how to use it :)

Feel free to try it out and submit issues to pytorch/pytorch directly or even contributing once it's there :)

wanchaol added a commit that referenced this issue Nov 15, 2022
…ests to core distributed"


This PR moves remaining tests, i.e. tensor_ops, op db tests to core distributed

part of #88838

[ghstack-poisoned]
wanchaol added a commit that referenced this issue Nov 15, 2022
… and tests to core distributed"


This PR moves tensor/parallel folder and tests to torch.distributed.

part of #88838

[ghstack-poisoned]
wanchaol added a commit that referenced this issue Nov 15, 2022
…re distributed"


This PR moves tensor/parallel folder and tests to torch.distributed.

part of #88838

[ghstack-poisoned]
wanchaol added a commit that referenced this issue Nov 15, 2022
…tributed"


This PR moves remaining tests, i.e. tensor_ops, op db tests to core distributed

part of #88838

[ghstack-poisoned]
wanchaol added a commit that referenced this issue Nov 15, 2022
… and tests to core distributed"


This PR moves tensor/parallel folder and tests to torch.distributed.

part of #88838

[ghstack-poisoned]
wanchaol added a commit that referenced this issue Nov 15, 2022
…ps to core distributed"


This PR moves the view related DTensor ops to core distributed,
tests will be add in follow up PRs

part of #88838

[ghstack-poisoned]
wanchaol added a commit that referenced this issue Nov 15, 2022
…ibuted"


This PR moves the view related DTensor ops to core distributed,
tests will be add in follow up PRs

part of #88838

[ghstack-poisoned]
wanchaol added a commit that referenced this issue Nov 15, 2022
… core distributed"


This PR moves DTensor op tests to core distributed, including
prop_rule, pointwise op, matrix op tests, etc.

part of #88838

[ghstack-poisoned]
@wanchaol
Copy link
Collaborator Author

hi - lots of people use databricks and run a cluster. does DTensor also work with databricks + pyspark running a cluster with CPUs? is there even a way to run it in that way?

@Arnold1 Sorry for the late reply. It could possibly run in a cluster with CPUs but probably not together with databricks or pyspark.

@wanchaol
Copy link
Collaborator Author
wanchaol commented Apr 26, 2023

Could DTensor supports treating CPU memory and GPU memory as a unified Memory pool

@fengjinyuan This might be something we want to explore in long term but not the current focus. The current focus is still around homogenous hardware.

I expect a model which runs on a single node can auto-implement model parallel(Tensor parallel/ Pipeline parallel) and data parallel on top of DTensor + Dist nn.Module. @wanchaol

Yeah this is something we are exploring

@philippwitte
Copy link
philippwitte commented Nov 10, 2023

Were there any updates to the API? I'm trying to run the very first example from the README, which appears not to be working (anymore). My full code is as follows:

import os

import torch
from torch.distributed._tensor import DeviceMesh
from torch.distributed._tensor import Shard
from torch.distributed._tensor import distribute_tensor

# Read rank and world size from MPI environment variables
rank = int(os.environ["OMPI_COMM_WORLD_RANK"])
world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"])

# Initiate the process group
torch.distributed.init_process_group(backend="nccl", world_size=world_size, rank=rank)
print("Hello from rank: {} of {}".format(rank, world_size))

# Device mesh
mesh = DeviceMesh("cuda", list(range(world_size)))

# Allocate some large tensor
big_tensor = torch.randn(100000, 88)

# Shard this tensor over the mesh by sharding `big_tensor`'s 0th dimension over the 0th dimension of `mesh`.
my_dtensor = distribute_tensor(big_tensor, mesh, [Shard(dim=0)])

print("Tensor shapes: ", big_tensor.shape, `my_dtensor.shape)

I'm running this script via mpirun -n 4 python3 hello_dtensor.py. This results in multiple ranks trying to use the same GPU:

Duplicate GPU detected : rank 3 and rank 0 both on CUDA device 100000

If I change the line that creates the tensor to big_tensor = torch.randn(1000000, 88, device='cuda:{}'.format(rank)), the code runs but doesn't appear to be doing anything. I.e., the shapes of the big and distributed tensors are the same. Is there something obvious I'm not doing correctly?

@wanchaol
Copy link
Collaborator Author
wanchaol commented Jan 9, 2024

Were there any updates to the API? I'm trying to run the very first example from the README, which appears not to be working (anymore).

Hey sorry for the late reply! We prototype released the APIs and are currently working on enhancing it to push to beta release.

The issue you observed is because when you first do init_process_group then construct DeviceMesh, we don't automatically set the cuda device for you, because as part of init_process_group(backend="nccl"), user need to be responsible to set up the cuda device for each process.

However If you just construct DeviceMesh alone without initializing process group, DeviceMesh will automatically set up the devices on each process.

I also updated the README to make everything runnable

@ad8e
Copy link
Contributor
ad8e commented Apr 28, 2024

How do we handle fused layers with DTensor?

For example, in SwiGLU, there are frequently two input matrices in the FF layer. These two matrices are fused into one big matrix. If we apply DTensor Tensor Parallel to shard this big matrix in the output dimension, the sharding will cross the wrong dimensions. I notice that gpt-fast and torchtitan both avoid using fused layers because of this problem, but that comes with a performance penalty.

With megatron layers, I handle the TP sharding and unsharding myself.

@wanchaol
Copy link
Collaborator Author

How do we handle fused layers with DTensor?

Hey @ad8e, thanks for bringing this up! For fused layers (i.e. fused SwiGLU or even fused QKV), DTensor should be able to represent its sharding layout by implementing a more complicated sharding layout like stride aware sharding, we are currently working on this but haven't enabled it yet, planning to make this work in the upcoming weeks!

I notice that gpt-fast and torchtitan both avoid using fused layers because of this problem, but that comes with a performance penalty.

The fused layers could be more performant, but I feel it might not give too much e2e perf gains to the Transformer model training. Horizontal fusions would give good wins when the CUDA computations are not saturated but usually the compute are saturated already for LLMs.

Either way I think it's better if we could support this out of box, but this does involves some work for a more complicated layouts, so stay tuned :)

@ad8e
Copy link
Contributor
ad8e commented May 18, 2024

On a transformer, with TP + Sequence Parallel (i.e. what torchtitan is doing), we should theoretically be able to hide all the TP comms inside the computation. Is this planned for DTensor? TP is very slow without it.

@wanchaol
Copy link
Collaborator Author
wanchaol commented May 20, 2024

On a transformer, with TP + Sequence Parallel (i.e. what torchtitan is doing), we should theoretically be able to hide all the TP comms inside the computation. Is this planned for DTensor? TP is very slow without it.

@ad8e TP is expected to be slower due to the on-critical path communications happened in each TransformerBlock, that's why TP always need to happen intra-host (where NVLink enable). TP's value is to enable larger model training (i.e. 70B) with hundreds to thousands of GPUs (where FSDP alone would OOM, please take a look at the tutorial we recently added). We are also working on fusing the compute and communication, but NCCL itself aren't performing well with p2p ops so we are working on some alternative solutions (i.e. we found that with NCCL fusing comm and comp is even slower than just exposing the TP comm). @yifuwang is making a cuda p2p backend that would make comm/computation overlap be performant, and we expect to make no model code change performance improvement with torch.compile, and provide a simple eager API for users who don't want to enable torch.compile.

Wondering how slow you are experiencing? For performance comparison, one common thing I want to mention is that when enabling FSDP + TP and compare the perf with FSDP only, one should bump their local batch size to multiply it by the model_parallel_degree/tp_degree so that both setups have the same global batch size.

@ad8e
Copy link
Contributor
ad8e commented May 20, 2024

With Llama3 70B: if I use full activation checkpointing with TP=1, it's 2x the MFU as TP=8 with DTensor without activation checkpointing. These reach the memory limit of an 16 HGX H100 system, which is the limitation I'm working around.

Using the TP+sequence parallel example in your tutorial: feed_forward.w2 changes the split from dim(2) to dim(1). My thought is that the communication can stream while w2 is being calculated. So when the w2 matrix finishes, the next block's RMSNorm+attention.wq starts, and let's say 50% of its input tensors have arrived, sharded on the sequence dimension. Since this RMSNorm+attention.wq's calculation is independent along the sequence dimension, it should be able to start its computation while waiting for the remaining tokens in the sequence to stream in. This way, the TP comm is not on the critical path so TP would be entirely free. This Nvidia JAX talk mentions it at 15:20: https://www.nvidia.com/en-us/on-demand/session/gtc24-s62246/

If only RMSNorm was considered individually, instead of RMSNorm+attention.wq together, then it may finish its calculations while only 70% of the communication has happened, so the TP may peek out.

@wanchaol
Copy link
Collaborator Author

With Llama3 70B: if I use full activation checkpointing with TP=1, it's 2x the MFU as TP=8 with DTensor without activation checkpointing. These reach the memory limit of an 16 HGX H100 system, which is the limitation I'm working around.

@ad8e 16 H100 is really a small number of GPUs that applying 2D parallelism won't see benefits comparing to just use FSDP alone for example, in practice you would really want to apply TP when the world size goes beyond 256 GPUs. I also tried locally using torchtitan on Llama3 8B:

  • 8 H100, and setting 8-way FSDP, TP=1, local batchsize=1, non-fused RMSNorm, full AC would give me around 26.9% MFU
  • 8 H100, 2-way FSDP, TP=4, local batch size =4, non-fused RMSNorm, full AC would give me around 22.4% MFU

I can try to run some 70B jobs but I would be surprised to see 2x slower here given that the compute is even heavier.

Using the TP+sequence parallel example in your tutorial: feed_forward.w2 changes the split from dim(2) to dim(1). My thought is that the communication can stream while w2 is being calculated. So when the w2 matrix finishes, the next block's RMSNorm+attention.wq starts, and let's say 50% of its input tensors have arrived, sharded on the sequence dimension. Since this RMSNorm+attention.wq's calculation is independent along the sequence dimension, it should be able to start its computation while waiting for the remaining tokens in the sequence to stream in. This way, the TP comm is not on the critical path so TP would be entirely free. For example, this Nvidia JAX talk mentions it at 15:20: https://www.nvidia.com/en-us/on-demand/session/gtc24-s62246/

If only RMSNorm was considered individually, instead of RMSNorm+attention.wq together, then it may finish its calculations while only 70% of the communication has happened, so the TP may peek out.

The optimizations that JAX is doing in this talk is exactly what we are approaching it, we are enabling partial gemm outputs with collectives (basically gemm + reduce_scatter decompose to gemms + collective permute), the issue still hold: NCCL is bad at this optimization, so we are pursuing a intra-host only optimization currently, and we'll possibily land this optimization in the next few weeks

@ad8e
Copy link
Contributor
ad8e commented May 20, 2024

I mean 16xHGX, so 128 H100s. I think our 7B numbers align with yours, but I would have to check.

For 70B, it's 340 TFLOPS for full AC, 137 TFLOPS for TP=8, both are MFU numbers.

Our FLOPS graphs also make quite interesting shapes with DTensor. This one is 70B TP=8 for about a minute:

Screenshot 2024-05-20 at 2 50 58 PM

This one is a vocab layer-only run with 70B TP=4 for 45 hours (half the backwards is missing, so the FLOPS is misleadingly high):
Screenshot 2024-05-20 at 2 54 36 PM
Notice how there are spikes upward every 4000 steps, which is when a checkpoint happens. Then the run is fast for an hour and then slow after.

The optimizations that JAX is doing in this talk is exactly what we are approaching it

Awesome, I look forward to it!

@ad8e
Copy link
Contributor
ad8e commented May 21, 2024

I tried your TP=4 and TP=1 tests on our internal codebase. Llama 8B, non-fused RMSNorm, FSDP1, PyTorch 2.3. 8 H100, full AC, 2048 context, 7x attention flops multiplier (rather than 12x).

TP=4 + BS=4, 150 TFLOPS
Screenshot 2024-05-20 at 5 39 06 PM

TP=1 + BS=1, 210 TFLOPS
Screenshot 2024-05-20 at 5 39 33 PM

The TP4/TP1 ratio is 150/210 = 71.4%. Yours is 22.4/26.9 = 83.3%, which is better.

I ran my TP=4 inside a profiler for 3 steps. The GPU waits for the CPU a lot. https://drive.google.com/file/d/1U3PSLETgbI4F5hTwHtkS-pxo_quRzvtp/view?usp=sharing

TP=1 profile is more reasonable: the GPU waits for comms, but at least there's no CPU waiting. https://drive.google.com/file/d/1p4AzFm_sqheVhGvGpcCFu3SKmADYO09E/view?usp=sharing

Characteristics of my benchmark:
Wasn't able to get torch.compile working with FSDP1; the activation checkpointing blocks it. (With DTensor off and AC off, torch.compile works.)
I used 2.3 because nightly works only half the time; there's an at-scale dataloader bug I have no repro for. Nothing related to DTensor.

@wanchaol
Copy link
Collaborator Author

I tried your TP=4 and TP=1 tests on our internal codebase. Llama 8B, non-fused RMSNorm, FSDP1, PyTorch 2.3. 8 H100, full AC, 2048 context, 7x attention flops multiplier (rather than 12x).

@ad8e curious why you use the 2048 seq length instead of the llama3 8B default 8192? We did observe some CPU overhead when the model size is small or seq_length is small (i.e. same issue with llama2 7B). This is usually not an issue for the cases where TP is actually needed, with either large model size or data that already saturates the cuda computation (i.e. even Llama3 8b with 8192 seq length works well). We are also investing into torch.compile so that TP layers can be fully compiled so that there would be F438 almost no CPU overhead.

@ad8e
Copy link
Contributor
ad8e commented May 21, 2024

Good point; I'll redo the tests at 8192. It was that way as a leftover from some vocab replacement runs, but 2048 is a bad benchmark here.

@ad8e
Copy link
Contributor
ad8e commented May 21, 2024

8192 context.
Traces:
https://drive.google.com/file/d/1CxpQYFQnTNy_YOR1VSX4MnymPpIALpMX/view?usp=sharing
https://drive.google.com/file/d/1Vx4UlkhLk3B8PBoenOejk2SlFKIF19qL/view?usp=sharing
TP4:
Screenshot 2024-05-20 at 8 57 31 PM
TP1:
Screenshot 2024-05-20 at 8 50 35 PM

New ratio is 257/308, which is the same as yours at 8B, 83.4%.

@wanchaol
Copy link
Collaborator Author

New ratio is 257/308, which is the same as yours at 8B, 83.4%.

@ad8e glad to see our numbers matched! I think for models beyond 8B, with the same setting the numbers should match or even higher when training on clusters. I'll do some more benchmarks for 70B models on my side too.

Once we have the comm/compute overlap enable this should be higher :)

@ad8e
Copy link
Contributor
ad8e commented May 21, 2024

My opinion from the user-friendliness side: I'm willing to do some manual DTensor labeling in my code.

There are a few types of PyTorch users in my mind.

Small users: training up to 8B, FSDP alone is enough on an HGX. It's easy, and it's performant if the user can make torch.compile work.

Medium users: at 13-34B, users will put some engineering effort in, but it's to the extent of writing a torchtitan-like codebase, or using one.

Big users: at 70B+, I'm willing to annotate anything that needs to be annotated. If it's not easy to deduce a property with torch.compile, I can specify that property. The perf is more important than user friendliness.

I'm considering this because the sequence parallel example I gave does not seem easy to deduce automatically. The comm/comp overlap should only require two splits in most cases. The computation of feed_forward.w2 is split in two in the output sequence dimension, and each half-output is consumed by RMSNorm+QKV independently. So the order of the operations is: ff.w2-1 [comm1] ff.w2-2 [comm2] rmsqkv-1 rmsqkv-2, with two sequencing events: rmsqkv-1 after comm1, and rmsqkv-2 after comm2. Each comm is for half the sequence dimension. That's simple enough that I could do the comm/comp overlap myself if I knew how to handle async streams properly, although I'd still run into the nccl issues you guys are experiencing, and I wouldn't be able to support torch.compile. Being allowed to split both the input and output sides of the communication also means the partial gemms are bigger (fewer splits). The property needed to know that this split is allowed is that RMSNorm+QKV is independent along its input sequence dimension, and ff.w2 is independent along its output sequence dimension.

But from the PyTorch side, tracing the independence of dimensions would require labeling a lot of ops. So if I can decorate the graph to say where the independent dimensions are, to make analysis easier, I would.

@wanchaol
Copy link
Collaborator Author

My opinion from the user-friendliness side: I'm willing to do some manual DTensor labeling in my code.

There are a few types of PyTorch users in my mind.

Small users: training up to 8B, FSDP alone is enough on an HGX. It's easy, and it's performant if the user can make torch.compile work.

Medium users: at 13-34B, users will put some engineering effort in, but it's to the extent of writing a torchtitan-like codebase, or using one.

Big users: at 70B+, I'm willing to annotate anything that needs to be annotated. If it's not easy to deduce a property with torch.compile, I can specify that property. The perf is more important than user friendliness.

But from the PyTorch side, tracing the independence of dimensions would require labeling a lot of ops. So if I can decorate the graph to say where the independent dimensions are, to make analysis easier, I would.

@ad8e Thanks a lot for the insightful feedbacks!! Yeah I think I'm aligned with you here. That is also why we don't specialize on Tensor Parallel, and working on getting DTensor to a better state. The model would change/evolve over time and different parallelisms would emerge, having the ability to allow user to annotate in every detail on how to perform sharding and communication is important for medium and big users! We'll expose enough control and flexibility to express sharding and runtime communication.

For the sequence parallel case specifically, we'll expose a fused comm + compute method that user can plug in to their forward (i.e. attention or ffn) together with the sharding of TP, so that we can even have eager mode performance.

@GuWei007
Copy link

Can you please provide a document describing the a 10000 pplication scenarios and best practices?

@guoyejun
Copy link
Contributor
guoyejun commented Jul 5, 2024

section Compiler and DistributedTensor mentions that torch.compile optimization is required for DistributedTensor for a better performance of FSDP, with collective fusion and computation overlap.

  • do we expect better performance of FSDP2 once we tuned the optimization in torch.compile? (I think the torch.compile optimization for DistributedTensor is not finished now)

  • which kinds of torch.compile optimizations are expected? IMO, a bucket can be used to fuse the collectives of each parameter (extra D2D copy will be added), and the fused collective can overlap with computation. Which are the left significant parts for torch.compile to improve?

thanks

@awgu
Copy link
Collaborator
awgu commented Jul 22, 2024

do we expect better performance of FSDP2 once we tuned the optimization in torch.compile? (I think the torch.compile optimization for DistributedTensor is not finished now)

torch.compile can still make the compute between FSDP2 communications faster in the same way as FSDP1. (FSDP2 does not use DTensor during forward/backward computation; it wraps DTensors and uses local tensors manually.)

which kinds of torch.compile optimizations are expected? IMO, a bucket can be used to fuse the collectives of each parameter (extra D2D copy will be added), and the fused collective can overlap with computation. Which are the left significant parts for torch.compile to improve?

With the current code, torch.compile graph breaks on FSDP2 logic (like all-gathering), so like mentioned above, the optimization will only be for module compute. @yf225 is working on tracing through FSDP2 logic itself, in which case we would have the collective ops in the graph, and some kind of re-bucketing could be possible. Though, we may not target that kind of optimization in the near term. Instead, communication reordering might be more interesting than communication fusion.

@guoyejun
Copy link
Contributor

thanks, could you explain a bit on what communication reordering is? thanks. (reorder allgather and reducescatter? reorder something within allgather/reducescatter?)

@awgu
Copy link
Collaborator
awgu commented Jul 23, 2024

I think the default FSDP ordering of its own all-gather and reduce-scatter in backward is pretty reasonable. I was thinking more along the lines of reordering with collectives from other parts of the model, e.g. TP communications or if you have a sparse embedding part of your model, some embedding all-to-alls.

@ad8e
Copy link
Contributor
ad8e commented Jul 29, 2024

If I call distribute_tensor on a tensor, does rank 0's tensor win, and all the other ranks' local tensors are discarded? Or must the tensors on each rank be equivalent? (I am trying to hack in nn.init.orthogonal_ by creating the tensor on one rank.)

distribute_tensor states, "If [placements] is not specified, we will by default replicate the tensor across the device_mesh from the first rank of each dimension of the device_mesh." Does this "from the first rank" still hold if we are sharding rather than replicating?

@wanchaol
Copy link
Collaborator Author
wanchaol commented Aug 1, 2024

If I call distribute_tensor on a tensor, does rank 0's tensor win, and all the other ranks' local tensors are discarded? Or must the tensors on each rank be equivalent? (I am trying to hack in nn.init.orthogonal_ by creating the tensor on one rank.)

distribute_tensor states, "If [placements] is not specified, we will by default replicate the tensor across the device_mesh from the first rank of each dimension of the device_mesh." Does this "from the first rank" still hold if we are sharding rather than replicating?

@ad8e Yes, when calling distribute_tensor, rank0 would by default be the source of truth of the "global" unsharded Tensor for both sharding and replication, this is done in this way to fully preserve the single device semantic, which is essential for numerical purpose. The unsharded tensor on other ranks will be discarded and there's no requirement that tensors on each rank be equivalent :)

@fhaolinaws
Copy link

Hi @wanchaol , can you share if there's plan or ongoing work to support compilation or lazy tensors. We are looking to support DTensor with XLA tensors better so the existing parallelism utils can be reused, including FSDP wrappers. Would like to learn if there's opportunity for collaboration.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: dtensor distributed tensor tag oncall: distributed Add this issue/PR to distributed oncall triage queue
Projects
None yet
Development

No branches or pull requests

0