|
23 | 23 | from ..virtualized import V
|
24 | 24 | from .aoti_hipify_utils import maybe_hipify_code_wrapper
|
25 | 25 | from .common import get_device_op_overrides, IndentedBuffer, Kernel
|
26 |
| -from .cpp_utils import cexpr, DEVICE_TO_ATEN, DTYPE_TO_ATEN, DTYPE_TO_CPP |
| 26 | +from .cpp_utils import cexpr, DEVICE_TO_ATEN, DEVICE_TO_INT, DTYPE_TO_ATEN, DTYPE_TO_CPP |
27 | 27 | from .wrapper import (
|
28 | 28 | EnterSubgraphLine,
|
29 | 29 | ExitSubgraphLine,
|
@@ -322,9 +322,12 @@ def codegen_symbol(
|
322 | 322 | raise AssertionError(f"Unknown value type: {type(value)}")
|
323 | 323 |
|
324 | 324 | def generate_input_output_runtime_checks(self):
|
325 |
| - # In debug_compile mode, we generate checks to ensure the dtype/shape/stride of each |
326 |
| - # real input/output tensor match ones provided at compile time via sample |
327 |
| - # input/output. |
| 325 | + """ |
| 326 | + In debug_compile mode, we generate checks to ensure the dtype/shape/stride/device of each |
| 327 | + real input/output tensor match ones provided at compile time via sample |
| 328 | + input/output. |
| 329 | + """ |
| 330 | + |
328 | 331 | def gen_check(handle_kind, idx, name, tensor):
|
329 | 332 | # Wrap AtenTensorHandle with ConstantHandle for cleaner utility function access
|
330 | 333 | self.prefix.writeline(
|
@@ -404,6 +407,27 @@ def gen_check(handle_kind, idx, name, tensor):
|
404 | 407 | """
|
405 | 408 | )
|
406 | 409 |
|
| 410 | + # check input device type |
| 411 | + if isinstance(tensor, ir.TensorBox): |
| 412 | + tensor_device = tensor.get_device() |
| 413 | + if tensor_device is not None: |
| 414 | + expected_device_type = DEVICE_TO_INT.get(tensor_device.type) |
| 415 | + if expected_device_type is not None: |
| 416 | + self.codegen_input_device_type_var_decl(self.prefix, name) |
| 417 | + device_type_str = str(tensor_device.type) |
| 418 | + self.prefix.splice( |
| 419 | + f""" |
| 420 | + int32_t {name}_expected_device_type = {expected_device_type}; |
| 421 | + if ({name}_expected_device_type != {name}_device_type) {{ |
| 422 | + std::stringstream ss; |
| 423 | + ss << "{handle_kind}[{idx}]: unmatched device type, " |
| 424 | + << "expected: " << {name}_expected_device_type << "{expected_device_type}({device_type_str}), " |
| 425 | + << "but got: " << {name}_device_type << "\\n"; |
| 426 | + throw std::runtime_error(ss.str()); |
| 427 | + }} |
| 428 | + """ |
| 429 | + ) |
| 430 | + |
407 | 431 | # Create a separate function for each input check to avoid "too big to optimize" error
|
408 | 432 | for idx, (name, tensor) in enumerate(V.graph.graph_inputs.items()):
|
409 | 433 | self.prefix.splice(
|
@@ -593,6 +617,12 @@ def codegen_input_size_var_decl(self, code: IndentedBuffer, name):
|
593 | 617 | def codegen_input_stride_var_decl(self, code: IndentedBuffer, name):
|
594 | 618 | code.writeline(f"auto {name}_stride = {name}.strides();")
|
595 | 619 |
|
| 620 | + def codegen_input_device_type_var_decl(self, code: IndentedBuffer, name): |
| 621 | + code.writeline(f"int32_t {name}_device_type;") |
| 622 | + code.writeline( |
| 623 | + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_device_type({name}, &{name}_device_type));" |
| 624 | + ) |
| 625 | + |
596 | 626 | def codegen_model_kernels(self):
|
597 | 627 | self.prefix.writeline("namespace {")
|
598 | 628 |
|
|
0 commit comments