8000 [inductor] parallel compile: Create new pipes for subproc communication by pytorchbot · Pull Request #133590 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[inductor] parallel compile: Create new pipes for subproc communication #133590

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 12 additions & 16 deletions torch/_inductor/compile_worker/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@
import argparse
import os
import sys
import typing

from torch._inductor.async_compile import pre_fork_setup
from torch._inductor.compile_worker.subproc_pool import Pipe, SubprocMain
from torch._inductor.compile_worker.subproc_pool import SubprocMain
from torch._inductor.compile_worker.watchdog import _async_compile_initializer
from torch._inductor.runtime.compile_tasks import _set_triton_ptxas_path

Expand All @@ -20,22 +19,19 @@


def main():
parser = argparse.ArgumentParser()
parser.add_argument("--workers", type=int)
parser.add_argument("--parent", type=int)
args = parser.parse_args()
if os.getppid() != args.parent:
sys.exit(0)
write_fd = typing.cast(Pipe, os.fdopen(os.dup(sys.stdout.fileno()), "wb"))
read_fd = typing.cast(Pipe, os.fdopen(os.dup(sys.stdin.fileno()), "rb"))

# nobody else should read stdin
sys.stdin.close()

# redirect output of workers to stderr
os.dup2(sys.stderr.fileno(), sys.stdout.fileno())
try:
parser = argparse.ArgumentParser()
parser.add_argument("--workers", type=int)
parser.add_argument("--parent", type=int)
parser.add_argument("--read-fd", type=int)
parser.add_argument("--write-fd", type=int)
args = parser.parse_args()
if os.getppid() != args.parent:
sys.exit(0)
read_fd = os.fdopen(args.read_fd, "rb")
write_fd = os.fdopen(args.write_fd, "wb")

pre_fork_setup()

Check failure on line 34 in torch/_inductor/compile_worker/__main__.py

View workflow job for this annotation

GitHub Actions / lintrunner-noclang / linux-job

MYPY [syntax]

invalid syntax; you likely need to run mypy using Python 3.11 or newer

_async_compile_initializer(args.parent)
SubprocMain(args.workers, read_fd, write_fd).main()
Expand Down
29 changes: 10 additions & 19 deletions torch/_inductor/compile_worker/subproc_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,6 @@
log = logging.getLogger(__name__)


class Pipe(typing.Protocol):
def write(self, data: bytes):
...

def read(self, n: int) -> bytes:
...

def close(self):
...

def flush(self):
...


def _pack_msg(job_id, length):
return struct.pack("nn", job_id, length)

Expand Down Expand Up @@ -67,16 +53,22 @@ class SubprocPool:

def __init__(self, nprocs: int):
entry = os.path.join(os.path.dirname(__file__), "__main__.py")

subproc_read_fd, write_fd = os.pipe()
read_fd, subproc_write_fd = os.pipe()
self.write_pipe = os.fdopen(write_fd, "wb")
self.read_pipe = os.fdopen(read_fd, "rb")

cmd = [
sys.executable,
entry,
f"--workers={nprocs}",
f"--parent={os.getpid()}",
f"--read-fd={str(subproc_read_fd)}",
f"--write-fd={str(subproc_write_fd)}",
]
self.process = subprocess.Popen(
cmd,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
env={
**os.environ,
# We need to set the PYTHONPATH so the subprocess can find torch.
Expand All @@ -86,10 +78,9 @@ def __init__(self, nprocs: int):
# creates the SubprocPool in the first place.
"TORCH_WARM_POOL": "0",
},
pass_fds=(subproc_read_fd, subproc_write_fd),
)
self.write_pipe: Pipe = typing.cast(Pipe, self.process.stdin)
self.write_lock = threading.Lock()
self.read_pipe: Pipe = typing.cast(Pipe, self.process.stdout)
self.read_thread = threading.Thread(target=self._read_thread, daemon=True)

self.futures_lock = threading.Lock()
Expand Down Expand Up @@ -160,7 +151,7 @@ def shutdown(self):
class SubprocMain:
"""Communicates with a SubprocPool in the parent process, called by __main__.py"""

def __init__(self, nprocs: int, read_pipe: Pipe, write_pipe: Pipe):
def __init__(self, nprocs, read_pipe, write_pipe):
self.read_pipe = read_pipe
self.write_pipe = write_pipe
self.write_lock = threading.Lock()
Expand Down
Loading
0