8000 torchrun in environments without DNS support · Issue #150532 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

torchrun in environments without DNS support #150532

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
kuiz 8000 hiqing opened this issue Apr 2, 2025 · 3 comments · May be fixed by #150533
Closed

torchrun in environments without DNS support #150532

kuizhiqing opened this issue Apr 2, 2025 · 3 comments · May be fixed by #150533
Labels
module: c10d Issues/PRs related to collective communications and process groups oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@kuizhiqing
Copy link
kuizhiqing commented Apr 2, 2025

I'm managing a cluster with a large number of nodes, where each node's hostname is only resolvable locally on that node.

This causes my torchrun program to hang when using the c10d rendezvous backend:

export PET_NPROC_PER_NODE=8  
export PET_NNODES=2  
export PET_RDZV_ENDPOINT=<MASTER_IP>:36123  
export PET_RDZV_BACKEND=c10d  
torchrun demo.py  

After investigating the issue, I found that the problem originates from the local_addr being retrieved via socket.getfqdn(). This method does not return a correctly reachable hostname, leading to connectivity issues during rendezvous.

Precisely, in torch/distributed/elastic/rendezvous/dynamic_rendezvous.py

class _NodeDescGenerator:
    def generate(self, local_addr: Optional[str] = None) -> _NodeDesc:
        return _NodeDesc(local_addr or socket.getfqdn(), os.getpid(), local_id)

A potential issue also exists in torch/distributed/elastic/rendezvous/api.py

class RendezvousStoreInfo:
    def build(...):
        if rank == 0:
            addr = local_addr or socket.getfqdn()

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k

@drisspg drisspg added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Apr 2, 2025
@mori360 mori360 added the module: c10d Issues/PRs related to collective communications and process groups label Apr 2, 2025
@kuizhiqing
Copy link
Author

I want to provide mor 8000 e information why this case is worth considering.

Normally, a global DNS server is responsible for making hostnames available across the network. However, certain scenarios make this difficult to achieve:

– Managing a global DNS in very large clusters can be inefficient. In our environments, directly using IP addresses is often a more practical and reliable solution.

– Jobs may span multiple clusters, each with different networking mechanisms. For example, Kubernetes clusters may have separate internal DNS setups, making hostname resolution inconsistent across clusters.

@fduwjj fduwjj added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Apr 23, 2025
@kuizhiqing kuizhiqing changed the title torchrun Hangs Due to Unresolvable Hostname in c10d Rendezvous torchrun with Unresolvable Hostname in c10d Rendezvous May 12, 2025
@kuizhiqing
Copy link
Author

This comment provide solutions for each version without change code.

After more investigation, I've found some solution without changing code and more information in different versions.

v2.6.0 & v2.7.0

Rendezvous implementation

# torch/distributed/elastic/rendezvous/dynamic_rendezvous.py

class DynamicRendezvousHandler(RendezvousHandler):
    def next_rendezvous(self) -> RendezvousInfo:
        try:
            rank, world_size = self._get_world()
            store = self._get_store()

        if self._bootstrap_store_info is None:
            # To avoid race in get_free_port because we release the port after the call,
            # we want to create a TCPStore server soon afterwards.
            server_port = 0
            if rank == 0:
                self._shared_tcp_store_server = self._create_tcp_store_server(
                    self._this_node.addr, server_port
                )
                server_port = self._shared_tcp_store_server.port
            self._bootstrap_store_info = RendezvousStoreInfo.build(
                rank,
                store,
                local_addr=self._this_node.addr,
                server_port=server_port,  # For non-0 rank, this is a no-op
            )

        return RendezvousInfo(
            store,
            rank,
            world_size,
            self._bootstrap_store_info,  # type: ignore[assignment]
        )
# torch/distributed/elastic/rendezvous/api.py

@dataclass
class RendezvousStoreInfo:
    @staticmethod
    def build(
        rank: int,
        store: Store,
        local_addr: Optional[str],
        server_port: Optional[int] = None,
    ) -> "RendezvousStoreInfo":
        if rank == 0:
            addr = local_addr or socket.getfqdn()
            # When TCPStore is not shared, we fallback to get_free_port.
            port = server_port or get_free_port()
            store.set(RendezvousStoreInfo.MASTER_ADDR_KEY, addr.encode(encoding="UTF-8"))
            store.set(RendezvousStoreInfo.MASTER_PORT_KEY, str(port).encode(encoding="UTF-8"))

        addr = store.get(RendezvousStoreInfo.MASTER_ADDR_KEY).decode(encoding="UTF-8")
        port = int(
            store.get(RendezvousStoreInfo.MASTER_PORT_KEY).decode(encoding="UTF-8")
        )
        return RendezvousStoreInfo(master_addr=addr, master_port=port)

