8000 `context_parallel` fails for training with `RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation` · Issue #149306 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

context_parallel fails for training with RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation #149306

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
ydshieh opened this issue Mar 17, 2025 · 14 comments
Assignees
Labels
module: context parallel PyTorch Context Parallel oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@ydshieh
Copy link
ydshieh commented Mar 17, 2025

🐛 Describe the bug

Hi, I am from Hugging Face and we are trying to use context_parallel (using stable and nightly torch). However, for training, it fails with

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation

I have created a reproducible with minimal example where a very simple model DummyModel is defined in the script. The same error occurs for a real model (Qwen 2.5) too.

The same error happens for both SDPBackend.FLASH_ATTENTION and SDPBackend.EFFICIENT_ATTENTION.

To reproduce

Run the following script, on a multiple GPU machine (I am using a single cloud machine with 4 A10 GPU), as

  1. python script.py
  2. torchrun --nproc-per-node=2 script.py --distributed
  3. torchrun --nproc-per-node=2 script.py --distributed --use-cp

where 1. (not using any distributed stuff) and 2. (distributed, without CP) succeed and 3. (distributed with CP) fails.

script.py

import torch
torch.autograd.set_detect_anomaly(True)


class DummyOutput:
    def __init__(self, loss, logits, attn_out):
        self.loss = loss
        self.logits = logits
        self.attn_out = attn_out

    def __str__(self):
        return str({"loss": self.loss, "logits": self.logits, "attn_out": self.attn_out})


class DummyModel(torch.nn.Module):

    def __init__(self, vocab_size, hidden_dim, n_heads, is_causal=True):
        super().__init__()

        self.vocab_size = vocab_size
        self.hidden_dim = hidden_dim
        self.n_heads = n_heads
        self.head_dim = hidden_dim // n_heads

        self.is_causal = is_causal

        self.embedding = torch.nn.Embedding(num_embeddings=vocab_size, embedding_dim=self.hidden_dim)

        self.linear = torch.nn.Linear(hidden_dim, hidden_dim)
        self.q = torch.nn.Linear(hidden_dim, hidden_dim)
        self.k = torch.nn.Linear(hidden_dim, hidden_dim)
        self.v = torch.nn.Linear(hidden_dim, hidden_dim)
        self.atnn_out = torch.nn.Linear(hidden_dim, hidden_dim)

        self.proj = torch.nn.Linear(hidden_dim, vocab_size)

    # h being [batch_size, seq_len, hidden_dim]
    # we convert it to q, k, v here
    def forward(self, input_ids, labels=None):

        embeddings = self.embedding(input_ids)
        hidden_states = self.linear(embeddings)

        # we need to change it to q, k, v with [batch_size, n_head, seq_len, head_dim]
        # first, projection to get to [batch_size, seq_len, head_dim]
        q = self.q(hidden_states)
        k = self.k(hidden_states)
        v = self.v(hidden_states)

        batch_size = 1

        # reshape to [batch_size, n_head, seq_len, head_dim]
        q = q.view(batch_size, -1, self.n_heads, self.head_dim).transpose(1, 2)
        k = k.view(batch_size, -1, self.n_heads, self.head_dim).transpose(1, 2)
        v = v.view(batch_size, -1, self.n_heads, self.head_dim).transpose(1, 2)

        attn_out = F.scaled_dot_product_attention(q, k, v, is_causal=self.is_causal)

        # back to [batch_size, n_head, seq_len, head_dim]
        # need contiguous for training
        hidden = attn_out.transpose(1, 2).contiguous().view(batch_size, -1, self.n_heads * self.head_dim)

        atnn_out = self.atnn_out(hidden)
        logits = self.proj(atnn_out)

        loss = None
        if labels is not None:
            loss = torch.nn.functional.cross_entropy(logits.transpose(1, 2), labels)

        return DummyOutput(loss=loss, logits=logits, attn_out=attn_out)


def check(distributed=False, use_cp=False):

    device = "cuda"
    dtype = torch.bfloat16
    sdpa_backend = SDPBackend.FLASH_ATTENTION

    is_causal = True

    input_ids = torch.randint(low=8, high=64, size=(1, 64), device=device)
    labels = torch.clone(input_ids)

    model = DummyModel(vocab_size=128, hidden_dim=128, n_heads=4, is_causal=is_causal)
    model = model.to(device, dtype=dtype)
    model.eval()

    if distributed:
        dist.broadcast(input_ids, src=0)
        dist.broadcast(labels, src=0)

        rank = torch.distributed.get_node_local_rank()
        model = DDP(model, device_ids=[rank])

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)

    model.train()
    for step in range(3):

        model.zero_grad()
        optimizer.zero_grad()

        with sdpa_kernel(sdpa_backend):
          
8000
  if use_cp:
                with context_parallel(
                    cp_mesh, buffers=(input_ids, labels), buffer_seq_dims=(1, 1)
                ):
                    outputs = model(input_ids, labels=labels)
            else:
                outputs = model(input_ids=input_ids, labels=labels)

        loss = outputs.loss
        print(f"device: {loss.device} | step: {step} | loss = {loss.detach().to('cpu').float().numpy()}")

        loss.backward()
        optimizer.step()


if __name__ == '__main__':

    # python3 temp.py
    # torchrun --nproc-per-node=2 temp.py --distributed
    # torchrun --nproc-per-node=2 temp.py --distributed --use_cp

    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--distributed", action="store_true", default=False)
    parser.add_argument("--use-cp", action="store_true", default=False)
    parser.add_argument("--nproc-per-node", type=int, default=1)
    args = parser.parse_args()

    import torch
    import torch.nn.functional as F
    from torch.nn.attention import sdpa_kernel, SDPBackend

    distributed = args.distributed
    use_cp = args.use_cp

    if distributed:
        from torch.distributed.device_mesh import init_device_mesh
        from torch.nn.parallel import DistributedDataParallel as DDP
        import torch.distributed as dist
        if use_cp:
            from torch.distributed.tensor.experimental import context_parallel

        world_size = args.nproc_per_node
        cp_mesh = init_device_mesh("cuda", (world_size,))

    check(distributed=distributed, use_cp=use_cp)

Error log

root@dff7b35823a9:/transformers# torchrun --nproc-per-node=2 script.py --distributed --use-cp
W0317 08:57:27.892000 1659 torch/distributed/run.py:766] 
W0317 08:57:27.892000 1659 torch/distributed/run.py:766] *****************************************
W0317 08:57:27.892000 1659 torch/distributed/run.py:766] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W0317 08:57:27.892000 1659 torch/distributed/run.py:766] *****************************************
[rank1]: Traceback (most recent call last):
[rank1]:   File "/transformers/script.py", line 149, in <module>
[rank1]:     check(distributed=distributed, use_cp=use_cp)
[rank1]:   File "/transformers/script.py", line 105, in check
[rank1]:     with context_parallel(
[rank1]:   File "/usr/lib/python3.10/contextlib.py", line 135, in __enter__
[rank1]:     return next(self.gen)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 36, in generator_context
[rank1]:     response = gen.send(None)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/distributed/tensor/experimental/_attention.py", line 1345, in context_parallel
[rank1]:     chunks = _context_parallel_buffers(mesh, buffers, buffer_seq_dims)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/distributed/tensor/experimental/_attention.py", line 1287, in _context_parallel_buffers
[rank1]:     new_buffers.append(sharder.shard(buffer, mesh, seq_dim))
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/distributed/tensor/experimental/_attention.py", line 1244, in shard
[rank1]:     cp_rank = mesh.get_local_rank()
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/distributed/device_mesh.py", line 946, in get_local_rank
[rank1]:     mesh_dim_group = not_none(self.get_group(mesh_dim))
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/distributed/device_mesh.py", line 781, in get_group
[rank1]:     _find_pg_by_ranks_and_tag(*self._dim_group_infos[mesh_dim][:2])  # type: ignore[index]
[rank1]: IndexError: list index out of range
device: cuda:0 | step: 0 | loss = 4.84375
/usr/local/lib/python3.10/dist-packages/torch/autograd/graph.py:824: UserWarning: Error detected in NllLoss2DBackward0. Traceback of forward call that caused the error:
  File "/transformers/script.py", line 149, in <module>
    check(distributed=distributed, use_cp=use_cp)
  File "/transformers/script.py", line 108, in check
    outputs = model(input_ids, labels=labels)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/parallel/distributed.py", line 1637, in forward
    else self._run_ddp_forward(*inputs, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/parallel/distributed.py", line 1464, in _run_ddp_forward
    return self.module(*inputs, **kwargs)  # type: ignore[index]
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
  File "/transformers/script.py", line 68, in forward
    loss = torch.nn.functional.cross_entropy(logits.transpose(1, 2), labels)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/functional.py", line 3494, in cross_entropy
    return torch._C._nn.cross_entropy_loss(
 (Triggered internally at /pytorch/torch/csrc/autograd/python_anomaly_mode.cpp:122.)
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
[rank0]: Traceback (most recent call last):
[rank0]:   File "/transformers/script.py", line 149, in <module>
[rank0]:     check(distributed=distributed, use_cp=use_cp)
[rank0]:   File "/transformers/script.py", line 115, in check
[rank0]:     loss.backward()
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/_tensor.py", line 648, in backward
[rank0]:     torch.autograd.backward(
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/autograd/__init__.py", line 353, in backward
[rank0]:     _engine_run_backward(
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/autograd/graph.py", line 824, in _engine_run_backward
[rank0]:     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
[rank0]: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.LongTensor [1, 1, 64]] is at version 2; expected version 1 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!
[rank0]:[W317 08:57:31.906052155 ProcessGroupNCCL.cpp:1497] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
W0317 08:57:31.821000 1659 torch/distributed/elastic/multiprocessing/api.py:900] Sending process 1708 closing signal SIGTERM
E0317 08:57:31.985000 1659 torch/distributed/elastic/multiprocessing/api.py:874] failed (exitcode: 1) local_rank: 1 (pid: 1709) of binary: /usr/bin/python3
Traceback (most recent call last):
  File "/usr/local/bin/torchrun", line 8, in <module>
    sys.exit(main())
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 355, in wrapper
    return f(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/run.py", line 892, in main
    run(args)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/run.py", line 883, in run
    elastic_launch(
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/launcher/api.py", line 139, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/launcher/api.py", line 270, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:
============================================================
script.py FAILED
------------------------------------------------------------
Failures:
  <NO_OTHER_FAILURES>
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2025-03-17_08:57:31
  host      : dff7b35823a9
  rank      : 1 (local_rank: 1)
  exitcode  : 1 (pid: 1709)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
============================================================

Versions

PyTorch version: 2.8.0.dev20250315+cu126
Is debug build: False
CUDA used to build PyTorch: 12.6
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.3 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.20.3
Libc version: glibc-2.35

Python version: 3.10.12 (main, Feb  4 2025, 14:57:36) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-5.10.234-225.895.amzn2.x86_64-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA A10G
GPU 1: NVIDIA A10G
GPU 2: NVIDIA A10G
GPU 3: NVIDIA A10G

Nvidia driver version: 550.144.03
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.9.0
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                         x86_64
CPU op-mode(s):                       32-bit, 64-bit
Address sizes:                        48 bits physical, 48 bits virtual
Byte Order:                           Little Endian
CPU(s):  
8000
                             48
On-line CPU(s) list:                  0-47
Vendor ID:                            AuthenticAMD
Model name:                           AMD EPYC 7R32
CPU family:                           23
Model:                                49
Thread(s) per core:                   2
Core(s) per socket:                   24
Socket(s):                            1
Stepping:                             0
BogoMIPS:                             5599.99
Flags:                                fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_ts
c rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf tsc_known_freq pni pclmulqdq ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy cr8_legacy 
abm sse4a misalignsse 3dnowprefetch topoext ssbd ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 clzero xsaveerptr rdpru wbnoinvd arat npt nrip_save rdpid
Hypervisor vendor:                    KVM
Virtualization type:                  full
L1d cache:                            768 KiB (24 instances)
L1i cache:                            768 KiB (24 instances)
L2 cache:                             12 MiB (24 instances)
L3 cache:                             96 MiB (6 instances)
NUMA node(s):                         1
NUMA node0 CPU(s):                    0-47
Vulnerability Gather data sampling:   Not affected
Vulnerability Itlb multihit:          Not affected
Vulnerability L1tf:                   Not affected
Vulnerability Mds:                    Not affected
Vulnerability Meltdown:               Not affected
Vulnerability Mmio stale data:        Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed:               Mitigation; untrained return thunk; SMT enabled with STIBP protection
Vulnerability Spec rstack overflow:   Mitigation; safe RET
Vulnerability Spec store bypass:      Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:             Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:             Mitigation; Retpolines, IBPB conditional, STIBP always-on, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds:                  Not affected
Vulnerability Tsx async abort:        Not affected

Versions of relevant libraries:
[pip3] intel-extension-for-pytorch==2.3.0
[pip3] mypy-extensions==1.0.0
[pip3] natten==0.15.1+torch220cu121
[pip3] numpy==1.24.3
[pip3] nvidia-cublas-cu12==12.6.4.1
[pip3] nvidia-cuda-cupti-cu12==12.6.80
[pip3] nvidia-cuda-nvrtc-cu12==12.6.77
[pip3] nvidia-cuda-runtime-cu12==12.6.77
[pip3] nvidia-cudnn-cu12==9.5.1.17
[pip3] nvidia-cufft-cu12==11.3.0.4
[pip3] nvidia-curand-cu12==10.3.7.77
[pip3] nvidia-cusolver-cu12==11.7.1.2
[pip3] nvidia-cusparse-cu12==12.5.4.2
[pip3] nvidia-cusparselt-cu12==0.6.3
[pip3] nvidia-nccl-cu12==2.25.1
[pip3] nvidia-nvjitlink-cu12==12.6.85
[pip3] nvidia-nvtx-cu12==12.6.77
[pip3] onnx==1.17.0
[pip3] onnxconverter-common==1.13.0
[pip3] onnxruntime==1.21.0
[pip3] onnxruntime-tools==1.7.0
[pip3] pytorch-triton==3.3.0+git96316ce5
[pip3] tf2onnx==1.16.1
[pip3] torch==2.8.0.dev20250315+cu126
[pip3] torchaudio==2.6.0.dev20250315+cu126
[pip3] torchvision==0.22.0.dev20250315+cu126
[pip3] triton==3.2.0

cc @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o

@ydshieh
Copy link
Author
ydshieh commented Mar 17, 2025

cc: @XilunWu , @fegin as I saw you are pinged in some context_parallel issues 🙏

@zou3519 zou3519 added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Mar 17, 2025
@XilunWu XilunWu self-assigned this Mar 17, 2025
@XilunWu XilunWu added the module: context parallel PyTorch Context Parallel label Mar 17, 2025
@fduwjj fduwjj added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Apr 23, 2025
@ydshieh
Copy link
Author
ydshieh commented May 1, 2025

Hi @XilunWu @fegin @fduwjj

While this is issue is looked at a team member, would you kindly provide a similar minimal example of training script that demonstrate a working usage of context_parallel 🙏 ?

This post seems too much as an minimal example.

The API page doesn't provide example neither 😢

@XilunWu
Copy link
Contributor
XilunWu commented May 1, 2025

@ydshieh

Here is a new tutorial on Context Paralell in PyTorch: https://pytorch.org/tutorials/prototype/context_parallel.html

Besides, the unit test can be a minimal example:

with context_parallel(
device_mesh, buffers=[cp_q, cp_k, cp_v], buffer_seq_dims=[2, 2, 2]
):
cp_q.requires_grad = True
cp_k.requires_grad = True
cp_v.requires_grad = True
with CommDebugMode() as comm_mode:
with sdpa_kernel(backend):
if compiled:
fn = torch.compile(
F.scaled_dot_product_attention,
fullgraph=True,
backend="aot_eager",
)
else:
fn = F.scaled_dot_product_attention
cp_out = fn_eval(fn, cp_q, cp_k, cp_v, is_causal=is_causal)

You can ignore the with CommDebugMode part. Let me know if this is not ideal and I can add an example to torch/distributed/tensor/examples.

@ydshieh
Copy link
Author
ydshieh commented May 1, 2025

Thank you @XilunWu .

A quick look shows that the tutorial and the test doesn't have running loss.backward, which is what I am looking for at this moment. The tutorial have linked another post about 1M context, which is the post I mentioned in my previous comment.

I can wait though. I am just wondering if there is some simple example showing a training using CP api.

@ydshieh

Here is a new tutorial on Context Paralell in PyTorch: https://pytorch.org/tutorials/prototype/context_parallel.html

Besides, the unit test can be a minimal example:

with context_parallel(
device_mesh, buffers=[cp_q, cp_k, cp_v], buffer_seq_dims=[2, 2, 2]
):
cp_q.requires_grad = True
cp_k.requires_grad = True
cp_v.requires_grad = True
with CommDebugMode() as comm_mode:
with sdpa_kernel(backend):
if compiled:
fn = torch.compile(
F.scaled_dot_product_attention,
fullgraph=True,
backend="aot_eager",
)
else:
fn = F.scaled_dot_product_attention
cp_out = fn_eval(fn, cp_q, cp_k, cp_v, is_causal=is_causal)

You can ignore the with CommDebugMode part. Let me know if this is not ideal and I can add an example to torch/distributed/tensor/examples.

@XilunWu
Copy link
Contributor
XilunWu commented May 2, 2025

An end-to-end example is torchtitan, but that's a bit complicated and may include many details that you may not be interested.

Loss.backward is just as simple as calling out.backward().

@ydshieh
Copy link
Author
ydshieh commented May 2, 2025

Loss.backward is just as simple as calling out.backward().

Yes, but this is exactly what fails in the minimal code snippet that I provided. At this point, it's not clear if I'm doing some wrong, or it is indeed a general issue in context_parallel.

As you mentioned, there is torchtitan, so I assume training with context_parallel is possible, but what are the necessary configurations to make it work is unclear (i.e. how to avoid the issue RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation ). As this error only happens (in the provided code snippet) when context_parallel is used, so I believe the modeling code is not the source of the issue. But I don't know if the configurations I specified of the distributed stuff cause the issue. For example, if using model = DDP(model, device_ids=[rank]) makes sense together with context_parallel etc.
`

My bad: in the unittest, there is out.sum().backward(). I will try to run the unit test and add backward to it to see how it goes.

@ydshieh
Copy link
Author
ydshieh commented May 2, 2025

Hi @XilunWu

I incorporated my code snippet into the test test_ring_attention_sdpa, and somehow it works 🤔 .

Here is the new code snippet, i.e. Modified RingAttentionTest in test/distributed/tensor/test_attention.py to use the class DummyModel. This works by running python3 8000 test.py.

I have to use rank = int(str(input_ids.device)[-1]) otherwise rank = torch.distributed.get_node_local_rank() gives

RuntimeError: LOCAL_RANK is not in the environment. Consider passing fallback_rank to allow get_node_local_rank to work, assuming you are not running in a multi-device context and want the code to run locally instead.

With and without my_model = DDP(my_model, device_ids=[rank]), the test pass.

So it looks like there are some differences in the way of the distributed configurations between the way I launch my previous code snippet and the way the test is run.

test.py

# Copyright (c) Meta Platforms, Inc. and affiliates
# Owner(s): ["oncall: distributed"]
import unittest

import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch import nn
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor import DeviceMesh
from torch.distributed.tensor.debug import CommDebugMode
from torch.distributed.tensor.experimental._attention import (
    _AttentionContextParallel,
    _CausalBehavior,
    _cp_options,
    _DispatchMode,
    _is_causal_behavior,
    _RotateMethod,
    context_parallel,
    context_parallel_unshard,
    set_rotate_method,
)
from torch.distributed.tensor.parallel import parallelize_module
from torch.nn.attention import sdpa_kernel, SDPBackend
from torch.testing._internal.common_cuda import (
    PLATFORM_SUPPORTS_CUDNN_ATTENTION,
    PLATFORM_SUPPORTS_FLASH_ATTENTION,
    PLATFORM_SUPPORTS_FUSED_ATTENTION,
    PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
)
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_utils import run_tests, skipIfRocm
from torch.testing._internal.distributed._tensor.common_dtensor import (
    DTensorTestBase,
    ModelArgs,
    Transformer,
    with_comms,
)

from torch.nn.parallel import DistributedDataParallel as DDP

c10d_functional = torch.ops.c10d_functional
backends = []
if PLATFORM_SUPPORTS_FLASH_ATTENTION:
    backends.append(SDPBackend.FLASH_ATTENTION)
# if PLATFORM_SUPPORTS_MEM_EFF_ATTENTION:
#     backends.append(SDPBackend.EFFICIENT_ATTENTION)
if PLATFORM_SUPPORTS_CUDNN_ATTENTION:
    backends.append(SDPBackend.CUDNN_ATTENTION)

rotater_enum_to_str = {
    _RotateMethod.ALL_GATHER: "allgather",
    _RotateMethod.ALL_TO_ALL: "alltoall",
}  # mapping from _RotateMethod enum to string


class MyModel(torch.nn.Module):

    def __init__(self):
        super().__init__()

    def forward(self, q, k, v):
        return F.scaled_dot_product_attention(q, k, v, is_causal=True)


class DummyOutput:
    def __init__(self, loss, logits, attn_out):
        self.loss = loss
        self.logits = logits
        self.attn_out = attn_out

    def __str__(self):
        return str({"loss": self.loss, "logits": self.logits, "attn_out": self.attn_out})


class DummyModel(torch.nn.Module):

    def __init__(self, vocab_size, hidden_dim, n_heads, is_causal=True):
        super().__init__()

        self.vocab_size = vocab_size
        self.hidden_dim = hidden_dim
        self.n_heads = n_heads
        self.head_dim = hidden_dim // n_heads

        self.is_causal = is_causal

        self.embedding = torch.nn.Embedding(num_embeddings=vocab_size, embedding_dim=self.hidden_dim)

        self.linear = torch.nn.Linear(hidden_dim, hidden_dim)
        self.q = torch.nn.Linear(hidden_dim, hidden_dim)
        self.k = torch.nn.Linear(hidden_dim, hidden_dim)
        self.v = torch.nn.Linear(hidden_dim, hidden_dim)
        self.atnn_out = torch.nn.Linear(hidden_dim, hidden_dim)

        self.proj = torch.nn.Linear(hidden_dim, vocab_size)

    # h being [batch_size, seq_len, hidden_dim]
    # we convert it to q, k, v here
    def forward(self, input_ids, labels=None):

        embeddings = self.embedding(input_ids)
        hidden_states = self.linear(embeddings)

        # we need to change it to q, k, v with [batch_size, n_head, seq_len, head_dim]
        # first, projection to get to [batch_size, seq_len, head_dim]
        q = self.q(hidden_states)
        k = self.k(hidden_states)
        v = self.v(hidden_states)

        batch_size = 1

        # reshape to [batch_size, n_head, seq_len, head_dim]
        q = q.view(batch_size, -1, self.n_heads, self.head_dim).transpose(1, 2)
        k = k.view(batch_size, -1, self.n_heads, self.head_dim).transpose(1, 2)
        v = v.view(batch_size, -1, self.n_heads, self.head_dim).transpose(1, 2)

        attn_out = F.scaled_dot_product_attention(q, k, v, is_causal=self.is_causal)

        # back to [batch_size, n_head, seq_len, head_dim]
        # need contiguous for training
        hidden = attn_out.transpose(1, 2).contiguous().view(batch_size, -1, self.n_heads * self.head_dim)

        atnn_out = self.atnn_out(hidden)
        logits = self.proj(atnn_out)

        loss = None
        if labels is not None:
            loss = torch.nn.functional.cross_entropy(logits.transpose(1, 2), labels)

        return DummyOutput(loss=loss, logits=logits, attn_out=attn_out)

class RingAttentionTest(DTensorTestBase):
    @property
    def world_size(self) -> int:
        return torch.cuda.device_count()

    @property
    def destroy_pg_upon_exit(self) -> bool:
        return False

    @with_comms
    def test_ring_attention_sdpa(self) -> None:
        self.run_subtests(
            {
                "is_causal": [True],
                "backend": backends,
                "load_balance": [True],
                "dispatch_mode": [
                    _DispatchMode.MONKEY_PATCH,
                    _DispatchMode.TORCH_FUNCTION,
                ],
            },
            self._test_ring_attention_sdpa,
        )

    def _test_ring_attention_sdpa(
        self,
        is_causal: bool,
        backend: SDPBackend,
        load_balance: bool,
        dispatch_mode: _DispatchMode,
    ) -> None:
        torch.distributed.tensor.experimental._attention._dispatch_mode = dispatch_mode

        device_mesh = DeviceMesh(self.device_type, torch.arange(0, self.world_size))
        torch.manual_seed(10)

        _cp_options.enable_load_balance = load_balance

        input_ids = torch.randint(low=8, high=64, size=(1, 128), device="cuda")
        labels = torch.clone(input_ids).to("cuda")

        dist.broadcast(input_ids, src=0)
        dist.broadcast(labels, src=0)

        my_model = DummyModel(vocab_size=128, hidden_dim=128, n_heads=4, is_causal=is_causal)
        my_model.to(device="cuda", dtype=torch.bfloat16)

        #rank = torch.distributed.get_node_local_rank()
        rank = int(str(input_ids.device)[-1])
        print(rank)
        my_model = DDP(my_model, device_ids=[rank])

        with context_parallel(
            device_mesh, buffers=[input_ids, labels], buffer_seq_dims=[1, 1]
        ):
            with sdpa_kernel(backend):
                outputs = my_model(input_ids, labels=labels)
                outputs.loss.backward()

if __name__ == "__main__":
    run_tests()

@ydshieh
Copy link
Author
ydshieh commented May 2, 2025

I finally find what is the issue

        with sdpa_kernel(sdpa_backend):
                with context_parallel(cp_mesh, buffers=(input_ids, labels), buffer_seq_dims=(1, 1)):
                    outputs = model(input_ids, labels=labels)

        loss = outputs.loss
        loss.backward()

should be

        with sdpa_kernel(sdpa_backend):
                with context_parallel(cp_mesh, buffers=(input_ids, labels), buffer_seq_dims=(1, 1)):
                    outputs = model(input_ids, labels=labels)

                    loss = outputs.loss
                    loss.backward()

i.e. loss.backward have to be within with context_parallel. 😢 😿 🎉 💯

We can probably close the issue now.

For reference: the full working example

import torch


class DummyOutput:
    def __init__(self, loss, logits, attn_out):
        self.loss = loss
        self.logits = logits
        self.attn_out = attn_out

    def __str__(self):
        return str({"loss": self.loss, "logits": self.logits, "attn_out": self.attn_out})


class DummyModel(torch.nn.Module):

    def __init__(self, vocab_size, hidden_dim, n_heads, is_causal=True):
        super().__init__()

        self.vocab_size = vocab_size
        self.hidden_dim = hidden_dim
        self.n_heads = n_heads
        self.head_dim = hidden_dim // n_heads

        self.is_causal = is_causal

        self.embedding = torch.nn.Embedding(num_embeddings=vocab_size, embedding_dim=self.hidden_dim)

        self.linear = torch.nn.Linear(hidden_dim, hidden_dim)
        self.q = torch.nn.Linear(hidden_dim, hidden_dim)
        self.k = torch.nn.Linear(hidden_dim, hidden_dim)
        self.v = torch.nn.Linear(hidden_dim, hidden_dim)
        self.atnn_out = torch.nn.Linear(hidden_dim, hidden_dim)

        self.proj = torch.nn.Linear(hidden_dim, vocab_size)

    # h being [batch_size, seq_len, hidden_dim]
    # we convert it to q, k, v here
    def forward(self, input_ids, labels=None):

        embeddings = self.embedding(input_ids)
        hidden_states = self.linear(embeddings)

        # we need to change it to q, k, v with [batch_size, n_head, seq_len, head_dim]
        # first, projection to get to [batch_size, seq_len, head_dim]
        q = self.q(hidden_states)
        k = self.k(hidden_states)
        v = self.v(hidden_states)

        batch_size = 1

        # reshape to [batch_size, n_head, seq_len, head_dim]
        q = q.view(batch_size, -1, self.n_heads, self.head_dim).transpose(1, 2)
        k = k.view(batch_size, -1, self.n_heads, self.head_dim).transpose(1, 2)
        v = v.view(batch_size, -1, self.n_heads, self.head_dim).transpose(1, 2)

        attn_out = F.scaled_dot_product_attention(q, k, v, is_causal=self.is_causal)

        # back to [batch_size, n_head, seq_len, head_dim]
        # need contiguous for training
        hidden = attn_out.transpose(1, 2).contiguous().view(batch_size, -1, self.n_heads * self.head_dim)

        atnn_out = self.atnn_out(hidden)
        logits = self.proj(atnn_out)

        loss = None
        if labels is not None:
            loss = torch.nn.functional.cross_entropy(logits.transpose(1, 2), labels)

        return DummyOutput(loss=loss, logits=logits, attn_out=attn_out)


def check(distributed=False, use_cp=False):

    device = "cuda"
    dtype = torch.bfloat16
    sdpa_backend = SDPBackend.FLASH_ATTENTION

    is_causal = True

    input_ids = torch.randint(low=8, high=64, size=(1, 64), device=device)
    labels = torch.clone(input_ids)

    model = DummyModel(vocab_size=128, hidden_dim=128, n_heads=4, is_causal=is_causal)
    model = model.to(device, dtype=dtype)
    model.eval()

    if distributed:
        dist.broadcast(input_ids, src=0)
        dist.broadcast(labels, src=0)

        rank = torch.distributed.get_node_local_rank()
        model = DDP(model, device_ids=[rank])

    optimizer = torch.optim.Adam(model.parameters(), lr=4e-5)

    model.train()
    for step in range(3):

        model.zero_grad()
        optimizer.zero_grad()

        with sdpa_kernel(sdpa_backend):
            if use_cp:
                with context_parallel(
                    cp_mesh, buffers=(input_ids, labels), buffer_seq_dims=(1, 1)
                ):
                    outputs = model(input_ids, labels=labels)
                    loss = outputs.loss
                    loss.backward()
                    optimizer.step()
            else:
                outputs = model(input_ids=input_ids, labels=labels)
                loss = outputs.loss
                loss.backward()
                optimizer.step()

        print(f"device: {loss.device} | step: {step} | loss = {loss.detach().to('cpu').float().numpy()}")


if __name__ == '__main__':

    # python3 script.py
    # torchrun --nproc-per-node=2 script.py --distributed
    # torchrun --nproc-per-node=2 script.py --distributed --use-cp

    import os

    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--distributed", action="store_true", default=False)
    parser.add_argument("--use-cp", action="store_true", default=False)
    parser.add_argument("--nproc-per-node", type=int, default=1)
    args = parser.parse_args()

    import torch
    import torch.nn.functional as F
    from torch.nn.attention import sdpa_kernel, SDPBackend

    distributed = args.distributed
    use_cp = args.use_cp

    if distributed:
        from torch.distributed.device_mesh import init_device_mesh
        from torch.nn.parallel import DistributedDataParallel as DDP
        import torch.distributed as dist
        if use_cp:
            from torch.distributed.tensor.experimental import context_parallel

        world_size = int(os.environ.get("WORLD_SIZE", "1"))
        cp_mesh = init_device_mesh("cuda", (world_size,))

    check(distributed=distributed, use_cp=use_cp)

@ydshieh
Copy link
Author
ydshieh commented May 7, 2025

After diving into it further, I found the reason.

Set no_restore_buffers to the set of buffers , it won't restore the buffers, and we can use backwward outside the context! This is what

                with context_parallel(
                    cp_mesh, buffers=buffers, buffer_seq_dims=buffer_seq_dims, no_restore_buffers=set(buffers),
                ):
                    outputs = model(input_ids, **model_kwargs)
                loss = outputs.loss
                loss.backward()

Mentioned in the doc

If the buffers won't be used after the context exits,
these buffers can be put in this list to avoid extra restore time.

so doing it this way also avoid some potential overhead!

@fegin
Copy link
Contributor
fegin commented May 8, 2025

@ydshieh Any reason why you have to put backward() outside of the context? Without the context, the backward() may not perform correct attention (no correct q, k, v gathering).

@ydshieh
Copy link
Author
ydshieh commented May 8, 2025

Hi @fegin Thank you for mentioning this. I try to observe the weight difference after one training step, between CP and without CP, and compare these difference values with loss.backward/optimizer.step done in and outside context_parallel. Indeed there is higher difference when loss.backward/optimizer.step is done outside the context than done inside it.

But in both cases, they remain quite small. I will provide a code snippet.

@ydshieh
Copy link
Author
ydshieh commented May 14, 2025

Confirmed that loss.backward need to be within with context_parallel otherwise we get the wrong outputs.

script.py
import json
import os
import torch
import torch.distributed as dist

from torch.distributed.tensor.experimental import context_parallel
from torch.distributed.tensor.experimental._attention import context_parallel_unshard
from torch.distributed.device_mesh import init_device_mesh
from torch.distributions.utils import logits_to_probs

from torch.nn.attention import sdpa_kernel, SDPBackend
from torch.nn.parallel import DistributedDataParallel as DDP

from transformers import AutoModelForCausalLM, AutoConfig
from transformers.loss.loss_utils import ForCausalLMLoss


world_size = int(os.environ.get("WORLD_SIZE", "1"))
cp_mesh = init_device_mesh("cuda", (world_size,))
rank = torch.distributed.get_node_local_rank()

device = "cuda"
dtype = torch.float32

sdpa_backend = SDPBackend.EFFICIENT_ATTENTION


# prepare inputs
batch_size = 1
seq_len = 128

ignore_index = -100

# model and optimizer
repo_id = "Qwen/Qwen2.5-Coder-0.5B-Instruct"

# For loss
config = AutoConfig.from_pretrained(repo_id)
vocab_size = config.vocab_size

# prepare for CP

buffer_seq_dims = (1, 1, 1)
# `no_restore_buffers=set(buffers)` is required if `loss.backward` is outside `context_parallel`.

def create_inputs():

    input_ids = torch.randint(low=8, high=64, size=(batch_size, seq_len), device=device)

    # When using CP, we need to use `shift_labels`
    shift_labels = torch.nn.functional.pad(input_ids, (0, 1), value=ignore_index)
    shift_labels = shift_labels[..., 1:].contiguous()

    position_ids = torch.cumsum(torch.ones(size=input_ids.size(), dtype=input_ids.dtype, device=input_ids.device), dim=1) - 1

    # sync input as they are created randomly
    dist.broadcast(input_ids, src=0)
    dist.broadcast(shift_labels, src=0)
    dist.broadcast(position_ids, src=0)

    cp_buffers = (input_ids, shift_labels, position_ids)

    return cp_buffers

def create_model():

    import gc;
    gc.collect()
    torch._dynamo.reset()
    torch.cuda.empty_cache()

    model = AutoModelForCausalLM.from_pretrained(repo_id, torch_dtype=dtype, device_map=device)
    model.train()
    model.zero_grad()

    if use_cp:
        model = DDP(model, device_ids=[rank])

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, eps=1.0)
    optimizer.zero_grad()

    return model, optimizer


def train(
    model,
    optimizer,
    cp_buffers,
    use_cp=False,
    loss_outside_cp=False,
    backward_outside_cp=False,
):
    buffers = tuple(x.clone() for x in cp_buffers)
    input_ids, shift_labels, position_ids = buffers

    # run with CP
    with sdpa_kernel(sdpa_backend):

        if use_cp:

            no_restore_buffers = set(buffers)

            with context_parallel(
                cp_mesh, buffers=buffers, buffer_seq_dims=buffer_seq_dims, no_restore_buffers=no_restore_buffers,
            ):
                outputs = model(input_ids, shift_labels=shift_labels, position_ids=position_ids)

                if not loss_outside_cp:
                    loss = ForCausalLMLoss(logits=outputs.logits, labels=None, shift_labels=shift_labels, vocab_size=vocab_size)

                if not backward_outside_cp:
                    loss.backward()

        else:

            outputs = model(input_ids, shift_labels=shift_labels, position_ids=position_ids)

        if loss_outside_cp:
            loss = ForCausalLMLoss(logits=outputs.logits, labels=None, shift_labels=shift_labels, vocab_size=vocab_size)

        if backward_outside_cp:
            loss.backward()

    optimizer.step()

    if use_cp:
        (logits,) = context_parallel_unshard(cp_mesh, [outputs.logits], [1])
    else:
        logits = outputs.logits

    values = {}
    if rank == 0:

        def named_data():
            yield "logits", logits

            for name, param in model.named_parameters():
                if name.startswith("module."):
                    name = name[len("module."):]

                if name in ["model.embed_tokens.weight", "model.layers.0.self_attn.q_proj.weight", "model.layers.0.self_attn.k_proj.weight", "model.layers.0.self_attn.v_proj.weight"]:
                    yield name, <
9E88
span class="pl-s1">param

        for name, data in named_data():
            value = data.detach().float().to("cpu").numpy()
            values[name] = value

    import gc;
    gc.collect()
    torch._dynamo.reset()
    torch.cuda.empty_cache()

    return values


if __name__ == '__main__':

    # torchrun --nproc-per-node=2 script.py

    # `kernels` sometimes gives large differences (e.g. > 1e-5) across ranks, even in `eval mode + without CP/DDP`.
    os.system("pip uninstall -y kernels")

    options = [
        (False, True, True),  # no cp
        (True, False, False), # cp, loss and backward inside `context_parallel`
        (True, False, True),  # cp, loss inside + backward outside `context_parallel`
        (True, True, True),   # cp, loss and backward inside `context_parallel`
    ]

    cp_buffers = create_inputs()

    # run each configuration
    values = {}
    for option in options:

        (use_cp, loss_outside_cp, backward_outside_cp) = option
        model, optimizer = create_model()

        _values = train(
            model,
            optimizer,
            cp_buffers,
            use_cp=use_cp,
            loss_outside_cp=loss_outside_cp,
            backward_outside_cp=backward_outside_cp,
        )

        if rank == 0:
            values[option] = _values

    if rank == 0:
        diffs_over_config = {}
        for i in range(1):
            for j in range(len(options)):
                if j <= i:
                    continue

                cp, loss_outside_cp, backward_outside_cp = options[i]
                if not cp:
                    option_1 = f"cp={cp}"
                else:
                    option_1 = f"cp={cp}, loss_outside_cp={loss_outside_cp}, backward_outside_cp={backward_outside_cp}"

                cp, loss_outside_cp, backward_outside_cp = options[j]
                option_2 = f"cp={cp}, loss_outside_cp={loss_outside_cp}, backward_outside_cp={backward_outside_cp}"

                option_pair = f"{option_1} | {option_2}"
                diffs_over_config[option_pair] = {}

                for name in values[options[i]]:
                    diff = values[options[i]][name] - values[options[j]][name]
                    import numpy as np
                    max_diff = float(np.amax(np.abs(diff)))

                    diffs_over_config[option_pair][name] = max_diff

        print(json.dumps(diffs_over_config, indent=4))
        with open(f"diff.json", "w") as fp:
            json.dump(diffs_over_config, fp, indent=4)

gives

{
    "cp=False | cp=True, loss_outside_cp=False, backward_outside_cp=False": {
        "logits": 3.719329833984375e-05,
        "model.embed_tokens.weight": 6.1588361859321594e-06,
        "model.layers.0.self_attn.q_proj.weight": 3.5762786865234375e-07,
        "model.layers.0.self_attn.k_proj.weight": 3.8463622331619263e-07,
        "model.layers.0.self_attn.v_proj.weight": 2.9781367629766464e-06
    },
    "cp=False | cp=True, loss_outside_cp=False, backward_outside_cp=True": {
        "logits": 3.719329833984375e-05,
        "model.embed_tokens.weight": 0.0008226484060287476,
        "model.layers.0.self_attn.q_proj.weight": 0.00020362436771392822,
        "model.layers.0.self_attn.k_proj.weight": 9.939692972693592e-05,
        "model.layers.0.self_attn.v_proj.weight": 0.0005658193840645254
    },
    "cp=False | cp=True, loss_outside_cp=True, backward_outside_cp=True": {
        "logits": 3.719329833984375e-05,
        "model.embed_tokens.weight": 0.0008226484060287476,
        "model.layers.0.self_attn.q_proj.weight": 0.00020362436771392822,
        "model.layers.0.self_attn.k_proj.weight": 9.939692972693592e-05,
        "model.layers.0.self_attn.v_proj.weight": 0.0005658193840645254
    }
}

which shows loss or/and backward computed outside context_parallel have larger difference (compared to no context_parallel) then computed inside context_parallel.

@ydshieh
Copy link
Author
ydshieh commented May 14, 2025

I am going to close this issue as it is resolved, we just need to put loss.backward within context_parallel, and it is necessary to get the correct results.

@ydshieh ydshieh closed this as completed May 14, 2025
@ydshieh
Copy link
Author
ydshieh commented May 14, 2025

if one really want to do loss.backward outside the context_parallel of model(...), it could be done this way

with context_parallel(
    cp_mesh, buffers=buffers, buffer_seq_dims=buffer_seq_dims, no_restore_buffers=set(buffers),
):
    model(...)

..... (whatever between that might be outside)

# prepare the buffers used to computed the loss, like the `labels`.
# Here we can't contain any element in the original `buffers`
new_buffers = ... 
new_buffer_seq_dims = ...

with context_parallel(
    cp_mesh, buffers=new_buffers , buffer_seq_dims=new_buffer_seq_dims ,
):
    loss = ....

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: context parallel PyTorch Context Parallel oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

5 participants
0