-
Notifications
You must be signed in to change notification settings - Fork 24.7k
[PGNCCL] Launch kernel on current stream & remove record_stream
entirely
#148467
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…147820) Summary: PTD current workflow: - PTD creates its own dedicated `ncclStream` for comm operation - it will first add a dependency on current-stream (typically the compute stream) to ensure tensors are ready before invoking collective such stream synchronization become expensive in Inference world (cpu overhead: 70us vs GPU kernel time: 160us). This diff: - async=False [default], will use current-stream as nccl-stream and avoid the stream-sync overhead - async=True, will retain existing logic: create new nccl-stream, let it wait on current-stream to ensure tensors are ready - pass down async from c10d down to NCCL-PG this helps shave off 50% CPU overhead **(70us -> 35us)**, which reduce total CPU/GPU from **230us to 195us by 15%** Test Plan: - # AMD > before ``` [cenzhao@devgpu039.atn3 ~/fbsource/fbcode (2265d32f0)]$ buck2 run @//mode/opt-amd-gpu -c fbcode.split-dwarf=True //param_bench/train/comms/pt:launcher -- --launcher mpi --nnode 1 --collective all_reduce --b 20M --e 20M --data-type bfloat16 --backend nccl --n 100 --w 5 --envs "NCCL_DEBUG_FILE=/tmp/dedicated_log_rccl.%h.%p.log;NCCL_DEBUG=INFO;NCCL_DEBUG_SUBSYS=INIT,COLL;MSCCL_ALGO_DIR=/data/users/${USER}/fbsource/third-party/rccl/develop/tools/msccl-algorithms;RCCL_MSCCLPP_THRESHOLD=$((128*1024*1024));RCCL_MSCCLPP_ENABLE=1;TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK=1;" --size-start-profiler 20M ``` https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/traces/dynocli/devgpu039.atn3.facebook.com/rank-0.Feb_24_16_19_28.354787.pt.trace.json.gz&bucket=hpc_traces {F1975408857} - c10d::allreduce_(69us) - cudaStreamSync (23us) - nccl::all_reduce(26us) > after ``` [cenzhao@devgpu039.atn3 ~/fbsource/fbcode (2265d32f0)]$ buck2 run @//mode/opt-amd-gpu -c fbcode.split-dwarf=True //param_bench/train/comms/pt:launcher -- --launcher mpi --nnode 1 --collective all_reduce --b 20M --e 20M --data-type bfloat16 --backend nccl --n 100 --w 5 --envs "NCCL_DEBUG_FILE=/tmp/dedicated_log_rccl.%h.%p.log;NCCL_DEBUG=INFO;NCCL_DEBUG_SUBSYS=INIT,COLL;MSCCL_ALGO_DIR=/data/users/${USER}/fbsource/third-party/rccl/develop/tools/msccl-algorithms;RCCL_MSCCLPP_THRESHOLD=$((128*1024*1024));RCCL_MSCCLPP_ENABLE=1;TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK=1;TORCH_NCCL_USE_CURRENT_STREAM_AS_NCCL_STREAM=1" --size-start-profiler 20M ``` https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/traces/dynocli/devgpu039.atn3.facebook.com/rank-4.Feb_24_16_22_56.534269.pt.trace.json.gz&bucket=hpc_traces {F1975408962} - c10d:allreduce_(37us) - cudaStreamSync (gone) - nccl::all_reduce(20us) # NV > before ``` [cenzhao@devgpu019.prn3 ~/fbsource/fbcode (e3f64263c)]$ buck2 run @//mode/opt -c fbcode.split-dwarf=True //param_bench/train/comms/pt:launcher -- --launcher mpi --nnode 1 --collective all_reduce --b 20M --e 20M --data-type bfloat16 --backend nccl --n 100 --w 5 --envs "NCCL_DEBUG_FILE=/tmp/dedicated_log_rccl.%h.%p.log;NCCL_DEBUG=INFO;NCCL_DEBUG_SUBSYS=INIT,COLL;" --size-start-profiler 20M ``` https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/traces/dynocli/devgpu019.prn3.facebook.com/rank-2.Feb_25_11_11_28.3328768.pt.trace.json.gz&bucket=hpc_traces {F1975437097} - c10d::allreduce_ (62us) - cudaStreamWait (0us) - nccl::all_reduce (47us) > after ``` [cenzhao@devgpu019.prn3 ~/fbsource/fbcode (e3f64263c)]$ buck2 run @//mode/opt -c fbcode.split-dwarf=True //param_bench/train/comms/pt:launcher -- --launcher mpi --nnode 1 --collective all_reduce --b 20M --e 20M --data-type bfloat16 --backend nccl --n 100 --w 5 --envs "NCCL_DEBUG_FILE=/tmp/dedicated_log_rccl.%h.%p.log;NCCL_DEBUG=INFO;NCCL_DEBUG_SUBSYS=INIT,COLL;TORCH_NCCL_USE_CURRENT_STREAM_AS_NCCL_STREAM=1" --size-start-profiler 20M ``` https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/traces/dynocli/devgpu019.prn3.facebook.com/rank-4.Feb_25_11_17_05.3469865.pt.trace.json.gz&bucket=hpc_traces {F1975437192} - c10d::allreduce_ (62us) - cudaStreamWait (gone) - nccl:all_reduce (53us) Differential Revision: D70135605
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/148467
Note: Links to docs will display an error until the docs builds have been completed. ❌ 5 New FailuresAs of commit 69042db with merge base 10ffd94 ( NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
} | ||
|
||
// Therefore, we warn and fall back to the typical recordStream logic. | ||
// TODO( kwen2501 ): revisit this when we have a better solution. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few options here:
- Change the requirement to be that the user has to call
.wait()
even on isend calls. - Add some logic to the destructor that does the
synchronizeStream
or something equivalent to block the destruction on the p2p op
Moving over to #148590 as internal user needs |
This PR has multiple changes to
ProcessGroupNCCL
(which unfortunately have to be atomic):cudaStreamWaitEvent
in PGNCCL #146881work.wait()
called by distributed_c10d.py.record_stream
and use CPU-side stashing for managing tensor lifetime against recycling.record_stream
in c10d causes FSDP2 to over-allocate GPU memory #147168work.wait()
, we ask watchdog to unstash tensors after detecting completion of collectives, to prevent us from holding reference to tensors forever. This is a safety net, rather than a service guarantee, see discussion here.Joint work with @cenzhaometa who wants to remove the event sync overhead.
Cc: @ngimel @awgu @Aidyn-A @skyw @wconstab @leonardo0lyj
cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o