8000 [inductor] parallel compile: Create new pipes for subproc communicati… · pytorch/pytorch@bebb111 · GitHub
[go: up one dir, main page]

Skip to content

Commit bebb111

Browse files
masnesralpytorchbot
authored andcommitted
[inductor] parallel compile: Create new pipes for subproc communication (#131194)
Summary: Rather then using stdin/stdout for IPC, we can create new pipes and pass the descriptors to the subproc via the cmd line. #131070 reports an issue where the combination of deepspeed and onnxruntime-training causes _something_ in the subproc to write to stdout and corrupt the IPC. The current implementation was already brittle; we can just create new pipes specifically for the IPC. Test Plan: I was able to repro the MemoryError in #131070 by installing deepspeed and onnxruntime-training. Verified this PR fixes. Differential Revision: [D59968362](https://our.internmc.facebook.com/intern/diff/D59968362) Pull Request resolved: #131194 Approved by: https://github.com/malfet, https://github.com/eellison, https://github.com/atalman (cherry picked from commit 3c43fe0)
1 parent 58ab993 commit bebb111

File tree

2 files changed

+22
-35
lines changed

2 files changed

+22
-35
lines changed

torch/_inductor/compile_worker/__main__.py

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,9 @@
22
import argparse
33
import os
44
import sys
5-
import typing
65

76
from torch._inductor.async_compile import pre_fork_setup
8-
from torch._inductor.compile_worker.subproc_pool import Pipe, SubprocMain
7+
from torch._inductor.compile_worker.subproc_pool import SubprocMain
98
from torch._inductor.compile_worker.watchdog import _async_compile_initializer
109
from torch._inductor.runtime.compile_tasks import _set_triton_ptxas_path
1110

@@ -20,20 +19,17 @@
2019

2120

2221
def main():
23-
parser = argparse.ArgumentParser()
24-
parser.add_argument("--workers", type=int)
25-
parser.add_argument("--parent", type=int)
26-
args = parser.parse_args()
27-
if os.getppid() != args.parent:
28-
sys.exit(0)
29-
write_fd = typing.cast(Pipe, os.fdopen(os.dup(sys.stdout.fileno()), "wb"))
30-
read_fd = typing.cast(Pipe, os.fdopen(os.dup(sys.stdin.fileno()), "rb"))
31-
32-
# nobody else should read stdin
33-
sys.stdin.close()
34-
35-
# redirect output of workers to stderr
36-
os.dup2(sys.stderr.fileno(), sys.stdout.fileno())
22+
try:
23+
parser = argparse.ArgumentParser()
24+
parser.add_argument("--workers", type=int)
25+
parser.add_argument("--parent", type=int)
26+
parser.add_argument("--read-fd", type=int)
27+
parser.add_argument("--write-fd", type=int)
28+
args = parser.parse_args()
29+
if os.getppid() != args.parent:
30+
sys.exit(0)
31+
read_fd = os.fdopen(args.read_fd, "rb")
32+
write_fd = os.fdopen(args.write_fd, "wb")
3733

3834
pre_fork_setup()
3935

torch/_inductor/compile_worker/subproc_pool.py

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -18,20 +18,6 @@
1818
log = logging.getLogger(__name__)
1919

2020

21-
class Pipe(typing.Protocol):
22-
def write(self, data: bytes):
23-
...
24-
25-
def read(self, n: int) -> bytes:
26-
...
27-
28-
def close(self):
29-
...
30-
31-
def flush(self):
32-
...
33-
34-
3521
def _pack_msg(job_id, length):
3622
return struct.pack("nn", job_id, length)
3723

@@ -67,16 +53,22 @@ class SubprocPool:
6753

6854
def __init__(self, nprocs: int):
6955
entry = os.path.join(os.path.dirname(__file__), "__main__.py")
56+
57+
subproc_read_fd, write_fd = os.pipe()
58+
read_fd, subproc_write_fd = os.pipe()
59+
self.write_pipe = os.fdopen(write_fd, "wb")
60+
self.read_pipe = os.fdopen(read_fd, "rb")
61+
7062
cmd = [
7163
sys.executable,
7264
entry,
7365
f"--workers={nprocs}",
7466
f"--parent={os.getpid()}",
67+
f"--read-fd={str(subproc_read_fd)}",
68+
f"--write-fd={str(subproc_write_fd)}",
7569
]
7670
self.process = subprocess.Popen(
7771
cmd,
78-
stdin=subprocess.PIPE,
79-
stdout=subprocess.PIPE,
8072
env={
8173
**os.environ,
8274
# We need to set the PYTHONPATH so the subprocess can find torch.
@@ -86,10 +78,9 @@ def __init__(self, nprocs: int):
8678
# creates the SubprocPool in the first place.
8779
"TORCH_WARM_POOL": "0",
8880
},
81+
pass_fds=(subproc_read_fd, subproc_write_fd),
8982
)
90-
self.write_pipe: Pipe = typing.cast(Pipe, self.process.stdin)
9183
self.write_lock = threading.Lock()
92-
self.read_pipe: Pipe = typing.cast(Pipe, self.process.stdout)
9384
self.read_thread = threading.Thread(target=self._read_thread, daemon=True)
9485

9586 6D40
self.futures_lock = threading.Lock()
@@ -160,7 +151,7 @@ def shutdown(self):
160151
class SubprocMain:
161152
"""Communicates with a SubprocPool in the parent process, called by __main__.py"""
162153

163-
def __init__(self, nprocs: int, read_pipe: Pipe, write_pipe: Pipe):
154+
def __init__(self, nprocs, read_pipe, write_pipe):
164155
self.read_pipe = read_pipe
165156
self.write_pipe = write_pipe
166157
self.write_lock = threading.Lock()

0 commit comments

Comments
 (0)
0