8000 [WIP] Dynamo CPU backend under Windows by andreigh · Pull Request #109677 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[WIP] Dynamo CPU backend under Windows #109677

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions test/dynamo/test_repros.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from torch.nn import functional as F
from torch.testing._internal.common_utils import (
disable_translation_validation_if_dynamic_shapes,
IS_WINDO 8000 WS,
)


Expand Down Expand Up @@ -1844,6 +1845,7 @@ def forward(self, x):

self.assertEqual(y, 10)

@unittest.skipIf(IS_WINDOWS, "torch.sort crashes on my Windows box")
def test_sort_out(self):
dtype = torch.float32
device = "cpu"
Expand All @@ -1860,6 +1862,7 @@ def fn():
opt_fn = torch._dynamo.optimize("eager")(fn)
opt_fn()

@unittest.skipIf(IS_WINDOWS, "torch.sort crashes on my Windows box")
def test_sort_out2(self):
class MyModule(torch.nn.Module):
def __init__(self):
Expand Down
5 changes: 4 additions & 1 deletion torch/_dynamo/convert_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,10 @@ def _convert_frame_assert(
):
return None
if code.co_name == "<genexpr>" and code.co_filename.endswith(
("transformers/file_utils.py", "transformers/utils/generic.py")
(
"transformers/file_utils.py".replace("/", os.sep),
"transformers/utils/generic.py".replace("/", os.sep),
)
):
# not needed, but cleans up torchbench error stats
return None
Expand Down
9 changes: 6 additions & 3 deletions torch/_dynamo/eval_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,10 @@ def _set_current_backend(backend: CompilerFn):
DONT_WRAP_FILES = {
# For tracing into fx modules
inspect.getsourcefile(GraphModule),
join(dirname(dirname(__file__)), "onnx/_internal/fx/dynamo_graph_extractor.py"),
join(
dirname(dirname(__file__)),
"onnx/_internal/fx/dynamo_graph_extractor.py".replace("/", os.sep),
),
}


Expand Down Expand Up @@ -595,8 +598,8 @@ def __call__(self, fn):


def check_if_dynamo_supported():
if sys.platform == "win32":
raise RuntimeError("Windows not yet supported for torch.compile")
# if sys.platform == "win32":
# raise RuntimeError("Windows not yet supported for torch.compile")
if sys.version_info >= (3, 12):
raise RuntimeError("Python 3.12+ not yet supported for torch.compile")

Expand Down
4 changes: 3 additions & 1 deletion torch/_dynamo/skipfiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def _module_dir(m: types.ModuleType):
# torch.*
_module_dir(torch),
# torchdynamo.*
os.path.dirname(__file__) + "/",
os.path.dirname(__file__) + os.sep,
"<frozen importlib",
"<__array_function__ internals>",
] + [
Expand Down Expand Up @@ -175,6 +175,8 @@ def _module_dir(m: types.ModuleType):
_module_dir(torch) + "distributed/_tensor/device_mesh.py",
}

FILENAME_ALLOWLIST = {file.replace("/", os.sep) for file in FILENAME_ALLOWLIST}

SKIP_DIRS_RE = None

