8000 [c10d] ProcessGroupNCCL cuda streams got merged in nightly · Issue #153296 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[c10d] ProcessGroupNCCL cuda streams got merged in nightly #153296

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
weifengpy opened this issue May 9, 2025 · 4 comments
Closed

[c10d] ProcessGroupNCCL cuda streams got merged in nightly #153296

weifengpy opened this issue May 9, 2025 · 4 comments
Labels
oncall: distributed Add this issue/PR to distributed oncall triage queue

Comments

@weifengpy
Copy link
Contributor
weifengpy commented May 9, 2025

🚀 The feature, motivation and pitch

ProcessGroupNCCL uses indepdent cuda streams in pytorch 2.7.0, but use merged cuda streams in pytorch nightly

Image

pytorch 2.7.0: pip3 install torch torchvision torchaudio
pytorch nightly: pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu126

# torchrun --nproc-per-node 4 test_reshard_after_forward.py

import os
import pickle
from typing import Required

import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed.tensor import init_device_mesh, DeviceMesh
from torch.distributed.fsdp import fully_shard

import contextlib
import os

import torch
import torch.nn as nn


@contextlib.contextmanager
def enable_profiling(enable=False):
    if not enable:
        torch_profiler = contextlib.nullcontext()
        yield None
    else:
        trace_dir = "./profilers"
        rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0

        def trace_handler(prof):
            curr_trace_dir_name = "iteration_" + str(prof.step_num)
            curr_trace_dir = os.path.join(trace_dir, curr_trace_dir_name)
            if not os.path.exists(curr_trace_dir):
                os.makedirs(curr_trace_dir, exist_ok=True)
            prof.export_chrome_trace(f"{curr_trace_dir}/rank{rank}_trace.json")

        if not os.path.exists(trace_dir):
            os.makedirs(trace_dir, exist_ok=True)
        warmup, active = 1, 2
        wait = 1
        with torch.profiler.profile(
            activities=[
                torch.profiler.ProfilerActivity.CPU,
                torch.profiler.ProfilerActivity.CUDA,
            ],
            schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active),
            on_trace_ready=trace_handler,
            record_shapes=True,
        ) as torch_profiler:
            yield torch_profiler

torch.cuda.memory._record_memory_history(max_entries=10000000)

class MyModel(nn.Module):
    def __init__(self, num_layer, dim):
        super().__init__()
        self.layers = nn.Sequential(
            *[nn.Linear(dim, dim, bias=False) for _ in range(num_layer)]
        )
        self.out_proj = nn.Linear(dim, dim, bias=False)
        self.weight = nn.Parameter(torch.rand(dim, dim), requires_grad=True)
    
    def forward(self, input):
        out = self.layers(input)
        return self.out_proj(out)

def main():
    dist.init_process_group(backend="nccl")
    gpu_id = int(os.environ["LOCAL_RANK"])
    device = f"cuda:{gpu_id}"
    torch.cuda.set_device(device)
    # FSDP
    # mesh = init_device_mesh("cuda", (int(os.environ["WORLD_SIZE"]),), mesh_dim_names=("dp_shard",))
    # HSDP
    # mesh = init_device_mesh("cuda", (2, 2), mesh_dim_names=("dp_replicate", "dp_shard"))
    torch.manual_seed(0)
    num_layer = 40
    batch_size = 1
    dim = 4
    # batch_size = 1
    with torch.device("cuda"):
        model = MyModel(num_layer, dim)
    for layer in model.layers:
        fully_shard(layer, reshard_after_forward=2)
    fully_shard(model, reshard_after_forward=2)
    with enable_profiling(True) as prof:
        for iter in range(5):
            x = torch.rand((batch_size, dim), device="cuda", requires_grad=True)
            loss = model(x).sum()
            loss.backward()
            prof.step()
    if torch.distributed.get_rank() == 0:
        snapshot = torch.cuda.memory._snapshot()
        pickle.dump(snapshot, open('test_shardit.pickle', 'wb'))
    torch.distributed.destroy_process_group()


if __name__ == "__main__":
    main()

Alternatives

No response

Additional context

No response

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k

@weifengpy
Copy link
Contributor Author

@kwen2501 for PG and cuda stream usage in nightly. this is helpful for me to debug fully_shard(reshard_after_forward=int)

@kwen2501
Copy link
Contributor

It is a purposed change in torch 2.8 :) -- treat collective kernels the same as other kernels if async_op=False, by launching them on the "current stream."

@kwen2501
Copy link
Contributor

See #150398

Launching on "current stream" gives us some benefits:

  • easier to manage tensor lifetime;
  • flexibility for user to control which stream a collective runs on;
  • less overhead;
  • easier to see dependency in profiler.

@weifengpy
Copy link
Contributor Author

thanks for explaining this!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
oncall: distributed Add this issue/PR to distributed oncall triage queue
Projects
None yet
Development

No branches or pull requests

2 participants
0