8000 Make fx.node.map_arg() and .map_aggregate() generic by rec · Pull Request #146248 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Make fx.node.map_arg() and .map_aggregate() generic #146248

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 27 commits into from
Closed
Prev Previous commit
Next Next commit
Update
[ghstack-poisoned]
  • Loading branch information
rec committed Feb 13, 2025
commit e1bf862981f69afca4a7f8ce29accffd3b1c9f9c
12 changes: 9 additions & 3 deletions torch/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from collections import OrderedDict
from copy import deepcopy
from numbers import Number
from typing import Any, Optional, Union
from typing import Any, Callable, cast, Optional, Union

import torch
import torch._C as _C
Expand Down Expand Up @@ -1104,8 +1104,14 @@ def __rdiv__(self, other):
__rtruediv__ = __rdiv__
__itruediv__ = _C.TensorBase.__idiv__

__pow__ = _handle_torch_function_and_wrap_type_error_to_not_implemented(
_C.TensorBase.pow
__pow__ = cast(
Callable[
["torch._C.TensorBase", Union["Tensor", int, float, bool, complex]],
"Tensor",
],
_handle_torch_function_and_wrap_type_error_to_not_implemented(
_C.TensorBase.pow
),
)
__ipow__ = _handle_torch_function_and_wrap_type_error_to_not_implemented(
_C.TensorBase.pow_
Expand Down
Loading
0