Solution

local-addr can overwrite the master assignment. So we can do it explicite,

torchrun --nnodes=2 --nproc-per-node=8 --rdzv-endpoint=$RDZV_IP:36123 --rdzv-backend=c10d --local-addr=$LOCAL_IP demo.py

Ensure that RDZV_IP and LOCAL_IP are IPs.

v2.5.1

Rendezvous implementation

# torch/distributed/elastic/rendezvous/dynamic_rendezvous.py

class DynamicRendezvousHandler(RendezvousHandler):
    def next_rendezvous(self) -> RendezvousInfo:
        try:
            rank, world_size = self._get_world()
            store = self._get_store()

        if self._bootstrap_store_info is None:
            if isinstance(self._store, dist.TCPStore):
                addr = self._store.host
                port = self._store.port
                self._bootstrap_store_info = RendezvousStoreInfo(
                    master_addr=addr, master_port=port
                )
                if rank == 0:
                    self._shared_tcp_store_server = self._store
            else:
                # If the store is not type of TCPStore start TCPStore server, which requries
                # bootstrapping info across ranks
                self._bootstrap_store_info = RendezvousStoreInfo.build(
                    rank, store, local_addr=self._this_node.addr
                )
                if rank == 0:
                    self._shared_tcp_store_server = self._create_tcp_store_server(
                        self._bootstrap_store_info
                    )

        return RendezvousInfo(
            store,
            rank,
            world_size,
            self._bootstrap_store_info,  # type: ignore[assignment]
        )
# torch/distributed/elastic/rendezvous/api.py

@dataclass
class RendezvousStoreInfo:
    @staticmethod
    def build(
        rank: int, store: Store, local_addr: Optional[str]
    ) -> "RendezvousStoreInfo":
        if rank == 0:
            addr = local_addr or socket.getfqdn()
            port = _get_free_port()
            store.set(RendezvousStoreInfo.MASTER_ADDR_KEY, addr.encode(encoding="UTF-8"))
            store.set(RendezvousStoreInfo.MASTER_PORT_KEY, str(port).encode(encoding="UTF-8"))

        addr = store.get(RendezvousStoreInfo.MASTER_ADDR_KEY).decode(encoding="UTF-8")
        port = int(
            store.get(RendezvousStoreInfo.MASTER_PORT_KEY).decode(encoding="UTF-8")
        )
        return RendezvousStoreInfo(master_addr=addr, master_port=port)

Since DynamicRendezvousHandler return RendezvousStoreInfo directly without build, they share the TCPStore, there will no issue with rdzv-endpoint with IP.

v2.4.1

Rendezvous implementation

# torch/distributed/elastic/rendezvous/dynamic_rendezvous.py

class DynamicRendezvousHandler(RendezvousHandler):
    def next_rendezvous(self) -> RendezvousInfo:
        try:
            rank, world_size = self._get_world()
            store = self._get_store()

        bootstrap_store_info = RendezvousStoreInfo.build(rank, store)
        return RendezvousInfo(
            store,
            rank,
            world_size,
            bootstrap_store_info,
        )
# torch/distributed/elastic/rendezvous/api.py

@dataclass
class RendezvousStoreInfo:
    @staticmethod
    def build(rank: int, store: Store) -> "RendezvousStoreInfo":
        if rank == 0:
            addr = socket.getfqdn()
            port = _get_free_port()
            store.set(RendezvousStoreInfo.MASTER_ADDR_KEY, addr.encode(encoding="UTF-8"))
            store.set(RendezvousStoreInfo.MASTER_PORT_KEY, str(port).encode(encoding="UTF-8"))

        addr = store.get(RendezvousStoreInfo.MASTER_ADDR_KEY).decode(encoding="UTF-8")
        port = int(store.get(RendezvousStoreInfo.MASTER_PORT_KEY).decode(encoding="UTF-8"))
        return RendezvousStoreInfo(master_addr=addr, master_port=port)

In this version, socket.getfqdn() can't be avoid, well the following command can launch the job though,

torchrun --nnodes=2 --nproc-per-node=8 --master-addr=$MASTER_IP --master-port=36123 --node-rank=$RANK demo.py

While, the node rank should provide externally.

@kuizhiqing
Copy link
Author

@fduwjj @mori360 @d4l3k
PTAL, We may close this issue and look forward to an internal solution for the lack of DNS support.

@kuizhiqing kuizhiqing changed the title torchrun with Unresolvable Hostname in c10d Rendezvous torchrun in environments without DNS support May 15, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: c10d Issues/PRs related to collective communications and process groups oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants
0