8000 [Cutlass] Integrate EVT into CUDACPPScheduling by mlazos · Pull Request #150906 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[Cutlass] Integrate EVT into CUDACPPScheduling #150906

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 37 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
23ef06e
[Cutlass] Integrate EVT into CUDACPPScheduling
mlazos Apr 9, 2025
144b406
Update on "[Cutlass] Integrate EVT into CUDACPPScheduling"
mlazos Apr 9, 2025
4202438
Update on "[Cutlass] Integrate EVT into CUDACPPScheduling"
mlazos Apr 9, 2025
19ecaf4
Update on "[Cutlass] Integrate EVT into CUDACPPScheduling"
mlazos Apr 10, 2025
094448c
Update on "[Cutlass] Integrate EVT into CUDACPPScheduling"
mlazos Apr 10, 2025
b0ed2ed
Update on "[Cutlass] Integrate EVT into CUDACPPScheduling"
mlazos Apr 10, 2025
22b9474
Update on "[Cutlass] Integrate EVT into CUDACPPScheduling"
mlazos Apr 11, 2025
91ad3db
Update on "[Cutlass] Integrate EVT into CUDACPPScheduling"
mlazos Apr 11, 2025
8c29ec2
Update on "[Cutlass] Integrate EVT into CUDACPPScheduling"
mlazos Apr 11, 2025
88590c2
Update on "[Cutlass] Integrate EVT into CUDACPPScheduling"
mlazos Apr 11, 2025
5e5f9de
Update on "[Cutlass] Integrate EVT into CUDACPPScheduling"
mlazos Apr 11, 2025
a1eeb5a
Update on "[Cutlass] Integrate EVT into CUDACPPScheduling"
mlazos Apr 14, 2025
9f6dd17
Update on "[Cutlass] Integrate EVT into CUDACPPScheduling"
mlazos Apr 15, 2025
2a8a6d4
Update on "[Cutlass] Integrate EVT into CUDACPPScheduling"
mlazos Apr 15, 2025
b79a270
Update on "[Cutlass] Integrate EVT into CUDACPPScheduling"
mlazos Apr 15, 2025
ca30523
Update on "[Cutlass] Integrate EVT into CUDACPPScheduling"
mlazos Apr 15, 2025
69ed263
Update on "[Cutlass] Integrate EVT into CUDACPPScheduling"
mlazos Apr 15, 2025
1ce6bf2
Update on "[Cutlass] Integrate EVT into CUDACPPScheduling"
mlazos Apr 15, 2025
cdb4ad5
Update on "[Cutlass] Integrate EVT into CUDACPPScheduling"
mlazos Apr 15, 2025
1cd84f8
Update on "[Cutlass] Integrate EVT into CUDACPPScheduling"
mlazos Apr 16, 2025
d0edb48
Update on "[Cutlass] Integrate EVT into CUDACPPScheduling"
mlazos Apr 17, 2025
32c4cc4
Update on "[Cutlass] Integrate EVT into CUDACPPScheduling"
mlazos Apr 17, 2025
c40aab9
Update on "[Cutlass] Integrate EVT into CUDACPPScheduling"
mlazos Apr 17, 2025
c167541
Update on "[Cutlass] Integrate EVT into CUDACPPScheduling"
mlazos Apr 17, 2025
7e05f66
Update on "[Cutlass] Integrate EVT into CUDACPPScheduling"
mlazos Apr 17, 2025
31c7b31
Update on "[Cutlass] Integrate EVT into CUDACPPScheduling"
mlazos Apr 18, 2025
ad22819
Update on "[Cutlass] Integrate EVT into CUDACPPScheduling"
mlazos Apr 22, 2025
b13b2d9
Update on "[Cutlass] Integrate EVT into CUDACPPScheduling"
mlazos Apr 25, 2025
dc8c0e6
Update on "[Cutlass] Integrate EVT into CUDACPPScheduling"
mlazos Apr 28, 2025
98f42f8
Update on "[Cutlass] Integrate EVT into CUDACPPScheduling"
mlazos Apr 28, 2025
9149411
Update on "[Cutlass] Integrate EVT into CUDACPPScheduling"
mlazos Apr 28, 2025
10f931b
Update on "[Cutlass] Integrate EVT into CUDACPPScheduling"
mlazos Apr 29, 2025
fa6a8c5
Update on "[Cutlass] Integrate EVT into CUDACPPScheduling"
mlazos Apr 29, 2025
b1708be
Update on "[Cutlass] Integrate EVT into CUDACPPScheduling"
mlazos May 1, 2025
6bb62cf
Update on "[Cutlass] Integrate EVT into CUDACPPScheduling"
mlazos May 2, 2025
df5cd91
Update on "[Cutlass] Integrate EVT into CUDACPPScheduling"
mlazos May 3, 2025
f20f30b
Update on "[Cutlass] Integrate EVT into CUDACPPScheduling"
mlazos May 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 146 additions & 5 deletions torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,14 @@
from ...._dynamo.utils import counters
from ... import config
from ...codecache import code_hash, get_path
from ...ir import CUDATemplateBuffer
from ...scheduler import BaseSchedulerNode, BaseScheduling, SchedulerNode
from ...ir import Buffer, ComputedBuffer, CUDATemplateBuffer, Pointwise
from ...scheduler import (
BaseSchedulerNode,
BaseScheduling,
FusedSchedulerNode,
SchedulerNode,
WhyNoFuse,
)
from ...utils import get_fused_kernel_name, get_kernel_metadata, sympy_product
from ...virtualized import V
from ..common import BackendFeature, IndentedBuffer
Expand All @@ -18,6 +24,12 @@
log = logging.getLogger(__name__)