is_fbcode = importlib.import_module("torch._inductor.config").is_fbcode()
Expand Down
3 changes: 1 addition & 2 deletions torch/_dynamo/test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import torch
import torch.testing
from torch.testing._internal.common_utils import (
IS_WINDOWS,
TEST_WITH_CROSSREF,
TEST_WITH_TORCHDYNAMO,
TestCase as TorchTestCase,
Expand All @@ -19,7 +18,7 @@ def run_tests(needs=()):

if (
TEST_WITH_TORCHDYNAMO
or IS_WINDOWS
# or IS_WINDOWS
or TEST_WITH_CROSSREF
or sys.version_info >= (3, 12)
):
Expand Down
63 changes: 46 additions & 17 deletions torch/_inductor/codecache.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def _compile_end() -> None:
def cache_dir() -> str:
cache_dir = os.environ.get("TORCHINDUCTOR_CACHE_DIR")
if cache_dir is None:
cache_dir = f"{tempfile.gettempdir()}/torchinductor_{getpass.getuser()}"
cache_dir = f"{tempfile.gettempdir()}{os.sep}torchinductor_{getpass.getuser()}"
os.makedirs(cache_dir, exist_ok=True)
return cache_dir

Expand Down Expand Up @@ -453,7 +453,10 @@ def cpp_compiler_search(search: str) -> str:
)
with lock:
cxx = install_gcc_via_conda()
subprocess.check_output([cxx, "--version"])
if cxx == "cl":
subprocess.check_output([cxx, "/nologo", "/?"])
else:
subprocess.check_output([cxx, "--version"])
return cxx
except (subprocess.SubprocessError, FileNotFoundError, ImportError):
continue
Expand Down Expand Up @@ -565,12 +568,17 @@ def __bool__(self) -> bool:
lock_dir = get_lock_dir()
lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT)
with lock:
output_path = input_path[:-3] + "so"
build_cmd = shlex.split(
cpp_compile_command(
input_path, output_path, warning_all=False, vec_isa=self
)
suffix = "so"
if sys.platform == "win32":
suffix = "dll"
output_path = input_path[:-3] + suffix
command = cpp_compile_command(
input_path, output_path, warning_all=False, vec_isa=self
)
if sys.platform == "win32":
build_cmd = shlex.split(command, posix=0)
else:
build_cmd = shlex.split(command)
try:
# Check build result
compile_file(input_path, output_path, build_cmd)
Expand Down Expand Up @@ -673,6 +681,8 @@ def get_compile_only(compile_only: bool = True) -> str:


def get_shared(shared: bool = True) -> str:
if "cl" in config.cpp.cxx:
return "/LD" if shared else ""
return "-shared -fPIC" if shared else ""


Expand All @@ -685,6 +695,8 @@ def get_glibcxx_abi_build_flags() -> str:


def cpp_flags() -> str:
if "cl" in config.cpp.cxx:
return "/std:c++17"
return "-std=c++17 -Wno-unused-variable"


Expand All @@ -693,7 +705,10 @@ def cpp_wrapper_flags() -> str:


def optimization_flags() -> str:
base_flags = "-O3 -ffast-math -fno-finite-math-only"
if "cl" in config.cpp.cxx:
base_flags = "/O2 /Ob2 /EHsc /nologo"
else:
base_flags = "-O3 -ffast-math -fno-finite-math-only"
if config.is_fbcode():
# FIXME: passing `-fopenmp` adds libgomp.so to the generated shared library's dependencies.
# This causes `ldopen` to fail in fbcode, because libgomp does not exist in the default paths.
Expand All @@ -704,14 +719,14 @@ def optimization_flags() -> str:
# Per https://mac.r-project.org/openmp/ right way to pass `openmp` flags to MacOS is via `-Xclang`
# Also, `-march=native` is unrecognized option on M1
base_flags += " -Xclang"
else:
elif sys.platform != "win32":
if platform.machine() == "ppc64le":
base_flags += " -mcpu=native"
else:
base_flags += " -march=native"

# Internal cannot find libgomp.so
if not config.is_fbcode():
if not config.is_fbcode() and sys.platform != "win32":
base_flags += " -fopenmp"
return base_flags

Expand Down Expand Up @@ -948,6 +963,9 @@ def cpp_compile_command(
inp_name = input
out_name = output
linker_paths = "" # let the compiler pick
output_param = "-o "
if "cl" in config.cpp.cxx:
output_param = "/Fe"
return re.sub(
r"[ \n]+",
" ",
Expand All @@ -961,7 +979,7 @@ def cpp_compile_command(
{use_fb_internal_macros()}
{use_standard_sys_dir_headers()}
{get_compile_only(compile_only)}
-o {out_name}
{output_param} {out_name}
""",
).strip()

Expand Down Expand Up @@ -1061,7 +1079,10 @@ def _to_bytes(t: torch.Tensor) -> bytes:
with open(output_json, "w") as f:
f.write(serialized_extern_kernel_nodes)

output_so = os.path.splitext(input_path)[0] + ".so"
suffix = ".so"
if sys.platform == "win32":
suffix = ".dll"
output_so = os.path.splitext(input_path)[0] + suffix

if not os.path.exists(output_so):
output_o = os.path.splitext(input_path)[0] + ".o"
Expand Down Expand Up @@ -1155,6 +1176,7 @@ def cpp_prefix_path() -> str:

def cpp_prefix() -> str:
filename = cpp_prefix_path()
filename = filename.replace("\\", "\\\\")
if config.is_fbcode():
# We need relative paths, since we bundle up
# everything that we compile into a folder for remote compilation.
Expand Down Expand Up @@ -1247,13 +1269,18 @@ def load(cls, source_code: str) -> CDLL:
lock_dir = get_lock_dir()
lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT)
with lock:
output_path = input_path[:-3] + "so"
suffix = "so"
if sys.platform == "win32":
suffix = "dll"
output_path = input_path[:-3] + suffix
if not os.path.exists(output_path):
cmd = shlex.split(
cpp_compile_command(
input=input_path, output=output_path, vec_isa=picked_vec_isa
)
command = cpp_compile_command(
input=input_path, output=output_path, vec_isa=picked_vec_isa
)
if sys.platform == "win32":
cmd = shlex.split(command, posix=0)
else:
cmd = shlex.split(command)
compile_file(input_path, output_path, cmd)
cls.cache[key] = cls._load_library(output_path)
cls.cache[key].key = key # type: ignore[attr-defined]
Expand Down Expand Up @@ -1350,6 +1377,8 @@ def load(cls, source_code: str, func_name: str, key: str, cuda: bool) -> CDLL:
os.makedirs(cpp_wrapper_dir)

ext = "so"
if sys.platform == "win32":
ext = "dll"
filepath = os.path.join(cpp_wrapper_dir, f"{name}.{ext}")
log.debug("Cpp wrapper code path %s", filepath)

Expand Down
7 changes: 6 additions & 1 deletion torch/_inductor/codegen/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2866,7 +2866,12 @@ def codegen_define_and_call(self, wrapper):
kernel_decl_name = kernel_name if V.graph.cpp_wrapper else "kernel"
code.writeline(codecache.cpp_prefix())

code.writeline(f'extern "C" void {kernel_decl_name}({arg_defs})')
if sys.platform == "win32":
code.writeline(
f'extern "C" __declspec(dllexport) void {kernel_decl_name}({arg_defs})'
)
else:
code.writeline(f'extern "C" void {kernel_decl_name}({arg_defs})')
with code.indent():
if enable_kernel_profile:
graph_id = V.graph.graph_id
Expand Down
2 changes: 1 addition & 1 deletion torch/_inductor/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1718,7 +1718,7 @@ def create_backend(self, device: torch.device):

if device.type == "cuda" and not has_triton():
device_props = torch.cuda.get_device_properties(device)
if device_props.major < 7:
if device_props.major < 5:
raise RuntimeError(
f"Found {device_props.name} which is too old to be supported by the triton GPU compiler, which is used as the backend. Triton only supports devices of CUDA Capability >= 7.0, but your device is of CUDA capability {device_props.major}.{device_props.minor}" # noqa: B950
)
Expand Down
0