8000 [Inductor] Fix cuda_kernel typing · pytorch/pytorch@aba636d · GitHub
[go: up one dir, main page]

Skip to content

Commit aba636d

Browse files
committed
[Inductor] Fix cuda_kernel typing
ghstack-source-id: 1853112 Pull Request resolved: #150908
1 parent 1bad49f commit aba636d

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

torch/_inductor/codegen/cuda/cuda_kernel.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# mypy: allow-untyped-defs
2+
import functools
23
import logging
34
from dataclasses import dataclass
45
from typing import Any, Callable, Literal, Optional, TYPE_CHECKING, Union
@@ -520,7 +521,10 @@ def __init__(
520521
category: str,
521522
input_nodes: list[Buffer],
522523
layout: Layout,
523-
make_kernel_render: Callable[[CUDATemplateBuffer, Optional[list[IRNode]]], str],
524+
make_kernel_render: Callable[
525+
[CUDATemplateBuffer, Optional[list[IRNode]]],
526+
tuple[CUDATemplateKernel, functools.partial[str]],
527+
],
524528
bmreq: CUDABenchmarkRequest,
525529
template: "CUDATemplate", # type: ignore[name-defined]
526530
info_kwargs: Optional[

0 commit comments

Comments
 (0)
0