class WhyNoFuseNames(WhyNoFuse):
def __init__(self, name1: str, name2: str) -> None:
self.name1 = name1
self.name2 = name2


class CUDACPPScheduling(BaseScheduling):
"""
Partial Scheduling implementation for CUDA C++ Kernels.
Expand All @@ -40,9 +52,32 @@ def is_cuda_cpp_template(node: BaseSchedulerNode) -> bool:
node.node, CUDATemplateBuffer
)

def is_cuda_cpp_fused_template(self, node: BaseSchedulerNode) -> bool:
return isinstance(node, FusedSchedulerNode) and self.is_cuda_cpp_template(node)

def can_fuse_vertical(
self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
) -> bool:
if self.is_cuda_cpp_template(node1) and isinstance(node2, SchedulerNode):
assert node1.node, "node1.node should not be None"
assert node2.node, "node2.node should not be None"
return self._can_fuse_epilogue_impl(
cast(CUDATemplateBuffer, node1.node),
[],
node2, # type: ignore[arg-type]
)
elif self.is_cuda_cpp_fused_template(node1) and isinstance(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typically we fuse all of the pointwise nodes first. Then fuse into templates. So its more likely that the second node is a FusedSchedulerNode and the cpp template is still just cuda_cpp_template..

node2, SchedulerNode
):
assert node1.node, "node1.node should not be None"
assert node2.node, "node2.node should not be None"
fnode1 = cast(FusedSchedulerNode, node1)
return self._can_fuse_epilogue_impl(
fnode1.get_template_node(), # type: ignore[arg-type]
self._unwrap_epilogue_nodes(fnode1),
node2, # type: ignore[arg-type]
)

return False

def define_kernel(self, src_code: str, node_schedule) -> str:
Expand Down Expand Up @@ -94,13 +129,19 @@ def codegen_template(
_, (_numel, rnumel) = template_node.group
assert rnumel == 1
ctb: CUDATemplateBuffer = cast(CUDATemplateBuffer, template_node.node)
kernel, render = ctb.make_kernel_render(ctb)
epilogue_ir_nodes: list[Buffer] = [n.node for n in epilogue_nodes] # type: ignore[misc]
assert all(isinstance(n, ComputedBuffer) for n in epilogue_ir_nodes), (
"Epilogue nodes must all be instances of ir.ComputedBuffer"
)
kernel, render = ctb.make_kernel_render(ctb, epilogue_nodes=epilogue_nodes)

with kernel:
template_node.mark_run()
for node in [template_node, *epilogue_nodes]:
node.mark_run()
src_code = render()

with V.set_kernel_handler(kernel):
node_schedule = [template_node]
node_schedule = [template_node, *epilogue_nodes]
kernel_name = self.define_kernel(src_code, node_schedule)

# debug printing values of intermediate tensors
Expand All @@ -114,3 +155,103 @@ def codegen_template(

V.graph.removed_buffers |= kernel.removed_buffers
self.free_buffers_in_scheduler()

@staticmethod
def _unwrap_epilogue_nodes(
fused_node: FusedSchedulerNode,
) -> list[BaseSchedulerNode]:
nodes = fused_node.get_nodes()
template_node = fused_node.get_template_node()
assert all(n.node is not None for n in nodes), (
"All epilogue nodes should have an IRNode"
)
return cast(
list[BaseSchedulerNode], [n for n in nodes if n.node is not template_node]
)

def _can_fuse_epilogue_impl(
self,
cuda_template_buffer: CUDATemplateBuffer,
existing_epilogue_nodes: list[BaseSchedulerNode],
node_to_fuse: BaseSchedulerNode,
) -> bool:
"""
Check if the given node can be fused with the epilogue. At the moment, Kernels
support fusion with Pointwise operations, wrapped in (named) ComputedBuffer nodes.

