8000 [AOTI][reland] Add an option to specify custom op C shim (#153968) · pytorch/pytorch@72a3c8d · GitHub
[go: up one dir, main page]

Skip to content

Commit 72a3c8d

Browse files
desertfirepytorchmergebot
authored andcommitted
[AOTI][reland] Add an option to specify custom op C shim (#153968)
Summary: Reland #153851 after fixing a fuzzer test issue. Add an option to tell AOTInductor codegen to generate C shim functions for certain custom ops instead of relying on ProxyExecutor. The lib that defines custom ops need to implement corresponding C shim functions. Pull Request resolved: #153968 Approved by: https://github.com/hl475
1 parent b7d08de commit 72a3c8d

File tree

7 files changed

+108
-8
lines changed
  • torch/_inductor
  • 7 files changed

    +108
    -8
    lines changed

    test/inductor/custom_ops.cpp

    Lines changed: 38 additions & 0 deletions
    Original file line numberDiff line numberDiff line change
    @@ -1,5 +1,8 @@
    11
    #include <torch/csrc/api/include/torch/types.h> // @manual=fbcode//caffe2:libtorch
    22

    3+
    #include <torch/csrc/inductor/aoti_torch/c/shim.h>
    4+
    #include <torch/csrc/inductor/aoti_torch/utils.h>
    5+
    36
    #include <cstdint>
    47
    #include <iostream>
    58
    #include <string>
    @@ -310,8 +313,40 @@ void fn_out_variant_without_return_meta(
    310313
    Tensor& out) {
    311314
    }
    312315

    316+
    Tensor fn_square_impl(const Tensor& tensor) {
    317+
    return tensor * tensor;
    318+
    }
    319+
    320+
    Tensor fn_square_meta(const Tensor& tensor) {
    321+
    return at::empty_like(tensor);
    322+
    }
    313323
    } // namespace at
    314324

    325+
    326+
    extern "C" {
    327+
    AOTI_TORCH_EXPORT AOTITorchError
    328+
    aoti_torch_cpu_fn_square(
    329+
    AtenTensorHandle input,
    330+
    AtenTensorHandle* ret) {
    331+
    AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
    332+
    auto tmp_result = at::fn_square_impl(
    333+
    torch::aot_inductor::resolve_tensor_dispatch_flags(input));
    334+
    *ret = torch::aot_inductor::new_tensor_handle(std::move(tmp_result));
    335+
    });
    336+
    }
    337+
    338+
    AOTI_TORCH_EXPORT AOTITorchError
    339+
    aoti_torch_cuda_fn_square(
    340+
    AtenTensorHandle input,
    341+
    AtenTensorHandle* ret) {
    342+
    AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
    343+
    auto tmp_result = at::fn_square_impl(
    344+
    torch::aot_inductor::resolve_tensor_dispatch_flags(input));
    345+
    *ret = torch::aot_inductor::new_tensor_handle(std::move(tmp_result));
    346+
    });
    347+
    }
    348+
    }
    349+
    315350
    TORCH_LIBRARY(aoti_custom_ops, m) {
    316351
    m.def("custom_add(Tensor t1, Tensor t2) -> Tensor");
    317352
    m.def(
    @@ -354,6 +389,7 @@ TORCH_LIBRARY(aoti_custom_ops, m) {
    354389
    "fn_with_input_mutation(Tensor(a!) t0, Tensor t1, Tensor(b!) t2) -> (Tensor, Tensor)");
    355390

    356391
    m.def("fn_out_variant_without_return(Tensor x, Tensor(a!) out) -> ()");
    392+
    m.def("fn_square(Tensor x) -> Tensor");
    357393
    }
    358394

    359395
    TORCH_LIBRARY_IMPL(aoti_custom_ops, CompositeExplicitAutograd, m) {
    @@ -365,6 +401,7 @@ TORCH_LIBRARY_IMPL(aoti_custom_ops, CompositeExplicitAutograd, m) {
    365401
    m.impl("fn_with_mix_outputs", at::fn_with_mix_outputs_impl);
    366402
    m.impl("fn_with_input_mutation", at::fn_with_input_mutation_impl);
    367403
    m.impl("fn_out_variant_without_return", at::fn_out_variant_without_return_impl);
    404+
    m.impl("fn_square", at::fn_square_impl);
    368405
    }
    369406

    370407
    TORCH_LIBRARY_IMPL(aoti_custom_ops, Meta, m) {
    @@ -375,4 +412,5 @@ TORCH_LIBRARY_IMPL(aoti_custom_ops, Meta, m) {
    375412
    m.impl("fn_with_mix_outputs", at::fn_with_mix_outputs_meta);
    376413
    m.impl("fn_with_input_mutation", at::fn_with_input_mutation_meta);
    377414
    m.impl("fn_out_variant_without_return", at::fn_out_variant_without_return_meta);
    415+
    m.impl("fn_square", at::fn_square_meta);
    378416
    }

    test/inductor/test_aot_inductor_custom_ops.py

    Lines changed: 33 additions & 0 deletions
    Original file line numberDiff line numberDiff line change
    @@ -20,6 +20,8 @@
    2020
    IS_MACOS,
    2121
    IS_SANDCASTLE,
    2222
    IS_WINDOWS,
    23+
    skipIfRocm,
    24+
    skipIfXpu,
    2325
    )
    2426
    from torch.testing._internal.logging_utils import LoggingTestCase, make_logging_test
    2527
    from torch.testing._internal.triton_utils import HAS_CUDA
    @@ -356,6 +358,37 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None):
    356358
    self.assertEqual(len(inps), 0)
    357359
    self.assertTrue(sentinel_seen)
    358360

    361+
    @skipIfXpu
    362+
    @skipIfRocm
    363+
    def test_custom_op_square(self) -> None:
    364+
    class Model(torch.nn.Module):
    365+
    def forward(self, x):
    366+
    return torch.ops.aoti_custom_ops.fn_square(x)
    367+
    368+
    m = Model().to(device=self.device)
    369+
    args = (torch.randn(2, 3, device=self.device),)
    370+
    with config.patch(
    371+
    "aot_inductor.custom_ops_to_c_shims",
    372+
    {
    373+
    torch.ops.aoti_custom_ops.fn_square.default: [
    374+
    """
    375+
    AOTITorchError
    376+
    aoti_torch_cpu_fn_square(
    377+
    AtenTensorHandle input,
    378+
    AtenTensorHandle* ret)""",
    379+
    """
    380+
    AOTITorchError
    381+
    aoti_torch_cuda_fn_square(
    382+
    AtenTensorHandle input,
    383+
    AtenTensorHandle* ret)""",
    384+
    ],
    385+
    },
    386+
    ), config.patch(
    387+
    "aot_inductor.custom_op_libs",
    388+
    ["aoti_custom_ops"],
    389+
    ):
    390+
    self.check_model(m, args)
    391+
    359392

    360393
    class AOTInductorLoggingTest(LoggingTestCase):
    361394
    @make_logging_test(dynamic=logging.DEBUG)

    torch/_inductor/codegen/cpp_wrapper_cpu.py

    Lines changed: 17 additions & 1 deletion
    Original file line numberDiff line numberDiff line change
    @@ -6,7 +6,7 @@
    66
    import os
    77
    import sys
    88
    import textwrap
    9-
    from itertools import count
    9+
    from itertools import chain, count
    1010
    from typing import Callable, Optional, Protocol, TYPE_CHECKING, Union
    1111

    1212
    import sympy
    @@ -237,6 +237,22 @@ def write_prefix(self):
    237237
    if V.graph.is_const_graph:
    238238
    # We do not write prefix for constant graph, it will be written by main module.
    239239
    return
    240+
    if config.aot_inductor.custom_ops_to_c_shims:
    241+
    # custom_ops_to_c_shims contains declaration of custom ops with C shim.
    242+
    # TODO: this could be auto-generated from a passed-in custom op schema
    243+
    custom_c_shims = list(
    244+
    chain(*config.aot_inductor.custom_ops_to_c_shims.values())
    245+
    )
    246+
    declarations = "\n".join(
    247+
    [f"extern {textwrap.dedent(shim)};" for shim in custom_c_shims]
    248+
    )
    249+
    self.prefix.splice(
    250+
    f"""
    251+
    extern "C" {{
    252+
    {declarations}
    253+
    }}
    254+
    """
    255+
    )
    240256
    if V.graph.aot_mode:
    241257
    self.prefix.writeline("namespace torch::aot_inductor {")
    242258

    torch/_inductor/config.py

    Lines changed: 5 additions & 0 deletions
    Original file line numberDiff line numberDiff line change
    @@ -1317,6 +1317,11 @@ class aot_inductor:
    13171317
    # Embed generated .cubin files into the .so
    13181318
    embed_cubin: bool = False
    13191319

    1320+
    # Custom ops that have implemented C shim wrappers, defined as an op to C shim declaration dict
    1321+
    custom_ops_to_c_shims: dict[torch._ops.OpOverload, list[str]] = {}
    1322+
    # custom op libs that have implemented C shim wrappers
    1323+
    custom_op_libs: Optional[list[str]] = None
    1324+
    13201325

    13211326
    class cuda:
    13221327
    """Settings for cuda backend, today this consists of cutlass"""

    torch/_inductor/cpp_builder.py

    Lines changed: 3 additions & 0 deletions
    Original file line numberDiff line numberDiff line change
    @@ -1323,6 +1323,9 @@ def get_cpp_torch_device_options(
    13231323
    # Only add link args, when compile_only is false.
    13241324
    passthrough_args = ["-Wl,-Bstatic -lcudart_static -Wl,-Bdynamic"]
    13251325

    1326+
    if config.aot_inductor.custom_op_libs:
    1327+
    libraries += config.aot_inductor.custom_op_libs
    1328+
    13261329
    return (
    13271330
    definitions,
    13281331
    include_dirs,

    torch/_inductor/fuzzer.py

    Lines changed: 8 additions & 6 deletions
    Original file line numberDiff line numberDiff line change
    @@ -220,9 +220,9 @@ def _generate_value_for_type(
    220220
    elem_type = getattr(
    221221
    type_hint,
    222222
    "__args__",
    223-
    [type(default[0])] if len(default) else [type(None)],
    223+
    [type(default[0])] if default and len(default) else [type(None)],
    224224
    )[0]
    225-
    new_default = default[0] if len(default) > 0 else None
    225+
    new_default = default[0] if default and len(default) > 0 else None
    226226
    return [
    227227
    SamplingMethod._generate_value_for_type(
    228228
    random_sample, field_name, elem_type, new_default
    @@ -234,9 +234,9 @@ def _generate_value_for_type(
    234234
    elem_type = getattr(
    235235
    type_hint,
    236236
    "__args__",
    237-
    [type(indexable[0])] if len(default) else [type(None)],
    237+
    [type(indexable[0])] if default and len(default) else [type(None)],
    238238
    )[0]
    239-
    new_default = indexable[0] if len(default) > 0 else None
    239+
    new_default = indexable[0] if default and len(default) > 0 else None
    240240
    return { # noqa: set_linter
    241241
    SamplingMethod._generate_value_for_type(
    242242
    random_sample, field_name, elem_type, new_default
    @@ -248,9 +248,9 @@ def _generate_value_for_type(
    248248
    elem_type = getattr(
    249249
    type_hint,
    250250
    "__args__",
    251-
    [type(indexable[0])] if len(default) else [type(None)],
    251+
    [type(indexable[0])] if default and len(default) else [type(None)],
    252252
    )[0]
    253-
    new_default = indexable[0] if len(default) > 0 else None
    253+
    new_default = indexable[0] if default and len(default) > 0 else None
    254254
    return OrderedSet(
    255255
    [
    256256
    SamplingMethod._generate_value_for_type(
    @@ -363,6 +363,8 @@ def dummy_function(*args, **kwargs): # type: ignore[no-untyped-def]
    363363
    )
    364364

    365365
    return dummy_function
    366+
    elif type_hint == torch._ops.OpOverload:
    367+
    return torch.ops.aten.add.default
    366368
    elif TypeExemplars.contains(type_hint):
    367369
    return TypeExemplars.example(type_hint)
    368370
    elif type_hint == Any:

    torch/_inductor/ir.py

    Lines changed: 4 additions & 1 deletion
    Original file line numberDiff line numberDiff line change
    @@ -7032,7 +7032,10 @@ def codegen(self, wrapper) -> None: # type: ignore[no-untyped-def]
    70327032
    assert isinstance(kernel, torch._ops.OpOverload)
    70337033
    elif V.graph.cpp_wrapper:
    70347034
    # For non-aten OpOverload, i.e. custom ops
    7035-
    self.use_runtime_dispatch = True
    7035+
    # If the op is in custom_ops_to_c_shims, generate direct function call
    7036+
    self.use_runtime_dispatch = (
    7037+
    kernel not in config.aot_inductor.custom_ops_to_c_shims
    7038+
    )
    70367039

    70377040
    def do_runtime_dispatch() -> None:
    70387041
    args = None

    0 commit comments

    Comments
     (0)
    0