8000 Improve torch.ops typing by benjaminglass1 · Pull Request #153558 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Improve torch.ops typing #153558

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

benjaminglass1
Copy link
Collaborator
@benjaminglass1 benjaminglass1 commented May 14, 2025

Fixes longstanding issue where direct references to aten operations are seen as untyped by type checkers. This is accomplished by setting attributes on several classes more consistently, so that __getattr__ can return a single type in all other cases.

Decisions made along the way:

  1. torch.ops.higher_order is now implemented by a single-purpose class. This was effectively true before, but the class implementing it attempted to be generalized unnecessarily. Fixing this simplified typing for the _Ops class.
  2. __getattr__ is only called when all other lookup methods have failed, so several constant special-cases in the function could be implemented as class variables.

The remainder of this PR is fixing up all the bugs exposed by the updated typing, as well as all the nitpicky typing issues.

Test plan: CI

cc @ezyang @SherlockNoMad @EikanWang @jgong5 @wenzhe-nrv @voznesenskym @penguinwu @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov

@benjaminglass1 benjaminglass1 requested a review from amjames May 14, 2025 17:37
@benjaminglass1 benjaminglass1 self-assigned this May 14, 2025
Copy link
pytorch-bot bot commented May 14, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/153558

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 192eac5 with merge base fa85434 (image):

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@@ -502,7 +502,7 @@ def fmax(self: torch.Tensor, other: torch.Tensor) -> torch.Tensor:
return torch.where(torch.isnan(other) | (other < self), self, other)


@register_decomposition(aten.amax)
@register_decomposition([aten.amax])
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was one of the places where the new typing paid off. register_decomposition expects a list of OperatorBase/OpOverloadPacket. The previous typing of _OpNamespace returned Any from __getattr__, so the implicit call in aten.amax couldn't be determined to not be iterable. Once that __getattr__ was explicitly typed to return OpOverloadPacket, typing failed here.

@benjaminglass1 benjaminglass1 force-pushed the benjaminglass1/improve_opoverload_typing branch from 521ba30 to 5cc58cf Compare May 14, 2025 19:37
@benjaminglass1 benjaminglass1 force-pushed the benjaminglass1/improve_opoverload_typing branch 2 times, most recently from 06041cf to 029e299 Compare May 14, 2025 23:37
Copy link
Collaborator
@rec rec left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, looking back at these comments, they're all quibble of some variety or other, so if you ignored them all, I wouldn't care.

This seems like a solid step forward for typing which will have positive ramifications all over the code. I'm a little surprised there weren't even more changes in distant files!

Good stuff.

@rec
Copy link
Collaborator
rec commented May 15, 2025

Oh, and tangentially relevant: in another code review I have this TypeAlias in _inductor/ir.py.

_OpOverloads: TypeAlias = Union[torch._ops.OpOverload, torch._ops.HigherOrderOperator]

It passed review but I'm not entirely happy with the name - do you have a better one, given that you're thinking about these sorts of things?

Copy link
Collaborator
@Skylion007 Skylion007 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aside from the naked callables, this looks great.

@benjaminglass1
Copy link
Collaborator Author

Oh, and tangentially relevant: in another code review I have this TypeAlias in _inductor/ir.py.

_OpOverloads: TypeAlias = Union[torch._ops.OpOverload, torch._ops.HigherOrderOperator]

It passed review but I'm not entirely happy with the name - do you have a better one, given that you're thinking about these sorts of things?

@rec You could simply use torch._ops.OperatorBase, I believe?

@benjaminglass1 benjaminglass1 force-pushed the benjaminglass1/improve_opoverload_typing branch 2 times, most recently from 68766f5 to d483986 Compare May 15, 2025 17:55
@benjaminglass1 benjaminglass1 marked this pull request as ready for review May 15, 2025 17:56
@benjaminglass1 benjaminglass1 requested a review from zou3519 as a code owner May 15, 2025 17:56
@benjaminglass1 benjaminglass1 added the ciflow/trunk Trigger trunk jobs on your pull request label May 15, 2025
torch/_ops.py Outdated
def __init__(
self,
overloadpacket: "OpOverloadPacket",
op: Callable[..., Any],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

still naked callables here, not great.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Skylion007 I would prefer to switch it to something more concrete, but after a good bit of experimenting I was unable to find anything other than Callable[..., Any] that worked in all the ways this was invoked. The next best solution was Callable[_P, _T], with op_dk getting Callable[_Q, _U], which didn't seem particularly informative to anyone, since it specified no similarities between the two callables.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I think I fundamentally misunderstood something I was trying to use to inform my typing of this. Hang on a minute, I may have a type for this.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Skylion007 I'm about to push something that explicitly links the types of op and op_dk here in this initializer only. I tried locally to make things even more generic, but I think we should stop here:

  1. Genericizing op in the initializer for OpOverloadPacket doesn't add any additional clarity, since the generics aren't linked to anything that couldn't be inferred otherwise.
  2. Fully genericizing the classes OpOverload and OpOverloadPacket solves problem 1, but you have to apply generics everywhere in the codebase that references either class in a type. That's a lot of boilerplate code for something that will almost always be filled in with generic types (since very rarely do we a priori know which arg types and return types to expect in functions being typed with these classes).

self,
qualified_op_name: str,
op_name: str,
op: Callable[..., Any],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

9E7A

See above.

Fixes longstanding issue where direct references to aten operations are seen as untyped by type checkers. This is accomplished by setting attributes on several classes more consistently, so that `__getattr__` can return a single type in all other cases.

Decisions made along the way:

1. `torch.ops.higher_order` is now implemented by a single-purpose class. This was effectively true before, but the class implementing it attempted to be generalized unnecessarily. Fixing this simplified typing for the `_Ops` class.
2. `__getattr__` is only called when all other lookup methods have failed, so several constant special-cases in the function could be implemented as class variables.

The remainder of this PR is fixing up all the bugs exposed by the updated typing, as well as all the nitpicky typing issues.
@benjaminglass1 benjaminglass1 force-pushed the benjaminglass1/improve_opoverload_typing branch from d483986 to 192eac5 Compare May 15, 2025 18:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants
0