Args:
cuda_template_buffer : A CUDATemplateBuffer object representing the CUDA template and it's result buffer
existing_epilogue_nodes : List[SchedulerNode]: The list of already fused epilogue nodes.
node_to_fuse: The SchedulerNode node to be checked if it can be fused with the epilogue.
Returns:
- bool: True if the given node can be fused with the epilogue, False otherwise.

"""

why = WhyNoFuseNames(cuda_template_buffer.get_name(), node_to_fuse.get_name())

ir_node_to_fuse = node_to_fuse.node
# for typing
assert ir_node_to_fuse

assert isinstance(cuda_template_buffer, CUDATemplateBuffer)
if not isinstance(ir_node_to_fuse, ComputedBuffer):
return False
if not isinstance(ir_node_to_fuse.data, Pointwise):
return False
# We can fuse a Pointwise op that depends on the last fused epilogue node
# if any. If there is no epilogue node yet, it needs to depend on the template
# node
node_name = ir_node_to_fuse.get_computed_buffer_name() # type: ignore[attr-defined]
if node_name is None:
return False

assert (
len(existing_epilogue_nodes)
or cuda_template_buffer.get_name() in ir_node_to_fuse.get_read_names()
), "First epilogue node must read from cuda template buffer"

# dtype can differ, and strides can differ as long as they are broadcastable
if ir_node_to_fuse.get_size() != cuda_template_buffer.get_size():
why(
f"{cuda_template_buffer.get_name()}'s size: {cuda_template_buffer.get_size()} \
differs from {node_name}'s size: {ir_node_to_fuse.get_size()}"
)
return False
elif node_to_fuse.has_aliasing_or_mutation():
why(f"{node_name} has aliasing or mutation")
return False
elif node_to_fuse.is_reduction():
why(f"{node_name} is a reduction which is not yet supported by EVT")
return False
elif not config.epilogue_fusion:
why("epilogue fusion is not enabled")
return False

try:
from torch._inductor.codegen.cuda.cutlass_python_evt import (
CutlassEVTCodegen,
)

CutlassEVTCodegen.ir_to_evt_python_code(
cuda_template_buffer.get_name(),
existing_epilogue_nodes + [node_to_fuse],
)

except NotImplementedError as e:
not_implemented_op = str(e)
if not_implemented_op.startswith("_op_"):
not_implemented_op = not_implemented_op[4:]
why(
f"Cannot fuse epilogue node {node_to_fuse} into {cuda_template_buffer.name}, \
likely due to unsupported operation: {not_implemented_op}" # noqa: G004, B950
)
return False
else: # Likely due to unsupported dtype.
why(
f"Cannot fuse epilogue node {node_to_fuse} into {cuda_template_buffer.name}. \
Reason: {not_implemented_op}" # noqa: G004, B950
)
return False

return True
8 changes: 7 additions & 1 deletion torch/_inductor/codegen/cuda/cuda_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import functools
import itertools
from dataclasses import dataclass
from typing import Any, Optional
from typing import Any, Optional, TYPE_CHECKING
from typing_extensions import override
from unittest.mock import patch

Expand All @@ -19,6 +19,12 @@
from .cuda_kernel import CUDATemplateCaller, CUDATemplateKernel


if TYPE_CHECKING:
from ...scheduler import BaseSchedulerNode # noqa: TC004
else:
BaseSchedulerNode = Any


autotuning_log = getArtifactLogger(__name__, "autotuning")


Expand Down
1 change: 0 additions & 1 deletion torch/_inductor/codegen/cuda_combined_scheduling.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ def codegen_template(
prologue_nodes: Sequence[BaseSchedulerNode],
) -> Optional[str]:
if self._cuda_cpp_scheduling.is_cuda_cpp_template(template_node):
assert not epilogue_nodes
assert not prologue_nodes
return self._cuda_cpp_scheduling.codegen_template(
template_node, epilogue_nodes, prologue_nodes
Expand Down
17 changes: 7 additions & 10 deletions torch/_inductor/codegen/simd.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,11 @@
from ..analyze_preserves_zero_mask import prologue_preserves_zero_mask
from ..codecache import code_hash
from ..dependencies import MemoryDep, StarDep, WeakDep
from ..ir import IRNode, TritonTemplateBuffer


if TYPE_CHECKING:
from ..ir import IRNode

from ..optimize_indexing import indexing_dtype_strength_reduction
from ..runtime.runtime_utils import green_text, yellow_text
from ..scheduler import BaseSchedulerNode, BaseScheduling, WhyNoFuse
Expand Down Expand Up @@ -1155,16 +1159,9 @@ def can_fuse(self, node1, node2):
)
return False

for n, node_name in zip((node1, node2), ("node1", "node2")):
for n in (node1, node2):
if n.is_template():
# Only allow fusion for TritonTemplates for now.
# Fusion for CUDATemplates are not supported.
is_triton_template = isinstance(
n.get_template_node(), TritonTemplateBuffer
)
if not is_triton_template:
why(f"{node_name} is not TritonTemplateBuffer")
return is_triton_template
return True

# check for a bad combined tiling
tiling1 = self.select_tiling(node1.get_nodes(), numel1, rnumel1)
Expand Down
8 changes: 4 additions & 4 deletions torch/_inductor/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -901,21 +901,21 @@ def get_prologue_template_epilogue(
class WhyNoFuse:
# TODO when we drop support for Python < 3.10, we can use
# @dataclass(slots=True) instead of manually specifying __slots__.
__slots__ = ["node1", "node2", "reason", "args"]
__slots__ = ["name1", "name2", "reason", "args"]
reason: str
args: tuple[Any, ...]

def __init__(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode) -> None:
self.node1 = node1
self.node2 = node2
self.name1 = node1.get_name()
self.name2 = node2.get_name()

def __call__(self, reason: str, *args: Any) -> None:
self.reason = reason
self.args = args
fusion_log.debug(self)

def __str__(self) -> str:
return f"cannot fuse {self.node1.get_name()} with {self.node2.get_name()}: " + (
return f"cannot fuse {self.name1} with {self.name2}: " + (
self.reason % self.args
)

Expand Down
Loading
0