8000 [inductor] parallel compile: Create new pipes for subproc communication by masnesral · Pull Request #131194 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[inductor] parallel compile: Create new pipes for subproc communication #131194

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 5 additions & 10 deletions torch/_inductor/compile_worker/__main__.py
8000
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
import logging
import os
import sys
import typing

from torch._inductor.async_compile import pre_fork_setup
from torch._inductor.compile_worker.subproc_pool import Pipe, SubprocMain
from torch._inductor.compile_worker.subproc_pool import SubprocMain
from torch._inductor.compile_worker.watchdog import _async_compile_initializer
from torch._inductor.runtime.compile_tasks import _set_triton_ptxas_path

Expand All @@ -27,17 +26,13 @@ def main():
parser = argparse.ArgumentParser()
parser.add_argument("--workers", type=int)
parser.add_argument("--parent", type=int)
parser.add_argument("--read-fd", type=int)
parser.add_argument("--write-fd", type=int)
args = parser.parse_args()
if os.getppid() != args.parent:
sys.exit(0)
write_fd = typing.cast(Pipe, os.fdopen(os.dup(sys.stdout.fileno()), "wb"))
read_fd = typing.cast(Pipe, os.fdopen(os.dup(sys.stdin.fileno()), "rb"))

# nobody else should read stdin
sys.stdin.close()

# redirect output of workers to stderr
os.dup2(sys.stderr.fileno(), sys.stdout.fileno())
read_fd = os.fdopen(args.read_fd, "rb")
write_fd = os.fdopen(args.write_fd, "wb")

pre_fork_setup()

Expand Down
29 changes: 10 additions & 19 deletions torch/_inductor/compile_worker/subproc_pool.py
8638
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,6 @@
log = logging.getLogger(__name__)


class Pipe(typing.Protocol):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this was really needed. typechecking gives me no errors.

def write(self, data: bytes):
...

def read(self, n: int) -> bytes:
...

def close(self):
...

def flush(self):
...


def _pack_msg(job_id, length):
return struct.pack("nn", job_id, length)

Expand Down Expand Up @@ -103,16 +89,22 @@ class SubprocPool:

def __init__(self, nprocs: int):
entry = os.path.join(os.path.dirname(__file__), "__main__.py")

subproc_read_fd, write_fd = os.pipe()
read_fd, subproc_write_fd = os.pipe()
self.write_pipe = os.fdopen(write_fd, "wb")
self.read_pipe = os.fdopen(read_fd, "rb")

cmd = [
sys.executable,
entry,
f"--workers={nprocs}",
f"--parent={os.getpid()}",
f"--read-fd={str(subproc_read_fd)}",
f"--write-fd={str(subproc_write_fd)}",
]
self.process = subprocess.Popen(
cmd,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
env={
**os.environ,
# We need to set the PYTHONPATH so the subprocess can find torch.
Expand All @@ -124,10 +116,9 @@ def __init__(self, nprocs: int):
# Some internal usages need a modified LD_LIBRARY_PATH.
"LD_LIBRARY_PATH": _get_ld_library_path(),
},
pass_fds=(subproc_read_fd, subproc_write_fd),
)
self.write_pipe: Pipe = typing.cast(Pipe, self.process.stdin)
self.write_lock = threading.Lock()
self.read_pipe: Pipe = typing.cast(Pipe, self.process.stdout)
self.read_thread = threading.Thread(target=self._read_thread, daemon=True)

self.futures_lock = threading.Lock()
Expand Down Expand Up @@ -204,7 +195,7 @@ def shutdown(self):
class SubprocMain:
"""Communicates with a SubprocPool in the parent process, called by __main__.py"""

def __init__(self, nprocs: int, read_pipe: Pipe, write_pipe: Pipe):
def __init__(self, nprocs, read_pipe, write_pipe):
self.read_pipe = read_pipe
self.write_pipe = write_pipe
self.write_lock = threading.Lock()
Expand Down
Loading
0