8000 [DCP][Draft] Checkpoint daemon process fixes · pytorch/pytorch@86081b8 · GitHub
[go: up one dir, main page]

Skip to content

Commit 86081b8

Browse files
meetv18facebook-github-bot
authored andcommitted
[DCP][Draft] Checkpoint daemon process fixes
Differential Revision: D71336180
1 parent 6c7d841 commit 86081b8

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

torch/distributed/checkpoint/_async_process_executor.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from typing import Any, Optional, Union
99
from uuid import uuid4
1010

11+
import torch
12+
1113
import torch.distributed as dist
1214
import torch.multiprocessing as mp
1315
from torch.distributed.checkpoint._async_executor import _AsyncCheckpointExecutor
@@ -55,7 +57,7 @@ class _ProcessGroupInitInfo:
5557
tcp_store_master_port: int
5658

5759
def __init__(self, process_group: Optional[dist.ProcessGroup] = None):
58-
self.local_rank = dist.get_node_local_rank(fallback_rank=0)
60+
self.local_rank = dist.get_node_local_rank(fallback_rank=dist.get_rank(process_group)%8)
5961
self.global_rank = dist.get_rank(process_group)
6062
self.world_size = dist.get_world_size(process_group)
6163

@@ -176,13 +178,12 @@ def _checkpointing_subprocess(
176178
os.environ["LOCAL_RANK"] = str(pg_init_info.local_rank)
177179
os.environ["RANK"] = str(pg_init_info.global_rank)
178180
os.environ["WORLD_SIZE"] = str(pg_init_info.world_size)
181+
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
179182

180183
logger.info(
181184
"Initializing dist.ProcessGroup in checkpoint background process"
182185
)
183-
# NOTE: GLOO backend is enforced here.
184-
dist.init_process_group(backend=dist.Backend.GLOO)
185-
dist.barrier()
186+
dist.init_process_group()
186187

187188
logger.info("Checkpoint background process is running...")
188189
send.put(_CheckpointSaveProcessControlOpts.INIT_COMPLETE)

0 commit comments

Comments
 (0)
0