|
8 | 8 | from typing import Any, Optional, Union
|
9 | 9 | from uuid import uuid4
|
10 | 10 |
|
| 11 | +import torch |
| 12 | + |
11 | 13 | import torch.distributed as dist
|
12 | 14 | import torch.multiprocessing as mp
|
13 | 15 | from torch.distributed.checkpoint._async_executor import _AsyncCheckpointExecutor
|
@@ -55,7 +57,7 @@ class _ProcessGroupInitInfo:
|
55 | 57 | tcp_store_master_port: int
|
56 | 58 |
|
57 | 59 | def __init__(self, process_group: Optional[dist.ProcessGroup] = None):
|
58 |
| - self.local_rank = dist.get_node_local_rank(fallback_rank=0) |
| 60 | + self.local_rank = dist.get_node_local_rank(fallback_rank=dist.get_rank(process_group)%8) |
59 | 61 | self.global_rank = dist.get_rank(process_group)
|
60 | 62 | self.world_size = dist.get_world_size(process_group)
|
61 | 63 |
|
@@ -176,13 +178,12 @@ def _checkpointing_subprocess(
|
176 | 178 | os.environ["LOCAL_RANK"] = str(pg_init_info.local_rank)
|
177 | 179 | os.environ["RANK"] = str(pg_init_info.global_rank)
|
178 | 180 | os.environ["WORLD_SIZE"] = str(pg_init_info.world_size)
|
| 181 | + torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) |
179 | 182 |
|
180 | 183 | logger.info(
|
181 | 184 | "Initializing dist.ProcessGroup in checkpoint background process"
|
182 | 185 | )
|
183 |
| - # NOTE: GLOO backend is enforced here. |
184 |
| - dist.init_process_group(backend=dist.Backend.GLOO) |
185 |
| - dist.barrier() |
| 186 | + dist.init_process_group() |
186 | 187 |
|
187 | 188 | logger.info("Checkpoint background process is running...")
|
188 | 189 | send.put(_CheckpointSaveProcessControlOpts.INIT_COMPLETE)
|
|
0 commit comments