|
| 1 | +# mypy: allow-untyped-defs |
| 2 | + |
| 3 | +from ..cutlass_utils import try_import_cutlass |
| 4 | + |
| 5 | + |
| 6 | +if try_import_cutlass(): |
| 7 | + import ast |
| 8 | + import textwrap |
| 9 | + |
| 10 | + from cutlass.backend.evt import ( # type: ignore[import-untyped, import-not-found] |
| 11 | + EpilogueFunctorVisitor, |
| 12 | + ) |
| 13 | + from cutlass.backend.evt.backend.emitter_base import ( # type: ignore[import-untyped, import-not-found] |
| 14 | + FusionCallbacks, |
| 15 | + ) |
| 16 | + from cutlass.backend.evt.backend.sm90_emitter import ( # type: ignore[import-untyped, import-not-found] |
| 17 | + CollectiveEpilogue, |
| 18 | + ) |
| 19 | + from cutlass.backend.evt.frontend import ( # type: ignore[import-untyped, import-not-found] |
| 20 | + PythonASTFrontend, |
| 21 | + ) |
| 22 | + from cutlass.backend.evt.ir.tensor import ( # type: ignore[import-untyped, import-not-found] |
| 23 | + Tensor as CutlassTensor, |
| 24 | + ) |
| 25 | + from cutlass_library import DataType, EpilogueScheduleType, TileDescription |
| 26 | + |
| 27 | + def generate( |
| 28 | + fn_src: str, |
| 29 | + example_tensors: dict[str, CutlassTensor], |
| 30 | + accum_type: DataType, |
| 31 | + output_type: DataType, |
| 32 | + tile_description: TileDescription, |
| 33 | + epilogue_schedule: EpilogueScheduleType, |
| 34 | + **kwargs, |
| 35 | + ): |
| 36 | + epilogue_functor = _trace(fn_src, example_tensors, **kwargs) |
| 37 | + visitor = EpilogueFunctorVisitor(90, epilogue_functor) |
| 38 | + fusion_callbacks = FusionCallbacks(visitor.graph, 90) |
| 39 | + collective_epilogue = CollectiveEpilogue( |
| 40 | + tile_description, |
| 41 | + epilogue_schedule, |
| 42 | + accum_type, |
| 43 | + output_type, |
| 44 | + fusion_callbacks, |
| 45 | + ) |
| 46 | + return ( |
| 47 | + "".join(fusion_callbacks.emit()) |
| 48 | + + "\n\n" |
| 49 | + + "".join(collective_epilogue.emit()) |
| 50 | + ) |
| 51 | + |
| 52 | + # Based off of |
| 53 | + # https://github.com/NVIDIA/cutlass/blob/df18f5e4f5de76bed8be1de8e4c245f2f5ec3020/python/cutlass/epilogue/epilogue.py#L117 |
| 54 | + # This is modified to enable directly passing the source code of the epilogue vs getting it from a bona-fide python function |
| 55 | + # The reason for this is that inspect.getsource does not work with functions defined at runtime via exec/eval |
| 56 | + def _trace(fn_src, example_tensors, **kwargs): |
| 57 | + class EpilogueFunctor(PythonASTFrontend): |
| 58 | + def __init__(self, **kwargs): |
| 59 | + self.source = textwrap.dedent(fn_src) |
| 60 | + super().__init__(**kwargs) |
| 61 | + |
| 62 | + def parse(self, example_inputs): |
| 63 | + self.example_inputs = example_inputs |
| 64 | + self.ast = ast.parse(self.source) |
| 65 | + self.visit(self.ast) |
| 66 | + |
| 67 | + epilogue_functor = EpilogueFunctor(**kwargs) |
| 68 | + epilogue_functor.trace(example_tensors) |
| 69 | + return epilogue_functor |
0 commit comments