8000 [Inductor] Fix cuda_kernel typing · pytorch/pytorch@b70bd22 · GitHub
[go: up one dir, main page]

Skip to content

Commit b70bd22

Browse files
committed
[Inductor] Fix cuda_kernel typing
ghstack-source-id: 317fe01 Pull Request resolved: #150908
1 parent c7f3acb commit b70bd22

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

torch/_inductor/codegen/cuda/cuda_kernel.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# mypy: allow-untyped-defs
2+
import functools
23
import logging
34
from dataclasses import dataclass
45
from typing import Any, Callable, Literal, Optional, TYPE_CHECKING, Union
@@ -504,7 +505,10 @@ def __init__(
504505
category: str,
505506
input_nodes: list[Buffer],
506507
layout: Layout,
507-
make_kernel_render: Callable[[CUDATemplateBuffer, Optional[list[IRNode]]], str],
508+
make_kernel_render: Callable[
509+
[CUDATemplateBuffer, Optional[list[IRNode]]],
510+
tuple[CUDATemplateKernel, functools.partial[str]],
511+
],
508512
bmreq: CUDABenchmarkRequest,
509513
template: "CUDATemplate", # type: ignore[name-defined]
510514
info_kwargs: Optional[

0 commit comments

Comments
 (0)
0