8000 Add device check for inputs (#151828) · pytorch/pytorch@5d316ce · GitHub
[go: up one dir, main page]

Skip to content

Commit 5d316ce

Browse files
yushangdipytorchmergebot
authored andcommitted
Add device check for inputs (#151828)
Summary: Generate device checks for inputs in AOTI. Enable with AOTI_RUNTIME_CHECK_INPUTS=1 Test Plan: ``` buck run fbcode//mode/dev-nosan //caffe2/test/inductor:test_aot_inductor -- -r test_runtime_checks_device_type_failed ``` Differential Revision: D73382824 Pull Request resolved: #151828 Approved by: https://github.com/angelayi
1 parent 3804aed commit 5d316ce

File tree

3 files changed

+63
-4
lines changed

3 files changed

+63
-4
lines changed

test/inductor/test_aot_inductor.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3810,6 +3810,32 @@ def forward(self, x):
38103810
with self.assertRaisesRegex(Exception, ""):
38113811
aot_inductor_module(x_casted)
38123812

3813+
@patch.dict(os.environ, {"AOTI_RUNTIME_CHECK_INPUTS": "1"})
3814+
def test_runtime_checks_device_type_failed(self):
3815+
if self.device != GPU_TYPE:
3816+
raise unittest.SkipTest("requires GPU")
3817+
3818+
class Model(torch.nn.Module):
3819+
def __init__(self) -> None:
3820+
super().__init__()
3821+
3822+
def forward(self, x):
3823+
return x + 1
3824+
3825+
x = torch.randn(1, 4, dtype=torch.float16, device="cpu")
3826+
model = Model()
3827+
with torch.no_grad():
3828+
package_path: str = AOTIRunnerUtil.compile(
3829+
model,
3830+
(x,),
3831+
)
3832+
3833+
aot_inductor_module = torch._inductor.aoti_load_package(package_path)
3834+
aot_inductor_module(x)
3835+
x_casted = x.to("cuda")
3836+
with self.assertRaisesRegex(Exception, ""):
3837+
aot_inductor_module(x_casted)
3838+
38133839
def test_non_contiguous_output_alias(self):
38143840
# Test return x, x.contiguous() where x is non-contiguous.
38153841
class Model(torch.nn.Module):

torch/_inductor/codegen/cpp_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,9 @@
8787
torch._mkldnn: "at::kMkldnn", # type: ignore[attr-defined]
8888
}
8989

90+
# matches c10/core/DeviceType.h
91+
DEVICE_TO_INT = {"cpu": 0, "cuda": 1}
92+
9093
_IS_WINDOWS = sys.platform == "win32"
9194

9295
INDEX_TYPE = "int64_t"

torch/_inductor/codegen/cpp_wrapper_cpu.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from ..virtualized import V
2424
from .aoti_hipify_utils import maybe_hipify_code_wrapper
2525
from .common import get_device_op_overrides, IndentedBuffer, Kernel
26-
from .cpp_utils import cexpr, DEVICE_TO_ATEN, DTYPE_TO_ATEN, DTYPE_TO_CPP
26+
from .cpp_utils import cexpr, DEVICE_TO_ATEN, DEVICE_TO_INT, DTYPE_TO_ATEN, DTYPE_TO_CPP
2727
from .wrapper import (
2828
EnterSubgraphLine,
2929
ExitSubgraphLine,
@@ -322,9 +322,12 @@ def codegen_symbol(
322322
raise AssertionError(f"Unknown value type: {type(value)}")
323323

324324
def generate_input_output_runtime_checks(self):
325-
# In debug_compile mode, we generate checks to ensure the dtype/shape/stride of each
326-
# real input/output tensor match ones provided at compile time via sample
327-
# input/output.
325+
"""
326+
In debug_compile mode, we generate checks to ensure the dtype/shape/stride/device of each
327+
real input/output tensor match ones provided at compile time via sample
328+
input/output.
329+
"""
330+
328331
def gen_check(handle_kind, idx, name, tensor):
329332
# Wrap AtenTensorHandle with ConstantHandle for cleaner utility function access
330333
self.prefix.writeline(
@@ -404,6 +407,27 @@ def gen_check(handle_kind, idx, name, tensor):
404407
"""
405408
)
406409

410+
# check input device type
411+
if isinstance(tensor, ir.TensorBox):
412+
tensor_device = tensor.get_device()
413+
if tensor_device is not None:
414+
expected_device_type = DEVICE_TO_INT.get(tensor_device.type)
415+
if expected_device_type is not None:
416+
self.codegen_input_device_type_var_decl(self.prefix, name)
417+
device_type_str = str(tensor_device.type)
418+
self.prefix.splice(
419+
f"""
420+
int32_t {name}_expected_device_type = {expected_device_type};
421+
if ({name}_expected_device_type != {name}_device_type) {{
422+
std::stringstream ss;
423+
ss << "{handle_kind}[{idx}]: unmatched device type, "
424+
<< "expected: " << {name}_expected_device_type << "{expected_device_type}({device_type_str}), "
425+
<< "but got: " << {name}_device_type << "\\n";
426+
throw std::runtime_error(ss.str());
427+
}}
428+
"""
429+
)
430+
407431
# Create a separate function for each input check to avoid "too big to optimize" error
408432
for idx, (name, tensor) in enumerate(V.graph.graph_inputs.items()):
409433
self.prefix.splice(
@@ -593,6 +617,12 @@ def codegen_input_size_var_decl(self, code: IndentedBuffer, name):
593617
def codegen_input_stride_var_decl(self, code: IndentedBuffer, name):
594618
code.writeline(f"auto {name}_stride = {name}.strides();")
595619

620+
def codegen_input_device_type_var_decl(self, code: IndentedBuffer, name):
621+
code.writeline(f"int32_t {name}_device_type;")
622+
code.writeline(
623+
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_device_type({name}, &{name}_device_type));"
624+
)
625+
596626
def codegen_model_kernels(self):
597627
self.prefix.writeline("namespace {")
598628

0 commit comments

Comments
 (0)
0