8000 cpp_wrapper: Move #includes to per-device header files (#143909) · pytorch/pytorch@d62b397 · GitHub
[go: up one dir, main page]

Skip to content

Commit d62b397

Browse files
benjaminglass1pytorchmergebot
authored andcommitted
cpp_wrapper: Move #includes to per-device header files (#143909)
This prepares us for the next PR in the stack, where we introduce pre-compiled per-device header files to save compilation time. Differential Revision: [D67938955](https://our.internmc.facebook.com/intern/diff/D67938955) Pull Request resolved: #143909 Approved by: https://github.com/desertfire
1 parent 05095a4 commit d62b397

25 files changed

+159
-111
lines changed

.lintrunner.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -556,7 +556,7 @@ exclude_patterns = [
556556
command = [
557557
'python3',
558558
'tools/linter/adapters/grep_linter.py',
559-
'--pattern=#include <pybind11\/',
559+
'--pattern=#include <pybind11\/(^|[^(gil\.h)])',
560560
'--allowlist-pattern=#include <torch\/csrc\/utils\/pybind.h>',
561561
'--linter-name=PYBIND11_INCLUDE',
562562
'--match-first-only',

setup.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1284,13 +1284,16 @@ def main():
12841284
"include/torch/csrc/distributed/autograd/rpc_messages/*.h",
12851285
"include/torch/csrc/dynamo/*.h",
12861286
"include/torch/csrc/inductor/*.h",
1287+
"include/torch/csrc/inductor/aoti_include/*.h",
12871288
"include/torch/csrc/inductor/aoti_package/*.h",
12881289
"include/torch/csrc/inductor/aoti_runner/*.h",
12891290
"include/torch/csrc/inductor/aoti_runtime/*.h",
12901291
"include/torch/csrc/inductor/aoti_torch/*.h",
12911292
"include/torch/csrc/inductor/aoti_torch/c/*.h",
12921293
"include/torch/csrc/inductor/aoti_torch/generated/*.h",
12931294
"include/torch/csrc/inductor/aoti_torch/generated/extend/*.h",
1295+
"include/torch/csrc/inductor/cpp_wrapper/*.h",
1296+
"include/torch/csrc/inductor/cpp_wrapper/device_internal/*.h",
12941297
"include/torch/csrc/jit/*.h",
12951298
"include/torch/csrc/jit/backends/*.h",
12961299
"include/torch/csrc/jit/generated/*.h",

torch/_inductor/codecache.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -688,7 +688,6 @@ def get_code_hash(root: str) -> bytes:
688688
# a hash representing the state of the source code.
689689
extra_files = (
690690
"codegen/aoti_runtime/interface.cpp",
691-
"codegen/aoti_runtime/implementation.cpp",
692691
"codegen/cpp_prefix.h",
693692
"script.ld",
694693
)

torch/_inductor/codegen/common.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -250,9 +250,6 @@ def kernel_header(self):
250250
def kernel_driver(self):
251251
raise NotImplementedError
252252

253-
def abi_compatible_header(self):
254-
raise NotImplementedError
255-
256253
def cpp_stream_type(self):
257254
raise NotImplementedError
258255

torch/_inductor/codegen/cpp_wrapper_cpu.py

Lines changed: 18 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from torch.utils._sympy.symbol import symbol_is_type, SymT
1919

2020
from .. import config, ir
21-
from ..utils import _align, ALIGN_BYTES, cache_on_self, normalize_name
21+
from ..utils import _align, cache_on_self, normalize_name
2222
from ..virtualized import V
2323
from .aoti_hipify_utils import maybe_hipify_code_wrapper
2424
from .common import get_device_op_overrides, IndentedBuffer, Kernel
@@ -126,85 +126,35 @@ def write_constant(self, name, hashed):
126126
# include a hash so our code cache gives different constants different files
127127
self.header.writeline(f"// {name} {hashed}")
128128

129+
def get_device_include(self):
130+
if V.graph.aot_mode:
131+
return f"#include <torch/csrc/inductor/aoti_include/{self.device}.h>"
132+
return f"#include <torch/csrc/inductor/cpp_wrapper/{self.device}.h>"
133+
129134
def write_header(self):
130135
if V.graph.is_const_graph:
131136
# We do not write header for constant graph, it will be written by main module.
132137
return
133138

134-
if V.graph.aot_mode:
135-
self.header.splice(
136-
"""
137-
#include <torch/csrc/inductor/aoti_runtime/interface.h>
138-
#include <torch/csrc/inductor/aoti_runtime/model.h>
139-
"""
140-
)
141-
with open(
142-
os.path.join(os.path.dirname(__file__), "aoti_runtime", "interface.cpp")
143-
) as f:
144-
self.header.splice(f.read())
145-
else:
139+
if not V.graph.aot_mode:
146140
self.header.splice(
147141
"""
148142
import torch
149143
from torch._inductor.codecache import CppWrapperCodeCache
150144
151145
cpp_wrapper_src = (
152146
'''
153-
#include <optional>
154-
#include <Python.h>
155-
156-
#define PYBIND11_SIMPLE_GIL_MANAGEMENT
157-
#include <pybind11/gil.h>
158-
namespace py = pybind11;
159-
160-
class RAIIPyObject {
161-
public:
162-
RAIIPyObject() : obj_(nullptr) {}
163-
RAIIPyObject(PyObject* obj) : obj_(obj) {}
164-
~RAIIPyObject() {
165-
Py_XDECREF(obj_);
166-
}
167-
RAIIPyObject& operator=(const RAIIPyObject& other) {
168-
if (this != &other) {
169-
Py_XDECREF(obj_);
170-
obj_ = other.obj_;
171-
Py_XINCREF(obj_);
172-
}
173-
return *this;
174-
}
175-
operator PyObject*() {
176-
return obj_;
177-
}
178-
PyObject* get() {
179-
return obj_;
180-
}
181-
private:
182-
PyObject* obj_;
183-
};
184< 10670 /code>-
185-
#include <torch/csrc/inductor/aoti_runtime/device_utils.h>
186-
#include <torch/csrc/inductor/aoti_runtime/utils.h>
187-
using namespace torch::aot_inductor;
188147
"""
189148
)
190149

191-
self.header.splice(
192-
f"""
193-
#include <torch/csrc/inductor/aoti_runtime/arrayref_tensor.h>
194-
#include <torch/csrc/inductor/aoti_runtime/thread_local.h>
195-
#include <torch/csrc/inductor/aoti_runtime/scalar_to_tensor.h>
196-
#include <torch/csrc/inductor/aoti_torch/generated/c_shim_{self.device}.h>
197-
198-
#include <c10/util/generic_math.h>
199-
typedef at::Half half;
200-
typedef at::BFloat16 bfloat16;
201-
202-
// Round up to the nearest multiple of {ALIGN_BYTES}
203-
[[maybe_unused]] static int64_t align(int64_t nbytes) {{
204-
return (nbytes + {ALIGN_BYTES} - 1) & -{ALIGN_BYTES};
205-
}}
206-
"""
207-
)
150+
self.header.splice(self.get_device_include())
151+
152+
if V.graph.aot_mode:
153+
with open(
154+
os.path.join(os.path.dirname(__file__), "aoti_runtime", "interface.cpp")
155+
) as f:
156+
self.header.splice(f.read())
157+
208158
extend_aoti_c_shim_include = (
209159
f"torch/csrc/inductor/aoti_torch/generated/extend/c_shim_{self.device}.h"
210160
)
@@ -1517,8 +1467,10 @@ def create_dtypeview_call(reinterpret_call: str) -> tuple[str, List[str]]:
15171467
return final_tmp_name
15181468

15191469
def codegen_device_copy(self, src, dst, non_blocking: bool):
1470+
"""This function is overridden by cpp_wrapper_cpu_array_ref, so we don't need to
1471+
handle cases where dst is not an AtenTensorHandle."""
15201472
self.writeline(
1521-
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_copy_(expensive_copy_to_tensor_if_needed({dst}), {src}, {non_blocking}));"
1473+
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_copy_({dst}, {src}, {non_blocking}));"
15221474
)
15231475

15241476
def codegen_multi_output(self, name, value):

torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# mypy: allow-untyped-defs
2-
import os
32
from itertools import count
43
from typing import Callable, Dict, List, Optional
54

@@ -82,18 +81,11 @@ def get_input_cpp_type(input):
8281
return DTYPE_TO_CPP[dtype]
8382
return f"ArrayRefTensor<{DTYPE_TO_CPP[input.get_dtype()]}>"
8483

85-
def write_header(self):
86-
if V.graph.is_const_graph:
87-
# We do not write header for constant graph, it will be written by main module.
88-
return
89-
90-
super().write_header()
91-
with open(
92-
os.path.join(
93-
os.path.dirname(__file__), "aoti_runtime", "implementation.cpp"
94-
)
95-
) as f:
96-
self.header.splice(f.read())
84+
def get_device_include(self):
85+
assert self.device == "cpu", "ArrayRef only supported on CPU!"
86+
if V.graph.aot_mode:
87+
return "#include <torch/csrc/inductor/aoti_include/array_ref.h>"
88+
return "#include <torch/csrc/inductor/cpp_wrapper/array_ref.h>"
9789

9890
def codegen_input_numel_asserts(self):
9991
for name, buf in V.graph.graph_inputs.items():

torch/_inductor/codegen/cpp_wrapper_gpu.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -203,9 +203,6 @@ def write_header(self):
203203
return
204204

205205
super().write_header()
206-
207-
self.header.splice("#include <filesystem>")
208-
self.header.splice(self.device_codegen.abi_compatible_header())
209206
self.header.splice(
210207
maybe_hipify_code_wrapper(self.device_codegen.kernel_driver())
211208
)

torch/_inductor/codegen/cuda/device_op_overrides.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -225,9 +225,6 @@ def tma_descriptor_helpers(self):
225225
#endif
226226
"""
227227

228-
def abi_compatible_header(self):
229-
return "#include <torch/csrc/inductor/aoti_runtime/utils_cuda.h>"
230-
231228
def cpp_stream_type(self):
232229
return "cudaStream_t"
233230

torch/_inductor/codegen/debug_utils.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,13 +53,13 @@ class DebugPrinterManager:
5353
def __init__(
5454
self,
5555
debug_printer_level,
56+
use_array_ref: bool,
5657
args_to_print_or_save: Optional[List[str]] = None,
5758
kernel_name: str = "",
5859
kernel=None,
59-
arg_signatures: Optional[List[type]] = None,
60-
kernel_type=None,
6160
):
6261
self.debug_printer_level = IntermediateValueDebuggingLevel(debug_printer_level)
62+
self.use_array_ref = use_array_ref
6363
if args_to_print_or_save is None:
6464
args_to_print_or_save = []
6565
self.args_to_print_or_save = args_to_print_or_save
@@ -155,12 +155,15 @@ def set_printer_args(
155155
]
156156
self.args_to_print_or_save = args_to_print_or_save_extern
157157
elif kernel_type == "cpp":
158-
args_to_print_or_save_cpp = [
159-
f"copy_arrayref_tensor_to_tensor({arg})"
158+
self.args_to_print_or_save = [
159+
(
160+
f"copy_arrayref_tensor_to_tensor({arg})"
161+
if self.use_array_ref
162+
else arg
163+
)
160164
for arg in args_to_print_or_save
161165
if arg.startswith(("buf", "arg"))
162166
]
163-
self.args_to_print_or_save = args_to_print_or_save_cpp
164167
else:
165168
self.args_to_print_or_save = args_to_print_or_save
166169
self.kernel_name = kernel_name

torch/_inductor/codegen/wrapper.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -721,7 +721,8 @@ def add_import_once(line: str) -> None:
721721

722722
# intermediate tensor value printing utility
723723
self.debug_printer = DebugPrinterManager(
724-
debug_printer_level=config.aot_inductor.debug_intermediate_value_printer
724+
debug_printer_level=config.aot_inductor.debug_intermediate_value_printer,
725+
use_array_ref=config.aot_inductor.allow_stack_allocation,
725726
)
726727

727728
# Additional files that are dependent to the wrapper (ex. cubin files)

0 commit comments

Comments
 (0)
0