-
Notifications
You must be signed in to change notification settings - Fork 24.7k
[MPS] Make torch.mps.compile_shader
public
#148972
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
It was a private method in 2.6, but nothin changes in its API for 2.7 and it will likely remain the same in 2.8, so time to remove underscore from its name
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/148972
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 54 PendingAs of commit 1eb532b with merge base 0902901 ( NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
According to https://github.com/search?q=%2Fmps%5C._compile_shader%2F+-path%3A%2Ftorch%5C%2F_inductor%5C%2F*%2F+-path%3A%2Ftorch%5C%2Fmps%5C%2F%2F&type=code the only other project using this is yours, so sounds good to hard change!
@@ -140,13 +140,13 @@ def recommended_max_memory() -> int: | |||
return torch._C._mps_recommendedMaxMemory() | |||
|
|||
|
|||
def _compile_shader(source: str): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we have more detailed doc on the argument conversion since it is not the same as our usual python bindings?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will do, in a follow up PR (want to technically land this one before branch cut)
Merge failedReason: 1 mandatory check(s) failed. The first few are: Dig deeper by viewing the failures on hud |
@pytorchbot merge -f "All relevant signals are green" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Followup work on top #149480 Wrapper on top of nvrtc inspired by https://gist.github.com/malfet/2c9a25976dd7396430c38af603f791da from @malfet Compiling toy kernels with this setup takes 0.01s vs 90s using `load_inline()` on my local H100. This was primarily motivated by the timeouts I was seeing in the popcorn leaderboard but would also be useful to integrate into KernelBench This PR is in the same spirit as #148972 which was a similar UX for Metal For now we are planning on landing this as a private function because we expect to iterate both on the user facing API and the internals implementation, will open up a seperate issue to discuss the path towards making this work public and give a broader overview of the state of custom cuda kernel authoring in PyTorch Future work, as a prereq to making the work public * divup primitive * support multiple kernels * Expose _get_nvrtc_version from native code * interop with torch.compile * AMD support Pull Request resolved: #151484 Approved by: https://github.com/malfet
It was a private method in 2.6, but nothing changes in its APIs for 2.7
and it will likely remain the same in 2.8, so time to remove underscore from its name
This allows one to author/invoke shaders directly from PyTorch, for example code below implements an increment by thread index: