-
Notifications
You must be signed in to change notification settings - Fork 24.2k
Improve torch.ops typing #153558
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Improve torch.ops typing #153558
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/153558
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit 192eac5 with merge base fa85434 ( UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
21e3328
to
521ba30
Compare
@@ -502,7 +502,7 @@ def fmax(self: torch.Tensor, other: torch.Tensor) -> torch.Tensor: | |||
return torch.where(torch.isnan(other) | (other < self), self, other) | |||
|
|||
|
|||
@register_decomposition(aten.amax) | |||
@register_decomposition([aten.amax]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was one of the places where the new typing paid off. register_decomposition
expects a list of OperatorBase
/OpOverloadPacket
. The previous typing of _OpNamespace
returned Any
from __getattr__
, so the implicit call in aten.amax
couldn't be determined to not be iterable. Once that __getattr__
was explicitly typed to return OpOverloadPacket
, typing failed here.
521ba30
to
5cc58cf
Compare
06041cf
to
029e299
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, looking back at these comments, they're all quibble of some variety or other, so if you ignored them all, I wouldn't care.
This seems like a solid step forward for typing which will have positive ramifications all over the code. I'm a little surprised there weren't even more changes in distant files!
Good stuff.
Oh, and tangentially relevant: in another code review I have this
It passed review but I'm not entirely happy with the name - do you have a better one, given that you're thinking about these sorts of things? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Aside from the naked callables, this looks great.
@rec You could simply use |
68766f5
to
d483986
Compare
torch/_ops.py
Outdated
def __init__( | ||
self, | ||
overloadpacket: "OpOverloadPacket", | ||
op: Callable[..., Any], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
still naked callables here, not great.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Skylion007 I would prefer to switch it to something more concrete, but after a good bit of experimenting I was unable to find anything other than Callable[..., Any]
that worked in all the ways this was invoked. The next best solution was Callable[_P, _T]
, with op_dk
getting Callable[_Q, _U]
, which didn't seem particularly informative to anyone, since it specified no similarities between the two callables.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, I think I fundamentally misunderstood something I was trying to use to inform my typing of this. Hang on a minute, I may have a type for this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Skylion007 I'm about to push something that explicitly links the types of op
and op_dk
here in this initializer only. I tried locally to make things even more generic, but I think we should stop here:
- Genericizing
op
in the initializer forOpOverloadPacket
doesn't add any additional clarity, since the generics aren't linked to anything that couldn't be inferred otherwise. - Fully genericizing the classes
OpOverload
andOpOverloadPacket
solves problem 1, but you have to apply generics everywhere in the codebase that references either class in a type. That's a lot of boilerplate code for something that will almost always be filled in with generic types (since very rarely do we a priori know which arg types and return types to expect in functions being typed with these classes).
self, | ||
qualified_op_name: str, | ||
op_name: str, | ||
op: Callable[..., Any], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See above.
Fixes longstanding issue where direct references to aten operations are seen as untyped by type checkers. This is accomplished by setting attributes on several classes more consistently, so that `__getattr__` can return a single type in all other cases. Decisions made along the way: 1. `torch.ops.higher_order` is now implemented by a single-purpose class. This was effectively true before, but the class implementing it attempted to be generalized unnecessarily. Fixing this simplified typing for the `_Ops` class. 2. `__getattr__` is only called when all other lookup methods have failed, so several constant special-cases in the function could be implemented as class variables. The remainder of this PR is fixing up all the bugs exposed by the updated typing, as well as all the nitpicky typing issues.
d483986
to
192eac5
Compare
Fixes longstanding issue where direct references to aten operations are seen as untyped by type checkers. This is accomplished by setting attributes on several classes more consistently, so that
__getattr__
can return a single type in all other cases.Decisions made along the way:
torch.ops.higher_order
is now implemented by a single-purpose class. This was effectively true before, but the class implementing it attempted to be generalized unnecessarily. Fixing this simplified typing for the_Ops
class.__getattr__
is only called when all other lookup methods have failed, so several constant special-cases in the function could be implemented as class variables.The remainder of this PR is fixing up all the bugs exposed by the updated typing, as well as all the nitpicky typing issues.
Test plan: CI
cc @ezyang @SherlockNoMad @EikanWang @jgong5 @wenzhe-nrv @voznesenskym @penguinwu @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov