8000 Improve torch.ops typing by benjaminglass1 · Pull Request #153558 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Improve torch.ops typing #153558

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion torch/_C/__init__.pyi.in
Original file line number Diff line number Diff line change
Expand Up @@ -1631,6 +1631,7 @@ class Generator:
class _DispatchOperatorHandle:
def schema(self) -> FunctionSchema: ...
def debug(self) -> str: ...
def redispatch_boxed(self, keyset: DispatchKeySet, *args, **kwargs) -> Stack: ...

class _DispatchModule:
def reset(self) -> None: ...
Expand Down Expand Up @@ -1827,7 +1828,7 @@ class _SetExcludeDispatchKeyGuard:
# Defined in torch/csrc/utils/schema_info.h

class _SchemaInfo:
def __init__(self, schema: _int) -> None: ...
def __init__(self, schema: FunctionSchema) -> None: ...
@overload
def is_mutable(self) -> _bool: ...
@overload
Expand Down
4 changes: 2 additions & 2 deletions torch/_export/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import warnings
from collections.abc import Sequence
from contextlib import contextmanager
from typing import Any, Optional, Union
from typing import Any, Callable, Optional, Union

import torch
import torch.export._trace
Expand Down Expand Up @@ -229,7 +229,7 @@ def get_dtype_as_int(tensor):
# Those operators will be automatically populated to a instance method
# of TS2FXGraphConverter with name convert_<namespace>_<opname>().
# Please check __init__ for method population implementations.
kind_to_standard_operators = {
kind_to_standard_operators: dict[str, Callable[..., Any]] = {
"prim::max": builtins.max,
"prim::min": builtins.min,
"prim::TupleIndex": operator.getitem,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
aten = torch.ops.aten

_NON_FUNCTIONAL_TO_FUNCTIONAL_SIDE_EFFECTFUL_FUNCS: dict[OpOverload, OpOverload] = {
aten.sym_constrain_range.default: aten._functional_sym_constrain_range,
aten.sym_constrain_range.default: aten._functional_sym_constrain_range.default,
aten._assert_async.msg: aten._functional_assert_async.msg,
}

Expand Down
6 changes: 4 additions & 2 deletions torch/_functorch/partitioners.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import os.path
from collections import defaultdict
from dataclasses import dataclass, replace
from typing import Callable, Optional, TYPE_CHECKING, Union
from typing import Any, Callable, Optional, TYPE_CHECKING, Union

import torch
import torch._inductor.inductor_prims
Expand Down Expand Up @@ -2046,7 +2046,9 @@ def get_default_op_list() -> OpTypes:
default_recomputable_ops += [method_to_operator(m) for m in magic_methods]
recomputable_ops = OrderedSet(default_recomputable_ops)

random_ops = OrderedSet([aten.native_dropout, aten.rand_like, aten.randn_like])
random_ops = OrderedSet[Callable[..., Any]](
[aten.native_dropout, aten.rand_like, aten.randn_like]
)
compute_intensive_ops = [
aten.mm,
aten.convolution,
Expand Down
10 changes: 5 additions & 5 deletions torch/_higher_order_ops/effects.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ class _EffectType(Enum):
OpType = Union[torch._ops.HigherOrderOperator, torch._ops.OpOverload]


SIDE_EFFECTS: "WeakKeyDictionary[OpType, _EffectType]" = WeakKeyDictionary(
{
torch.ops.aten._print.default: _EffectType.ORDERED,
call_torchbind: _EffectType.ORDERED,
}
SIDE_EFFECTS = WeakKeyDictionary[OpType, _EffectType](
[
(torch.ops.aten._print.default, _EffectType.ORDERED),
(call_torchbind, _EffectType.ORDERED),
]
)


Expand Down
56 changes: 28 additions & 28 deletions torch/_inductor/decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@

# Remove unwanted decompositions included via the core ATen decompositions from
# the Inductor decomp table.
decomps_to_exclude = [
decomps_to_exclude: list[Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket]] = [
aten._unsafe_index,
aten._unsafe_masked_index,
aten._unsafe_masked_index_put_accumulate,
Expand Down Expand Up @@ -522,7 +522,7 @@ def fmax(self: torch.Tensor, other: torch.Tensor) -> torch.Tensor:
return torch.where(torch.isnan(other) | (other < self), self, other)


@register_decomposition(aten.amax)
@register_decomposition([aten.amax])
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was one of the places where the new typing paid off. register_decomposition expects a list of OperatorBase/OpOverloadPacket. The previous typing of _OpNamespace returned Any from __getattr__, so the implicit call in aten.amax couldn't be determined to not be iterable. Once that __getattr__ was explicitly typed to return OpOverloadPacket, typing failed here.

def amax(
self: torch.Tensor,
dim: Optional[int] = None,
Expand All @@ -533,7 +533,7 @@ def amax(
return NotImplemented


@register_decomposition(aten.amin)
@register_decomposition([aten.amin])
def amin(
self: torch.Tensor,
dim: Optional[int] = None,
Expand Down Expand Up @@ -581,7 +581,7 @@ def get_like_layout(
return memory_format


@register_decomposition(aten.rand_like)
@register_decomposition([aten.rand_like])
def rand_like(
self: torch.Tensor,
*,
Expand All @@ -598,7 +598,7 @@ def rand_like(
).to(memory_format=get_like_layout(self, memory_format))


@register_decomposition(aten.randn_like)
@register_decomposition([aten.randn_like])
def randn_like(
self: torch.Tensor,
*,
Expand All @@ -615,7 +615,7 @@ def randn_like(
).to(memory_format=get_like_layout(self, memory_format))


@register_decomposition(aten.full_like)
@register_decomposition([aten.full_like])
def full_like(
self: torch.Tensor,
fill_value: Union[int, float],
Expand All @@ -637,7 +637,7 @@ def full_like(
).to(memory_format=get_like_layout(self, memory_format))


@register_decomposition(aten.randint_like.default)
@register_decomposition([aten.randint_like.default])
def randint_like(
self: torch.Tensor,
high: int,
Expand All @@ -657,7 +657,7 @@ def randint_like(
).to(memory_format=get_like_layout(self, memory_format))


@register_decomposition(aten.randint_like.low_dtype)
@register_decomposition([aten.randint_like.low_dtype])
def randint_like_low(
self: torch.Tensor,
low: int,
Expand All @@ -678,7 +678,7 @@ def randint_like_low(
).to(memory_format=get_like_layout(self, memory_format))


@register_decomposition(aten.randint.default)
@register_decomposition([aten.randint.default])
def randint(
high: int,
size: list[Union[int, torch.SymInt]],
Expand All @@ -687,7 +687,7 @@ def randint(
return aten.randint.low(0, high, size, **kwargs)


@register_decomposition(quantized.linear_dynamic_fp16_unpacked_weight.default)
@register_decomposition([quantized.linear_dynamic_fp16_unpacked_weight.default])
def linear_dynamic_fp16_unpacked_weight(
input: torch.Tensor,
weight: torch.Tensor,
Expand All @@ -699,7 +699,7 @@ def linear_dynamic_fp16_unpacked_weight(
)


@register_decomposition(_quantized.wrapped_quantized_linear.default)
@register_decomposition([_quantized.wrapped_quantized_linear.default])
def wrapped_quantized_linear(
input: torch.Tensor,
input_scale: torch.Tensor,
Expand All @@ -726,7 +726,7 @@ def wrapped_quantized_linear(
)


@register_decomposition(torch.ops.quantized.embedding_bag_byte_unpack)
@register_decomposition([torch.ops.quantized.embedding_bag_byte_unpack])
def q_embedding_bag_byte_unpack_decomp(packed: torch.Tensor) -> torch.Tensor:
def bitcast_u8_to_f32(u8: torch.Tensor) -> torch.Tensor:
x, y, z, w = (u8[..., n].to(torch.int32) for n in (0, 1, 2, 3))
Expand Down Expand Up @@ -771,7 +771,7 @@ def grid_sampler_2d(
return output


@register_decomposition(aten._foreach_addcmul.Scalar)
@register_decomposition([aten._foreach_addcmul.Scalar])
def _foreach_addcmul_scalar(
self: list[torch.Tensor],
left_tensors: list[torch.Tensor],
Expand All @@ -783,7 +783,7 @@ def _foreach_addcmul_scalar(
)


@register_decomposition(aten._foreach_addcdiv.Scalar)
@register_decomposition([aten._foreach_addcdiv.Scalar])
def _foreach_addcdiv_scalar(
self: list[torch.Tensor],
left_tensors: list[torch.Tensor],
Expand All @@ -795,7 +795,7 @@ def _foreach_addcdiv_scalar(
)


@register_decomposition(aten._foreach_lerp.Scalar)
@register_decomposition([aten._foreach_lerp.Scalar])
def _foreach_lerp_scalar(
start_tensors: list[torch.Tensor],
end_tensors: list[torch.Tensor],
Expand All @@ -809,7 +809,7 @@ def _foreach_lerp_scalar(
)


@register_decomposition(aten._foreach_lerp.ScalarList)
@register_decomposition([aten._foreach_lerp.ScalarList])
def _foreach_lerp_scalarlist(
start_tensors: list[torch.Tensor],
end_tensors: list[torch.Tensor],
Expand All @@ -824,7 +824,7 @@ def _foreach_lerp_scalarlist(


@aten.miopen_batch_norm.default.py_impl(torch._C.DispatchKey.Autograd)
@register_decomposition(aten.miopen_batch_norm)
@register_decomposition([aten.miopen_batch_norm])
def miopen_batch_norm(
input: torch.Tensor,
weight: torch.Tensor,
Expand Down Expand Up @@ -869,7 +869,7 @@ def select_decomp_table() -> dict[Any, Callable[..., Any]]:
return fast_random_decomps()


@register_decomposition(aten.masked_scatter)
@register_decomposition([aten.masked_scatter])
def masked_scatter(
self: torch.Tensor,
mask: torch.Tensor,
Expand All @@ -888,7 +888,7 @@ def masked_scatter(
return NotImplemented


@register_decomposition(quantized_decomposed.choose_qparams.tensor)
@register_decomposition([quantized_decomposed.choose_qparams.tensor])
def choose_qparams_tensor(
input: torch.Tensor,
quant_min: int,
Expand All @@ -904,7 +904,7 @@ def choose_qparams_tensor(
return scale.to(torch.float64), zero_point.to(torch.int64)


@register_decomposition(aten.put)
@register_decomposition([aten.put])
def put(
self: torch.Tensor,
index: torch.Tensor,
Expand All @@ -918,7 +918,7 @@ def put(
return flattened.reshape(self.shape)


@register_decomposition(aten.put_)
@register_decomposition([aten.put_])
def put_(
self: torch.Tensor,
index: torch.Tensor,
Expand All @@ -929,7 +929,7 @@ def put_(
return self.copy_(out)


@register_decomposition(aten._softmax_backward_data.default)
@register_decomposition([aten._softmax_backward_data.default])
@pw_cast_for_opmath
def _softmax_backward_data(
grad_output: torch.Tensor,
Expand All @@ -951,7 +951,7 @@ def _softmax_backward_data(
return grad_input.contiguous()


@register_decomposition(aten.index_reduce)
@register_decomposition([aten.index_reduce])
def index_reduce(
self: torch.Tensor,
dim: int,
Expand Down Expand Up @@ -1056,7 +1056,7 @@ def _max_pool_with_indices(
return vals, indices


@register_decomposition(aten.max_pool2d_with_indices)
@register_decomposition([aten.max_pool2d_with_indices])
def max_pool2d_with_indices(
x: torch.Tensor,
kernel_size: list[int],
Expand All @@ -1070,7 +1070,7 @@ def max_pool2d_with_indices(
)


@register_decomposition(aten.max_pool3d_with_indices)
@register_decomposition([aten.max_pool3d_with_indices])
def max_pool3d_with_indices(
x: torch.Tensor,
kernel_size: list[int],
Expand All @@ -1084,7 +1084,7 @@ def max_pool3d_with_indices(
)


@register_decomposition(aten.adaptive_max_pool2d)
@register_decomposition([aten.adaptive_max_pool2d])
def adaptive_max_pool2d(
x: torch.Tensor, output_size: list[int]
) -> tuple[torch.Tensor, torch.Tensor]:
Expand All @@ -1102,7 +1102,7 @@ def adaptive_max_pool2d(
return NotImplemented


@register_decomposition(aten.searchsorted.Scalar)
@register_decomposition([aten.searchsorted.Scalar])
def searchsorted_scalar(
sorted_sequence: torch.Tensor,
self: torch.types.Number,
Expand All @@ -1122,7 +1122,7 @@ def searchsorted_scalar(
)[0]


@register_decomposition(aten.rrelu_with_noise_functional)
@register_decomposition([aten.rrelu_with_noise_functional])
def rrelu_with_noise_functional(
self: torch.Tensor,
noise: torch.Tensor,
Expand Down
6 changes: 4 additions & 2 deletions torch/_inductor/fx_passes/mkldnn_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import functools
import operator
from functools import reduce
from typing import Any
from typing import Any, Callable

import torch
from torch._dynamo.utils import counters
Expand Down Expand Up @@ -1345,7 +1345,9 @@ def linear(match, *args, **kwargs):
or V.aot_compilation
):
packed_linear_inputs += (bias, "none", [], "")
packed_linear_op = mkldnn._linear_pointwise.default
packed_linear_op: Callable[..., Any] = (
mkldnn._linear_pointwise.default
)
else:
packed_linear_inputs += (transpose_weight_node, bias, batch_size)
packed_linear_op = torch.ops.mkl._mkl_linear
Expand Down
Loading
Loading
0