From bebb1115a868de2abafa081e247726b0517446bc Mon Sep 17 00:00:00 2001 From: Sam Larsen Date: Fri, 19 Jul 2024 08:35:43 -0700 Subject: [PATCH] [inductor] parallel compile: Create new pipes for subproc communication (#131194) Summary: Rather then using stdin/stdout for IPC, we can create new pipes and pass the descriptors to the subproc via the cmd line. https://github.com/pytorch/pytorch/issues/131070 reports an issue where the combination of deepspeed and onnxruntime-training causes _something_ in the subproc to write to stdout and corrupt the IPC. The current implementation was already brittle; we can just create new pipes specifically for the IPC. Test Plan: I was able to repro the MemoryError in https://github.com/pytorch/pytorch/issues/131070 by installing deepspeed and onnxruntime-training. Verified this PR fixes. Differential Revision: [D59968362](https://our.internmc.facebook.com/intern/diff/D59968362) Pull Request resolved: https://github.com/pytorch/pytorch/pull/131194 Approved by: https://github.com/malfet, https://github.com/eellison, https://github.com/atalman (cherry picked from commit 3c43fe068f4c9d25d110106769bccab94da5f352) --- torch/_inductor/compile_worker/__main__.py | 28 ++++++++---------- .../_inductor/compile_worker/subproc_pool.py | 29 +++++++------------ 2 files changed, 22 insertions(+), 35 deletions(-) diff --git a/torch/_inductor/compile_worker/__main__.py b/torch/_inductor/compile_worker/__main__.py index 7f0965415bbff6..046e1c610d4c46 100644 --- a/torch/_inductor/compile_worker/__main__.py +++ b/torch/_inductor/compile_worker/__main__.py @@ -2,10 +2,9 @@ import argparse import os import sys -import typing from torch._inductor.async_compile import pre_fork_setup -from torch._inductor.compile_worker.subproc_pool import Pipe, SubprocMain +from torch._inductor.compile_worker.subproc_pool import SubprocMain from torch._inductor.compile_worker.watchdog import _async_compile_initializer from torch._inductor.runtime.compile_tasks import _set_triton_ptxas_path @@ -20,20 +19,17 @@ def main(): - parser = argparse.ArgumentParser() - parser.add_argument("--workers", type=int) - parser.add_argument("--parent", type=int) - args = parser.parse_args() - if os.getppid() != args.parent: - sys.exit(0) - write_fd = typing.cast(Pipe, os.fdopen(os.dup(sys.stdout.fileno()), "wb")) - read_fd = typing.cast(Pipe, os.fdopen(os.dup(sys.stdin.fileno()), "rb")) - - # nobody else should read stdin - sys.stdin.close() - - # redirect output of workers to stderr - os.dup2(sys.stderr.fileno(), sys.stdout.fileno()) + try: + parser = argparse.ArgumentParser() + parser.add_argument("--workers", type=int) + parser.add_argument("--parent", type=int) + parser.add_argument("--read-fd", type=int) + parser.add_argument("--write-fd", type=int) + args = parser.parse_args() + if os.getppid() != args.parent: + sys.exit(0) + read_fd = os.fdopen(args.read_fd, "rb") + write_fd = os.fdopen(args.write_fd, "wb") pre_fork_setup() diff --git a/torch/_inductor/compile_worker/subproc_pool.py b/torch/_inductor/compile_worker/subproc_pool.py index 03bfe6c3f203b7..98e35a666aaa07 100644 --- a/torch/_inductor/compile_worker/subproc_pool.py +++ b/torch/_inductor/compile_worker/subproc_pool.py @@ -18,20 +18,6 @@ log = logging.getLogger(__name__) -class Pipe(typing.Protocol): - def write(self, data: bytes): - ... - - def read(self, n: int) -> bytes: - ... - - def close(self): - ... - - def flush(self): - ... - - def _pack_msg(job_id, length): return struct.pack("nn", job_id, length) @@ -67,16 +53,22 @@ class SubprocPool: def __init__(self, nprocs: int): entry = os.path.join(os.path.dirname(__file__), "__main__.py") + + subproc_read_fd, write_fd = os.pipe() + read_fd, subproc_write_fd = os.pipe() + self.write_pipe = os.fdopen(write_fd, "wb") + self.read_pipe = os.fdopen(read_fd, "rb") + cmd = [ sys.executable, entry, f"--workers={nprocs}", f"--parent={os.getpid()}", + f"--read-fd={str(subproc_read_fd)}", + f"--write-fd={str(subproc_write_fd)}", ] self.process = subprocess.Popen( cmd, - stdin=subprocess.PIPE, - stdout=subprocess.PIPE, env={ **os.environ, # We need to set the PYTHONPATH so the subprocess can find torch. @@ -86,10 +78,9 @@ def __init__(self, nprocs: int): # creates the SubprocPool in the first place. "TORCH_WARM_POOL": "0", }, + pass_fds=(subproc_read_fd, subproc_write_fd), ) - self.write_pipe: Pipe = typing.cast(Pipe, self.process.stdin) self.write_lock = threading.Lock() - self.read_pipe: Pipe = typing.cast(Pipe, self.process.stdout) self.read_thread = threading.Thread(target=self._read_thread, daemon=True) self.futures_lock = threading.Lock() @@ -160,7 +151,7 @@ def shutdown(self): class SubprocMain: """Communicates with a SubprocPool in the parent process, called by __main__.py""" - def __init__(self, nprocs: int, read_pipe: Pipe, write_pipe: Pipe): + def __init__(self, nprocs, read_pipe, write_pipe): self.read_pipe = read_pipe self.write_pipe = write_pipe self.write_lock = threading.Lock()