- Pytorch/XLA SPMD Test code in Google TPU.
- For more details, see SPMD user guide
pip install torch~=2.1.0 --index-url https://download.pytorch.org/whl/cpu
pip install torch_xla[tpu]~=2.1.0 -f https://storage.googleapis.com/libtpu-releases/index.html
pip install -r requirements.txt
# If you use conda, set your lib/
export LD_LIBRARY_PATH=/path/to/conda/envs/your_env/lib
# export LD_LIBRARY_PATH=$HOME/anaconda3/envs/easydel/lib/
# for transformers library in TPU
export USE_TORCH=True
# 1. define mesh (use fsdp in in this example)
num_devices = xr.global_runtime_device_count()
mesh_shape = (1, num_devices, 1)
device_ids = np.array(range(num_devices))
mesh = Mesh(device_ids, mesh_shape, ('dp', 'fsdp', 'mp'))
...
# 2. partition module with mesh
partition_module(model, mesh)
...
# 3. partition input with mesh and forward
input_ids = torch.randint(0, 1000, (batch_size, seq_length + 1)).to(xm.xla_device())
xs.mark_sharding(input_ids, mesh, (0, 1))
output = model(input_ids=input_ids[:, :-1], labels=input_ids[:, 1:])
You can see the implemented sharding rules in spmd_util.py
Model | Implemented |
---|---|
GPT-NeoX | ✅ |
T5 | ✅ |
Llama | ✅ |
- code: spmd_gpt.py
- Architecture: GPT NeoX
- batch_size: 1
- Optimizer: AdamW
- Mesh: (1, 8, 1). Mesh shape이 바뀌면 OOM이 생기기도 합니다.
# params | seq_length | Inference | Trainable / LoRA |
---|---|---|---|
1.4B | 2048 | ✅ | ✅ / ✅ |
2.8B | 2048 | ✅ | 🤯 / ✅ |
3.8B | 2048 | ✅ | 🤯 / ✅ |
5.8B | 2048 | ✅(4) | 🤯 / ✅ |
6.9B | 2048 | ✅(2) | 🤯 / ✅(2) |
12.8B | 2048 | ✅ | 🤯 / ✅ |
() 괄호는 해당 크기 배치로 했을 때 OOM 발생하지 않았음을 표시. 표시 없을 경우 1이고 더 큰 배치에서 안되는 것은 아님. 실험을 안한 것.
# params | seq_length | Inference | Trainable / LoRA |
---|---|---|---|
1.4B | 2048 | ✅ | ✅ / ✅ |
2.8B | 2048 | ✅ | ✅ / ✅ |
3.8B | 2048 | ✅ | ✅ / ✅ |
5.8B | 2048 | ✅ | 🤯 / ✅ |
6.9B | 2048 | ✅(4) | 🤯 / ✅(4) |
12.8B | 2048 | ✅(2) | 🤯 / ✅(1) |
sudo docker run -it --name torch \
-d --privileged \
-p 7860:7860 \
-v `pwd`:/workspace \
us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_20230901 \
/bin/bash