-
Notifications
You must be signed in to change notification settings - Fork 24.2k
context_parallel
fails for training with RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
#149306
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
While this is issue is looked at a team member, would you kindly provide a similar minimal example of training script that demonstrate a working usage of This post seems too much as an minimal example. The API page doesn't provide example neither 😢 |
Here is a new tutorial on Context Paralell in PyTorch: https://pytorch.org/tutorials/prototype/context_parallel.html Besides, the unit test can be a minimal example: pytorch/test/distributed/tensor/test_attention.py Lines 169 to 186 in 0ba9169
You can ignore the |
Thank you @XilunWu . A quick look shows that the tutorial and the test doesn't have running loss.backward, which is what I am looking for at this moment. The tutorial have linked another post about 1M context, which is the post I mentioned in my previous comment. I can wait though. I am just wondering if there is some simple example showing a training using CP api.
|
An end-to-end example is torchtitan, but that's a bit complicated and may include many details that you may not be interested. Loss.backward is just as simple as calling out.backward(). |
Yes, but this is exactly what fails in the minimal code snippet that I provided. At this point, it's not clear if I'm doing some wrong, or it is indeed a general issue in As you mentioned, there is My bad: in the unittest, there is |
Hi @XilunWu I incorporated my code snippet into the test Here is the new code snippet, i.e. Modified I have to use
With and without So it looks like there are some differences in the way of the test.py# Copyright (c) Meta Platforms, Inc. and affiliates
# Owner(s): ["oncall: distributed"]
import unittest
import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch import nn
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor import DeviceMesh
from torch.distributed.tensor.debug import CommDebugMode
from torch.distributed.tensor.experimental._attention import (
_AttentionContextParallel,
_CausalBehavior,
_cp_options,
_DispatchMode,
_is_causal_behavior,
_RotateMethod,
context_parallel,
context_parallel_unshard,
set_rotate_method,
)
from torch.distributed.tensor.parallel import parallelize_module
from torch.nn.attention import sdpa_kernel, SDPBackend
from torch.testing._internal.common_cuda import (
PLATFORM_SUPPORTS_CUDNN_ATTENTION,
PLATFORM_SUPPORTS_FLASH_ATTENTION,
PLATFORM_SUPPORTS_FUSED_ATTENTION,
PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
)
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_utils import run_tests, skipIfRocm
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
ModelArgs,
Transformer,
with_comms,
)
from torch.nn.parallel import DistributedDataParallel as DDP
c10d_functional = torch.ops.c10d_functional
backends = []
if PLATFORM_SUPPORTS_FLASH_ATTENTION:
backends.append(SDPBackend.FLASH_ATTENTION)
# if PLATFORM_SUPPORTS_MEM_EFF_ATTENTION:
# backends.append(SDPBackend.EFFICIENT_ATTENTION)
if PLATFORM_SUPPORTS_CUDNN_ATTENTION:
backends.append(SDPBackend.CUDNN_ATTENTION)
rotater_enum_to_str = {
_RotateMethod.ALL_GATHER: "allgather",
_RotateMethod.ALL_TO_ALL: "alltoall",
} # mapping from _RotateMethod enum to string
class MyModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, q, k, v):
return F.scaled_dot_product_attention(q, k, v, is_causal=True)
class DummyOutput:
def __init__(self, loss, logits, attn_out):
self.loss = loss
self.logits = logits
self.attn_out = attn_out
def __str__(self):
return str({"loss": self.loss, "logits": self.logits, "attn_out": self.attn_out})
class DummyModel(torch.nn.Module):
def __init__(self, vocab_size, hidden_dim, n_heads, is_causal=True):
super().__init__()
self.vocab_size = vocab_size
self.hidden_dim = hidden_dim
self.n_heads = n_heads
self.head_dim = hidden_dim // n_heads
self.is_causal = is_causal
self.embedding = torch.nn.Embedding(num_embeddings=vocab_size, embedding_dim=self.hidden_dim)
self.linear = torch.nn.Linear(hidden_dim, hidden_dim)
self.q = torch.nn.Linear(hidden_dim, hidden_dim)
self.k = torch.nn.Linear(hidden_dim, hidden_dim)
self.v = torch.nn.Linear(hidden_dim, hidden_dim)
self.atnn_out = torch.nn.Linear(hidden_dim, hidden_dim)
self.proj = torch.nn.Linear(hidden_dim, vocab_size)
# h being [batch_size, seq_len, hidden_dim]
# we convert it to q, k, v here
def forward(self, input_ids, labels=None):
embeddings = self.embedding(input_ids)
hidden_states = self.linear(embeddings)
# we need to change it to q, k, v with [batch_size, n_head, seq_len, head_dim]
# first, projection to get to [batch_size, seq_len, head_dim]
q = self.q(hidden_states)
k = self.k(hidden_states)
v = self.v(hidden_states)
batch_size = 1
# reshape to [batch_size, n_head, seq_len, head_dim]
q = q.view(batch_size, -1, self.n_heads, self.head_dim).transpose(1, 2)
k = k.view(batch_size, -1, self.n_heads, self.head_dim).transpose(1, 2)
v = v.view(batch_size, -1, self.n_heads, self.head_dim).transpose(1, 2)
attn_out = F.scaled_dot_product_attention(q, k, v, is_causal=self.is_causal)
# back to [batch_size, n_head, seq_len, head_dim]
# need contiguous for training
hidden = attn_out.transpose(1, 2).contiguous().view(batch_size, -1, self.n_heads * self.head_dim)
atnn_out = self.atnn_out(hidden)
logits = self.proj(atnn_out)
loss = None
if labels is not None:
loss = torch.nn.functional.cross_entropy(logits.transpose(1, 2), labels)
return DummyOutput(loss=loss, logits=logits, attn_out=attn_out)
class RingAttentionTest(DTensorTestBase):
@property
def world_size(self) -> int:
return torch.cuda.device_count()
@property
def destroy_pg_upon_exit(self) -> bool:
return False
@with_comms
def test_ring_attention_sdpa(self) -> None:
self.run_subtests(
{
"is_causal": [True],
"backend": backends,
"load_balance": [True],
"dispatch_mode": [
_DispatchMode.MONKEY_PATCH,
_DispatchMode.TORCH_FUNCTION,
],
},
self._test_ring_attention_sdpa,
)
def _test_ring_attention_sdpa(
self,
is_causal: bool,
backend: SDPBackend,
load_balance: bool,
dispatch_mode: _DispatchMode,
) -> None:
torch.distributed.tensor.experimental._attention._dispatch_mode = dispatch_mode
device_mesh = DeviceMesh(self.device_type, torch.arange(0, self.world_size))
torch.manual_seed(10)
_cp_options.enable_load_balance = load_balance
input_ids = torch.randint(low=8, high=64, size=(1, 128), device="cuda")
labels = torch.clone(input_ids).to("cuda")
dist.broadcast(input_ids, src=0)
dist.broadcast(labels, src=0)
my_model = DummyModel(vocab_size=128, hidden_dim=128, n_heads=4, is_causal=is_causal)
my_model.to(device="cuda", dtype=torch.bfloat16)
#rank = torch.distributed.get_node_local_rank()
rank = int(str(input_ids.device)[-1])
print(rank)
my_model = DDP(my_model, device_ids=[rank])
with context_parallel(
device_mesh, buffers=[input_ids, labels], buffer_seq_dims=[1, 1]
):
with sdpa_kernel(backend):
outputs = my_model(input_ids, labels=labels)
outputs.loss.backward()
if __name__ == "__main__":
run_tests() |
I finally find what is the issue with sdpa_kernel(sdpa_backend):
with context_parallel(cp_mesh, buffers=(input_ids, labels), buffer_seq_dims=(1, 1)):
outputs = model(input_ids, labels=labels)
loss = outputs.loss
loss.backward() should be
i.e. We can probably close the issue now. For reference: the full working exampleimport torch
class DummyOutput:
def __init__(self, loss, logits, attn_out):
self.loss = loss
self.logits = logits
self.attn_out = attn_out
def __str__(self):
return str({"loss": self.loss, "logits": self.logits, "attn_out": self.attn_out})
class DummyModel(torch.nn.Module):
def __init__(self, vocab_size, hidden_dim, n_heads, is_causal=True):
super().__init__()
self.vocab_size = vocab_size
self.hidden_dim = hidden_dim
self.n_heads = n_heads
self.head_dim = hidden_dim // n_heads
self.is_causal = is_causal
self.embedding = torch.nn.Embedding(num_embeddings=vocab_size, embedding_dim=self.hidden_dim)
self.linear = torch.nn.Linear(hidden_dim, hidden_dim)
self.q = torch.nn.Linear(hidden_dim, hidden_dim)
self.k = torch.nn.Linear(hidden_dim, hidden_dim)
self.v = torch.nn.Linear(hidden_dim, hidden_dim)
self.atnn_out = torch.nn.Linear(hidden_dim, hidden_dim)
self.proj = torch.nn.Linear(hidden_dim, vocab_size)
# h being [batch_size, seq_len, hidden_dim]
# we convert it to q, k, v here
def forward(self, input_ids, labels=None):
embeddings = self.embedding(input_ids)
hidden_states = self.linear(embeddings)
# we need to change it to q, k, v with [batch_size, n_head, seq_len, head_dim]
# first, projection to get to [batch_size, seq_len, head_dim]
q = self.q(hidden_states)
k = self.k(hidden_states)
v = self.v(hidden_states)
batch_size = 1
# reshape to [batch_size, n_head, seq_len, head_dim]
q = q.view(batch_size, -1, self.n_heads, self.head_dim).transpose(1, 2)
k = k.view(batch_size, -1, self.n_heads, self.head_dim).transpose(1, 2)
v = v.view(batch_size, -1, self.n_heads, self.head_dim).transpose(1, 2)
attn_out = F.scaled_dot_product_attention(q, k, v, is_causal=self.is_causal)
# back to [batch_size, n_head, seq_len, head_dim]
# need contiguous for training
hidden = attn_out.transpose(1, 2).contiguous().view(batch_size, -1, self.n_heads * self.head_dim)
atnn_out = self.atnn_out(hidden)
logits = self.proj(atnn_out)
loss = None
if labels is not None:
loss = torch.nn.functional.cross_entropy(logits.transpose(1, 2), labels)
return DummyOutput(loss=loss, logits=logits, attn_out=attn_out)
def check(distributed=False, use_cp=False):
device = "cuda"
dtype = torch.bfloat16
sdpa_backend = SDPBackend.FLASH_ATTENTION
is_causal = True
input_ids = torch.randint(low=8, high=64, size=(1, 64), device=device)
labels = torch.clone(input_ids)
model = DummyModel(vocab_size=128, hidden_dim=128, n_heads=4, is_causal=is_causal)
model = model.to(device, dtype=dtype)
model.eval()
if distributed:
dist.broadcast(input_ids, src=0)
dist.broadcast(labels, src=0)
rank = torch.distributed.get_node_local_rank()
model = DDP(model, device_ids=[rank])
optimizer = torch.optim.Adam(model.parameters(), lr=4e-5)
model.train()
for step in range(3):
model.zero_grad()
optimizer.zero_grad()
with sdpa_kernel(sdpa_backend):
if use_cp:
with context_parallel(
cp_mesh, buffers=(input_ids, labels), buffer_seq_dims=(1, 1)
):
outputs = model(input_ids, labels=labels)
loss = outputs.loss
loss.backward()
optimizer.step()
else:
outputs = model(input_ids=input_ids, labels=labels)
loss = outputs.loss
loss.backward()
optimizer.step()
print(f"device: {loss.device} | step: {step} | loss = {loss.detach().to('cpu').float().numpy()}")
if __name__ == '__main__':
# python3 script.py
# torchrun --nproc-per-node=2 script.py --distributed
# torchrun --nproc-per-node=2 script.py --distributed --use-cp
import os
import argparse
|
After diving into it further, I found the reason. Set with context_parallel(
cp_mesh, buffers=buffers, buffer_seq_dims=buffer_seq_dims, no_restore_buffers=set(buffers),
):
outputs = model(input_ids, **model_kwargs)
loss = outputs.loss
loss.backward() Mentioned in the doc
so doing it this way also avoid some potential overhead! |
@ydshieh Any reason why you have to put |
Hi @fegin Thank you for mentioning this. I try to observe the weight difference after one training step, between CP and without CP, and compare these difference values with But in both cases, they remain quite small. I will provide a code snippet. |
Confirmed that script.pyimport json
import os
import torch
import torch.distributed as dist
from torch.distributed.tensor.experimental import context_parallel
from torch.distributed.tensor.experimental._attention import context_parallel_unshard
from torch.distributed.device_mesh import init_device_mesh
from torch.distributions.utils import logits_to_probs
from torch.nn.attention import sdpa_kernel, SDPBackend
from torch.nn.parallel import DistributedDataParallel as DDP
from transformers import AutoModelForCausalLM, AutoConfig
from transformers.loss.loss_utils import ForCausalLMLoss
world_size = int(os.environ.get("WORLD_SIZE", "1"))
cp_mesh = init_device_mesh("cuda", (world_size,))
rank = torch.distributed.get_node_local_rank()
device = "cuda"
dtype = torch.float32
sdpa_backend = SDPBackend.EFFICIENT_ATTENTION
# prepare inputs
batch_size = 1
seq_len = 128
ignore_index = -100
# model and optimizer
repo_id = "Qwen/Qwen2.5-Coder-0.5B-Instruct"
# For loss
config = AutoConfig.from_pretrained(repo_id)
vocab_size = config.vocab_size
# prepare for CP
buffer_seq_dims = (1, 1, 1)
# `no_restore_buffers=set(buffers)` is required if `loss.backward` is outside `context_parallel`.
def create_inputs():
input_ids = torch.randint(low=8, high=64, size=(batch_size, seq_len), device=device)
# When using CP, we need to use `shift_labels`
shift_labels = torch.nn.functional.pad(input_ids, (0, 1), value=ignore_index)
shift_labels = shift_labels[..., 1:].contiguous()
position_ids = torch.cumsum(torch.ones(size=input_ids.size(), dtype=input_ids.dtype, device=input_ids.device), dim=1) - 1
# sync input as they are created randomly
dist.broadcast(input_ids, src=0)
dist.broadcast(shift_labels, src=0)
dist.broadcast(position_ids, src=0)
cp_buffers = (input_ids, shift_labels, position_ids)
return cp_buffers
def create_model():
import gc;
gc.collect()
torch._dynamo.reset()
torch.cuda.empty_cache()
model = AutoModelForCausalLM.from_pretrained(repo_id, torch_dtype=dtype, device_map=device)
model.train()
model.zero_grad()
if use_cp:
model = DDP(model, device_ids=[rank])
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, eps=1.0)
optimizer.zero_grad()
return model, optimizer
def train(
model,
optimizer,
cp_buffers,
use_cp=False,
loss_outside_cp=False,
backward_outside_cp=False,
):
buffers = tuple(x.clone() for x in cp_buffers)
input_ids, shift_labels, position_ids = buffers
# run with CP
with sdpa_kernel(sdpa_backend):
if use_cp:
no_restore_buffers = set(buffers)
with context_parallel(
cp_mesh, buffers=buffers, buffer_seq_dims=buffer_seq_dims, no_restore_buffers=no_restore_buffers,
):
outputs = model(input_ids, shift_labels=shift_labels, position_ids=position_ids)
if not loss_outside_cp:
loss = ForCausalLMLoss(logits=outputs.logits, labels=None, shift_labels=shift_labels, vocab_size=vocab_size)
if not backward_outside_cp:
loss.backward()
else:
outputs = model(input_ids, shift_labels=shift_labels, position_ids=position_ids)
if loss_outside_cp:
loss = ForCausalLMLoss(logits=outputs.logits, labels=None, shift_labels=shift_labels, vocab_size=vocab_size)
if backward_outside_cp:
loss.backward()
optimizer.step()
if use_cp:
(logits,) = context_parallel_unshard(cp_mesh, [outputs.logits], [1])
else:
logits = outputs.logits
values = {}
if rank == 0:
def named_data():
yield "logits", logits
for name, param in model.named_parameters():
if name.startswith("module."):
name = name[len("module."):]
if name in ["model.embed_tokens.weight", "model.layers.0.self_attn.q_proj.weight", "model.layers.0.self_attn.k_proj.weight", "model.layers.0.self_attn.v_proj.weight"]:
yield name, <
9E88
span class="pl-s1">param
for name, data in named_data():
value = data.detach().float().to("cpu").numpy()
values[name] = value
import gc;
gc.collect()
torch._dynamo.reset()
torch.cuda.empty_cache()
return values
if __name__ == '__main__':
# torchrun --nproc-per-node=2 script.py
# `kernels` sometimes gives large differences (e.g. > 1e-5) across ranks, even in `eval mode + without CP/DDP`.
os.system("pip uninstall -y kernels")
options = [
(False, True, True), # no cp
(True, False, False), # cp, loss and backward inside `context_parallel`
(True, False, True), # cp, loss inside + backward outside `context_parallel`
(True, True, True), # cp, loss and backward inside `context_parallel`
]
cp_buffers = create_inputs()
# run each configuration
values = {}
for option in options:
(use_cp, loss_outside_cp, backward_outside_cp) = option
model, optimizer = create_model()
_values = train(
model,
optimizer,
cp_buffers,
use_cp=use_cp,
loss_outside_cp=loss_outside_cp,
backward_outside_cp=backward_outside_cp,
)
if rank == 0:
values[option] = _values
if rank == 0:
diffs_over_config = {}
for i in range(1):
for j in range(len(options)):
if j <= i:
continue
cp, loss_outside_cp, backward_outside_cp = options[i]
if not cp:
option_1 = f"cp={cp}"
else:
option_1 = f"cp={cp}, loss_outside_cp={loss_outside_cp}, backward_outside_cp={backward_outside_cp}"
cp, loss_outside_cp, backward_outside_cp = options[j]
option_2 = f"cp={cp}, loss_outside_cp={loss_outside_cp}, backward_outside_cp={backward_outside_cp}"
option_pair = f"{option_1} | {option_2}"
diffs_over_config[option_pair] = {}
for name in values[options[i]]:
diff = values[options[i]][name] - values[options[j]][name]
import numpy as np
max_diff = float(np.amax(np.abs(diff)))
diffs_over_config[option_pair][name] = max_diff
print(json.dumps(diffs_over_config, indent=4))
with open(f"diff.json", "w") as fp:
json.dump(diffs_over_config, fp, indent=4) gives {
"cp=False | cp=True, loss_outside_cp=False, backward_outside_cp=False": {
"logits": 3.719329833984375e-05,
"model.embed_tokens.weight": 6.1588361859321594e-06,
"model.layers.0.self_attn.q_proj.weight": 3.5762786865234375e-07,
"model.layers.0.self_attn.k_proj.weight": 3.8463622331619263e-07,
"model.layers.0.self_attn.v_proj.weight": 2.9781367629766464e-06
},
"cp=False | cp=True, loss_outside_cp=False, backward_outside_cp=True": {
"logits": 3.719329833984375e-05,
"model.embed_tokens.weight": 0.0008226484060287476,
"model.layers.0.self_attn.q_proj.weight": 0.00020362436771392822,
"model.layers.0.self_attn.k_proj.weight": 9.939692972693592e-05,
"model.layers.0.self_attn.v_proj.weight": 0.0005658193840645254
},
"cp=False | cp=True, loss_outside_cp=True, backward_outside_cp=True": {
"logits": 3.719329833984375e-05,
"model.embed_tokens.weight": 0.0008226484060287476,
"model.layers.0.self_attn.q_proj.weight": 0.00020362436771392822,
"model.layers.0.self_attn.k_proj.weight": 9.939692972693592e-05,
"model.layers.0.self_attn.v_proj.weight": 0.0005658193840645254
}
} which shows |
I am going to close this issue as it is resolved, we just need to put |
if one really want to do with context_parallel(
cp_mesh, buffers=buffers, buffer_seq_dims=buffer_seq_dims, no_restore_buffers=set(buffers),
):
model(...)
..... (whatever between that might be outside)
# prepare the buffers used to computed the loss, like the `labels`.
# Here we can't contain any element in the original `buffers`
new_buffers = ...
new_buffer_seq_dims = ...
with context_parallel(
cp_mesh, buffers=new_buffers , buffer_seq_dims=new_buffer_seq_dims ,
):
loss = ....
|
🐛 Describe the bug
Hi, I am from Hugging Face and we are trying to use
context_parallel
(usingstable
andnightly torch
). However, for training, it fails withI have created a reproducible with minimal example where a very simple model
DummyModel
is defined in the script. The same error occurs for a real model (Qwen 2.5) too.The same error happens for both
SDPBackend.FLASH_ATTENTION
andSDPBackend.EFFICIENT_ATTENTION
.To reproduce
Run the following script, on a multiple GPU machine (I am using a single cloud machine with 4 A10 GPU), as
where 1. (not using any distributed stuff) and 2. (distributed, without CP) succeed and 3. (distributed with CP) fails.
script.py
Error log
Versions
cc @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o
The text was updated successfully, but these errors were encountered: