8000 [inductor] Add type annotations to _inductor/utils.py · pytorch/pytorch@c58f80c · GitHub
[go: up one dir, main page]

Skip to content

Commit c58f80c

Browse files
committed
[inductor] Add type annotations to _inductor/utils.py
ghstack-source-id: 38f7b8a Pull Request resolved: #144108
1 parent 7a93a58 commit c58f80c

12 files changed

+388
-368
lines changed

test/inductor/test_torchinductor.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,9 @@
5050
from torch._inductor.codecache import cpp_prefix_path
5151
from torch._inductor.codegen.common import DataTypePropagation, OptimizationContext
5252
from torch._inductor.fx_passes import pad_mm
53+
from torch._inductor.scheduler import Scheduler
5354
from torch._inductor.test_case import TestCase as InductorTestCase
5455
from torch._inductor.utils import (
55-
add_scheduler_init_hook,
5656
run_and_get_code,
5757
run_and_get_cpp_code,
5858
run_and_get_kernels,
@@ -11542,8 +11542,9 @@ def hook_fn(scheduler, nodes):
1154211542
or "i0 + i1 * s1" in mul_buf.data.inner_fn_str()
1154311543
)
1154411544

11545-
with add_scheduler_init_hook(hook_fn):
11545+
with _add_scheduler_init_hook(hook_fn):
1154611546
actual = torch.compile(f, fullgraph=True)(x)
11547+
1154711548
self.assertEqual(ref, actual)
1154811549
self.assertTrue(called)
1154911550

@@ -13581,6 +13582,23 @@ def fn(pytype, dtype):
1358113582
self.assertEqual(ret_opt, fn(pytype, dtype))
1358213583

1358313584

13585+
def _add_scheduler_init_hook(pre_fn, post_fn=None):
13586+
"""
13587+
Add hook functions to be called at the beginning and end of Scheduler.__init__.
13588+
Used for unit tests.
13589+
"""
13590+
orig_fn = Scheduler.__init__
13591+
13592+
def wrapper(scheduler, nodes):
13593+
pre_fn(scheduler, nodes)
13594+
out = orig_fn(scheduler, nodes)
13595+
if post_fn:
13596+
post_fn(scheduler, nodes)
13597+
return out
13598+
13599+
return unittest.mock.patch.object(Scheduler, "__init__", wrapper)
13600+
13601+
1358413602
if __name__ == "__main__":
1358513603
from torch._inductor.test_case import run_tests
1358613604

torch/_inductor/codegen/cpp_wrapper_cpu.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,7 @@ def write_wrapper_decl(self):
401401
if V.graph.aot_mode:
402402
if V.graph.const_module:
403403
self.header.splice(V.graph.const_module.wrapper_code.header)
404+
assert V.graph.const_code is not None
404405
self.prefix.splice(V.graph.const_code)
405406

406407
if V.graph.is_const_graph:

torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ def write_wrapper_decl(self):
182182

183183
if V.graph.const_module:
184184
self.header.splice(V.graph.const_module.wrapper_code.header)
185+
assert V.graph.const_code is not None
185186
self.prefix.splice(V.graph.const_code)
186187

187188
if V.graph.is_const_graph:

torch/_inductor/codegen/halide.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1481,7 +1481,7 @@ def halide_kernel_meta(self) -> HalideMeta:
14811481
argtypes,
14821482
target="-".join(target),
14831483
scheduler=schduler,
1484-
scheduler_flags=scheduler_flags,
1484+
scheduler_flags=scheduler_flags, # type: ignore[arg-type]
14851485
cuda_device=cuda_device,
14861486
)
14871487

torch/_inductor/codegen/simd.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1342,7 +1342,10 @@ def codegen_node_schedule(self, kernel_features: SIMDKernelFeatures):
13421342
src_code = kernel.codegen_kernel()
13431343
kernel_name = self.define_kernel(src_code, node_schedule, kernel)
13441344
if config.trace.enabled:
1345-
set_kernel_post_grad_provenance_tracing(node_schedule, kernel_name)
1345+
set_kernel_post_grad_provenance_tracing(
1346+
node_schedule, # type: ignore[arg-type]
1347+
kernel_name,
1348+
)
13461349
log.debug("Generating kernel code with kernel_name: %s", kernel_name)
13471350
kernel.kernel_name = kernel_name
13481351
kernel.code_hash = code_hash(src_code)

torch/_inductor/compile_fx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1533,7 +1533,7 @@ def get_cuda_device_context(gm: torch.fx.GraphModule) -> ContextManager[None]:
15331533

15341534
out_devices: OrderedSet[torch.device] = OrderedSet(
15351535
arg.meta["val"].device
1536-
for arg in output_node(gm).args[0]
1536+
for arg in output_node(gm).args[0] # type: ignore[union-attr]
15371537
if isinstance(arg, fx.Node) and isinstance(arg.meta.get("val"), torch.Tensor)
15381538
)
15391539
cuda_devices: OrderedSet[torch.device] = OrderedSet(

torch/_inductor/decomposition.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -970,19 +970,10 @@ def max_pool2d_with_indices(
970970
dilation: Union[int, List[int]] = 1,
971971
ceil_mode: bool = False,
972972
) -> tuple[torch.Tensor, torch.Tensor]:
973-
if dilation == 1:
974-
dilation = [1, 1]
975-
976-
if padding == 0:
977-
padding = [0, 0]
978-
979-
if not stride:
980-
stride = kernel_size
981-
982973
kernel_size = pad_listlike(kernel_size, 2)
983974
dilation = pad_listlike(dilation, 2)
984975
padding = pad_listlike(padding, 2)
985-
stride = pad_listlike(stride, 2)
976+
stride = pad_listlike(stride or kernel_size, 2)
986977

987978
window_size = kernel_size[0] * kernel_size[1]
988979
# We fallback when using non-default dilation or when the window size is too large

torch/_inductor/fx_passes/quantization.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1202,14 +1202,8 @@ def qmaxpool2d(match: Match, *args, **kwargs):
12021202
dilation = kwargs["dilation"] if ("dilation" in kwargs) else 1
12031203
ceil_mode = kwargs["ceil_mode"] if ("ceil_mode" in kwargs) else False
12041204

1205-
if padding == 0:
1206-
padding = [0, 0]
1207-
if dilation == 1:
1208-
dilation = [1, 1]
1209-
if not stride:
1210-
stride = kernel_size
12111205
kernel_size = pad_listlike(kernel_size, 2)
1212-
stride = pad_listlike(stride, 2)
1206+
stride = pad_listlike(stride or kernel_size, 2)
12131207
padding = pad_listlike(padding, 2)
12141208
dilation = pad_listlike(dilation, 2)
12151209

torch/_inductor/lowering.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3737,7 +3737,7 @@ def scatter_fallback(
37373737
op_overload,
37383738
reduce,
37393739
self.get_dtype(),
3740-
src.get_dtype() if src_is_tensor else type(src),
3740+
src.get_dtype() if src_is_tensor else type(src), # type: ignore[arg-type]
37413741
src.get_device().type if src_is_tensor else "not impl",
37423742
src_is_tensor,
37433743
):
@@ -4154,13 +4154,6 @@ def should_fallback_max_pool2d_with_indices(kernel_size, dilation):
41544154
def max_pool2d_checks(
41554155
x, kernel_size, stride, padding, dilation, *, assert_fallback=None
41564156
):
4157-
if padding == 0:
4158-
padding = [0, 0]
4159-
if dilation == 1:
4160-
dilation = [1, 1]
4161-
if not stride:
4162-
stride = kernel_size
4163-
41644157
kernel_size = pad_listlike(kernel_size, 2)
41654158
stride = pad_listlike(stride, 2)
41664159
padding = pad_listlike(padding, 2)

torch/_inductor/output_code.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -441,7 +441,7 @@ def __init__(
441441
assert len(output.args) == 1
442442
stack_traces = [
443443
(arg.stack_trace if isinstance(arg, torch.fx.node.Node) else None)
444-
for arg in output.args[0]
444+
for arg in output.args[0] # type: ignore[union-attr]
445445
]
446446
cudagraph_fail_reasons = [s for b, s in cudagraph_tests if not b]
447447
placeholders = tuple(get_placeholder_info(gm.graph))

torch/_inductor/select_algorithm.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1747,11 +1747,8 @@ def no_op(*args, **kwargs):
17471747
# different than the original values. we explicitly restore the state
17481748
# here to avoid this issue.
17491749

1750-
initial_stdout = sys.stdout
1751-
initial_stderr = sys.stderr
1752-
17531750
def precompile_with_captured_stdout(choice):
1754-
with restore_stdout_stderr(initial_stdout, initial_stderr):
1751+
with restore_stdout_stderr():
17551752
choice.precompile()
17561753

17571754
def on_complete(future):
@@ -1784,7 +1781,7 @@ def on_complete(future):
17841781
futures[future] = c
17851782

17861783
@functools.lru_cache(None)
1787-
@restore_stdout_stderr(initial_stdout, initial_stderr)
1784+
@restore_stdout_stderr()
17881785
def wait_on_futures():
17891786
counters["inductor"]["select_algorithm_precompile"] += 1
17901787
for future in as_completed(

0 commit comments

Comments
 (0)
101
0