-
Notifications
You must be signed in to change notification settings - Fork 24.3k
[import][inductor] Simplify grid handling #147583
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/147583
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit 1bcc2d2 with merge base ce2f680 ( FLAKY - The following job failed but was likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@jansel has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
@jansel has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
01158b0
to
bb7c593
Compare
@jansel has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
2 similar comments
@jansel has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
@jansel has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
This pull request was exported from Phabricator. Differential Revision: D69965761 |
bb7c593
to
e5e9274
Compare
Summary: Copy of pytorch#147251 Before this PR, calling a triton kernel would look like: ```py kernel.run(a, b, xnumel, grid=grid(xnumel), stream=stream0) ``` where the `grid=` was passed as a callable (function closure) arg. This PR removes the grid arg: ```py kernel.run(a, b, xnumel, stream=stream0) ``` instead now the grid computation is included in the kernel launcher, with something like: ```py def launcher(in_ptr0, out_ptr0, xnumel, stream): grid_0 = ((xnumel + 1023) >> 10) grid_1 = 1 grid_2 = 1 runner(grid_0, grid_1, grid_2, stream, function, metadata, None, launch_enter_hook, launch_exit_hook, in_ptr0, out_ptr0, xnumel) ``` This should be faster, since we remove multiple function/dict calls and are able to specialize the grid computation for each `triton.Config`. It also allows us to unify the handling of grids between the Python and C++ wrapper code. Before this, C++ wrapper code didn't actually support dynamic grid sizes and instead burned in a static grid. This unification allows this PR to be a net deletion of code. cc jeffdaily sunway513 jithunnair-amd pruthvistony ROCmSupport dllehr-amd jataylo hongxiayang naromero77amd albanD voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 amjames chauhang aakhundov Pull Request resolved: pytorch#147583 Differential Revision: D69965761 Pulled By: jansel
This pull request was exported from Phabricator. Differential Revision: D69965761 |
e5e9274
to
20ec5c1
Compare
Summary: Before this PR, calling a triton kernel would look like: ```py kernel.run(a, b, xnumel, grid=grid(xnumel), stream=stream0) ``` where the `grid=` was passed as a callable (function closure) arg. This PR removes the grid arg: ```py kernel.run(a, b, xnumel, stream=stream0) ``` instead now the grid computation is included in the kernel launcher, with something like: ```py def launcher(in_ptr0, out_ptr0, xnumel, stream): grid_0 = ((xnumel + 1023) >> 10) grid_1 = 1 grid_2 = 1 runner(grid_0, grid_1, grid_2, stream, function, metadata, None, launch_enter_hook, launch_exit_hook, in_ptr0, out_ptr0, xnumel) ``` This should be faster, since we remove multiple function/dict calls and are able to specialize the grid computation for each `triton.Config`. It also allows us to unify the handling of grids between the Python and C++ wrapper code. Before this, C++ wrapper code didn't actually support dynamic grid sizes and instead burned in a static grid. This unification allows this PR to be a net deletion of code. cc jeffdaily sunway513 jithunnair-amd pruthvistony ROCmSupport dllehr-amd jataylo hongxiayang naromero77amd albanD voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 amjames chauhang aakhundov Pull Request resolved: pytorch#147583 Differential Revision: D69965761 Pulled By: jansel
This pull request was exported from Phabricator. Differential Revision: D69965761 |
Summary: Before this PR, calling a triton kernel would look like: ```py kernel.run(a, b, xnumel, grid=grid(xnumel), stream=stream0) ``` where the `grid=` was passed as a callable (function closure) arg. This PR removes the grid arg: ```py kernel.run(a, b, xnumel, stream=stream0) ``` instead now the grid computation is included in the kernel launcher, with something like: ```py def launcher(in_ptr0, out_ptr0, xnumel, stream): grid_0 = ((xnumel + 1023) >> 10) grid_1 = 1 grid_2 = 1 runner(grid_0, grid_1, grid_2, stream, function, metadata, None, launch_enter_hook, launch_exit_hook, in_ptr0, out_ptr0, xnumel) ``` This should be faster, since we remove multiple function/dict calls and are able to specialize the grid computation for each `triton.Config`. It also allows us to unify the handling of grids between the Python and C++ wrapper code. Before this, C++ wrapper code didn't actually support dynamic grid sizes and instead burned in a static grid. This unification allows this PR to be a net deletion of code. cc jeffdaily sunway513 jithunnair-amd pruthvistony ROCmSupport dllehr-amd jataylo hongxiayang naromero77amd albanD voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 amjames chauhang aakhundov Pull Request resolved: pytorch#147583 Differential Revision: D69965761 Pulled By: jansel
20ec5c1
to
1403b3c
Compare
This pull request was exported from Phabricator. Differential Revision: D69965761 |
1403b3c
to
7e9e424
Compare
Summary: Before this PR, calling a triton kernel would look like: ```py kernel.run(a, b, xnumel, grid=grid(xnumel), stream=stream0) ``` where the `grid=` was passed as a callable (function closure) arg. This PR removes the grid arg: ```py kernel.run(a, b, xnumel, stream=stream0) ``` instead now the grid computation is included in the kernel launcher, with something like: ```py def launcher(in_ptr0, out_ptr0, xnumel, stream): grid_0 = ((xnumel + 1023) >> 10) grid_1 = 1 grid_2 = 1 runner(grid_0, grid_1, grid_2, stream, function, metadata, None, launch_enter_hook, launch_exit_hook, in_ptr0, out_ptr0, xnumel) ``` This should be faster, since we remove multiple function/dict calls and are able to specialize the grid computation for each `triton.Config`. It also allows us to unify the handling of grids between the Python and C++ wrapper code. Before this, C++ wrapper code didn't actually support dynamic grid sizes and instead burned in a static grid. This unification allows this PR to be a net deletion of code. cc jeffdaily sunway513 jithunnair-amd pruthvistony ROCmSupport dllehr-amd jataylo hongxiayang naromero77amd albanD voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 amjames chauhang aakhundov Pull Request resolved: pytorch#147583 Differential Revision: D69965761 Pulled By: jansel
This pull request was exported from Phabricator. Differential Revision: D69965761 |
@pytorchbot revert -m="Diff reverted internally" -c="ghfirst" This Pull Request has been reverted by a revert inside Meta. To re-land this change, please open another pull request, assign the same reviewers, fix the CI failures that caused the revert and make sure that the failing CI runs on the PR by applying the proper ciflow label (e.g., ciflow/trunk).) |
@pytorchbot successfully started a revert job. Check the current status here. |
@jansel your PR has been successfully reverted. |
This reverts commit b59776d. Reverted #147583 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](#147583 (comment)))
Summary: Relands D69965761 / pytorch#147583 Before this PR, calling a triton kernel would look like: ```py kernel.run(a, b, xnumel, grid=grid(xnumel), stream=stream0) ``` where the `grid=` was passed as a callable (function closure) arg. This PR removes the grid arg: ```py kernel.run(a, b, xnumel, stream=stream0) ``` instead now the grid computation is included in the kernel launcher, with something like: ```py def launcher(in_ptr0, out_ptr0, xnumel, stream): grid_0 = ((xnumel + 1023) >> 10) grid_1 = 1 grid_2 = 1 runner(grid_0, grid_1, grid_2, stream, function, metadata, None, launch_enter_hook, launch_exit_hook, in_ptr0, out_ptr0, xnumel) ``` This should be faster, since we remove multiple function/dict calls and are able to specialize the grid computation for each `triton.Config`. It also allows us to unify the handling of grids between the Python and C++ wrapper code. Before this, C++ wrapper code didn't actually support dynamic grid sizes and instead burned in a static grid. This unification allows this PR to be a net deletion of code. Differential Revision: D70471332
Summary: Relands D69965761 / pytorch#147583 Before this PR, calling a triton kernel would look like: ```py kernel.run(a, b, xnumel, grid=grid(xnumel), stream=stream0) ``` where the `grid=` was passed as a callable (function closure) arg. This PR removes the grid arg: ```py kernel.run(a, b, xnumel, stream=stream0) ``` instead now the grid computation is included in the kernel launcher, with something like: ```py def launcher(in_ptr0, out_ptr0, xnumel, stream): grid_0 = ((xnumel + 1023) >> 10) grid_1 = 1 grid_2 = 1 runner(grid_0, grid_1, grid_2, stream, function, metadata, None, launch_enter_hook, launch_exit_hook, in_ptr0, out_ptr0, xnumel) ``` This should be faster, since we remove multiple function/dict calls and are able to specialize the grid computation for each `triton.Config`. It also allows us to unify the handling of grids between the Python and C++ wrapper code. Before this, C++ wrapper code didn't actually support dynamic grid sizes and instead burned in a static grid. This unification allows this PR to be a net deletion of code. Differential Revision: D70471332
Summary: Pull Request resolved: pytorch#148305 Relands D69965761 / pytorch#147583 Before this PR, calling a triton kernel would look like: ```py kernel.run(a, b, xnumel, grid=grid(xnumel), stream=stream0) ``` where the `grid=` was passed as a callable (function closure) arg. This PR removes the grid arg: ```py kernel.run(a, b, xnumel, stream=stream0) ``` instead now the grid computation is included in the kernel launcher, with something like: ```py def launcher(in_ptr0, out_ptr0, xnumel, stream): grid_0 = ((xnumel + 1023) >> 10) grid_1 = 1 grid_2 = 1 runner(grid_0, grid_1, grid_2, stream, function, metadata, None, launch_enter_hook, launch_exit_hook, in_ptr0, out_ptr0, xnumel) ``` This should be faster, since we remove multiple function/dict calls and are able to specialize the grid computation for each `triton.Config`. It also allows us to unify the handling of grids between the Python and C++ wrapper code. Before this, C++ wrapper code didn't actually support dynamic grid sizes and instead burned in a static grid. This unification allows this PR to be a net deletion of code. Differential Revision: D70471332
Before this PR, calling a triton kernel would look like: ```py kernel.run(a, b, xnumel, grid=grid(xnumel), stream=stream0) ``` where the `grid=` was passed as a callable (function closure) arg. This PR removes the grid arg: ```py kernel.run(a, b, xnumel, stream=stream0) ``` instead now the grid computation is included in the kernel launcher, with something like: ```py def launcher(in_ptr0, out_ptr0, xnumel, stream): grid_0 = ((xnumel + 1023) >> 10) grid_1 = 1 grid_2 = 1 runner(grid_0, grid_1, grid_2, stream, function, metadata, None, launch_enter_hook, launch_exit_hook, in_ptr0, out_ptr0, xnumel) ``` This should be faster, since we remove multiple function/dict calls and are able to specialize the grid computation for each `triton.Config`. It also allows us to unify the handling of grids between the Python and C++ wrapper code. Before this, C++ wrapper code didn't actually support dynamic grid sizes and instead burned in a static grid. This unification allows this PR to be a net deletion of code. Note the attached diff contains some minor fbcode-only changes. Pull Request resolved: pytorch#147583 Approved by: https://github.com/eellison, https://github.com/shunting314
This reverts commit b59776d. Reverted pytorch#147583 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](pytorch#147583 (comment)))
Summary: Pull Request resolved: #148305 Relands D69965761 / #147583 Before this PR, calling a triton kernel would look like: ```py kernel.run(a, b, xnumel, grid=grid(xnumel), stream=stream0) ``` where the `grid=` was passed as a callable (function closure) arg. This PR removes the grid arg: ```py kernel.run(a, b, xnumel, stream=stream0) ``` instead now the grid computation is included in the kernel launcher, with something like: ```py def launcher(in_ptr0, out_ptr0, xnumel, stream): grid_0 = ((xnumel + 1023) >> 10) grid_1 = 1 grid_2 = 1 runner(grid_0, grid_1, grid_2, stream, function, metadata, None, launch_enter_hook, launch_exit_hook, in_ptr0, out_ptr0, xnumel) ``` This should be faster, since we remove multiple function/dict calls and are able to specialize the grid computation for each `triton.Config`. It also allows us to unify the handling of grids between the Python and C++ wrapper code. Before this, C++ wrapper code didn't actually support dynamic grid sizes and instead burned in a static grid. This unification allows this PR to be a net deletion of code. Reviewed By: shunting314, jianyuh Differential Revision: D70471332
Summary: Pull Request resolved: pytorch#148305 Relands D69965761 / pytorch#147583 Before this PR, calling a triton kernel would look like: ```py kernel.run(a, b, xnumel, grid=grid(xnumel), stream=stream0) ``` where the `grid=` was passed as a callable (function closure) arg. This PR removes the grid arg: ```py kernel.run(a, b, xnumel, stream=stream0) ``` instead now the grid computation is included in the kernel launcher, with something like: ```py def launcher(in_ptr0, out_ptr0, xnumel, stream): grid_0 = ((xnumel + 1023) >> 10) grid_1 = 1 grid_2 = 1 runner(grid_0, grid_1, grid_2, stream, function, metadata, None, launch_enter_hook, launch_exit_hook, in_ptr0, out_ptr0, xnumel) ``` This should be faster, since we remove multiple function/dict calls and are able to specialize the grid computation for each `triton.Config`. It also allows us to unify the handling of grids between the Python and C++ wrapper code. Before this, C++ wrapper code didn't actually support dynamic grid sizes and instead burned in a static grid. This unification allows this PR to be a net deletion of code. Reviewed By: eellison, shunting314, jianyuh Differential Revision: D70471332
Summary: Pull Request resolved: pytorch#148305 Relands D69965761 / pytorch#147583 Before this PR, calling a triton kernel would look like: ```py kernel.run(a, b, xnumel, grid=grid(xnumel), stream=stream0) ``` where the `grid=` was passed as a callable (function closure) arg. This PR removes the grid arg: ```py kernel.run(a, b, xnumel, stream=stream0) ``` instead now the grid computation is included in the kernel launcher, with something like: ```py def launcher(in_ptr0, out_ptr0, xnumel, stream): grid_0 = ((xnumel + 1023) >> 10) grid_1 = 1 grid_2 = 1 runner(grid_0, grid_1, grid_2, stream, function, metadata, None, launch_enter_hook, launch_exit_hook, in_ptr0, out_ptr0, xnumel) ``` This should be faster, since we remove multiple function/dict calls and are able to specialize the grid computation for each `triton.Config`. It also allows us to unify the handling of grids between the Python and C++ wrapper code. Before this, C++ wrapper code didn't actually support dynamic grid sizes and instead burned in a static grid. This unification allows this PR to be a net deletion of code. Reviewed By: eellison, shunting314, jianyuh Differential Revision: D70471332
Summary: Pull Request resolved: pytorch#148305 Relands D69965761 / pytorch#147583 Before this PR, calling a triton kernel would look like: ```py kernel.run(a, b, xnumel, grid=grid(xnumel), stream=stream0) ``` where the `grid=` was passed as a callable (function closure) arg. This PR removes the grid arg: ```py kernel.run(a, b, xnumel, stream=stream0) ``` instead now the grid computation is included in the kernel launcher, with something like: ```py def launcher(in_ptr0, out_ptr0, xnumel, stream): grid_0 = ((xnumel + 1023) >> 10) grid_1 = 1 grid_2 = 1 runner(grid_0, grid_1, grid_2, stream, function, metadata, None, launch_enter_hook, launch_exit_hook, in_ptr0, out_ptr0, xnumel) ``` This should be faster, since we remove multiple function/dict calls and are able to specialize the grid computation for each `triton.Config`. It also allows us to unify the handling of grids between the Python and C++ wrapper code. Before this, C++ wrapper code didn't actually support dynamic grid sizes and instead burned in a static grid. This unification allows this PR to be a net deletion of code. Reviewed By: eellison, shunting314, jianyuh Differential Revision: D70471332
Summary: Pull Request resolved: pytorch#148305 Relands D69965761 / pytorch#147583 Before this PR, calling a triton kernel would look like: ```py kernel.run(a, b, xnumel, grid=grid(xnumel), stream=stream0) ``` where the `grid=` was passed as a callable (function closure) arg. This PR removes the grid arg: ```py kernel.run(a, b, xnumel, stream=stream0) ``` instead now the grid computation is included in the kernel launcher, with something like: ```py def launcher(in_ptr0, out_ptr0, xnumel, stream): grid_0 = ((xnumel + 1023) >> 10) grid_1 = 1 grid_2 = 1 runner(grid_0, grid_1, grid_2, stream, function, metadata, None, launch_enter_hook, launch_exit_hook, in_ptr0, out_ptr0, xnumel) ``` This should be faster, since we remove multiple function/dict calls and are able to specialize the grid computation for each `triton.Config`. It also allows us to unify the handling of grids between the Python and C++ wrapper code. Before this, C++ wrapper code didn't actually support dynamic grid sizes and instead burned in a static grid. This unification allows this PR to be a net deletion of code. Reviewed By: eellison, shunting314, jianyuh Differential Revision: D70471332
Summary: Pull Request resolved: pytorch#148305 Relands D69965761 / pytorch#147583 Before this PR, calling a triton kernel would look like: ```py kernel.run(a, b, xnumel, grid=grid(xnumel), stream=stream0) ``` where the `grid=` was passed as a callable (function closure) arg. This PR removes the grid arg: ```py kernel.run(a, b, xnumel, stream=stream0) ``` instead now the grid computation is included in the kernel launcher, with something like: ```py def launcher(in_ptr0, out_ptr0, xnumel, stream): grid_0 = ((xnumel + 1023) >> 10) grid_1 = 1 grid_2 = 1 runner(grid_0, grid_1, grid_2, stream, function, metadata, None, launch_enter_hook, launch_exit_hook, in_ptr0, out_ptr0, xnumel) ``` This should be faster, since we remove multiple function/dict calls and are able to specialize the grid computation for each `triton.Config`. It also allows us to unify the handling of grids between the Python and C++ wrapper code. Before this, C++ wrapper code didn't actually support dynamic grid sizes and instead burned in a static grid. This unification allows this PR to be a net deletion of code. Reviewed By: eellison, shunting314, jianyuh Differential Revision: D70471332
Summary: X-link: pytorch/pytorch#148305 Relands D69965761 / pytorch/pytorch#147583 Before this PR, calling a triton kernel would look like: ```py kernel.run(a, b, xnumel, grid=grid(xnumel), stream=stream0) ``` where the `grid=` was passed as a callable (function closure) arg. This PR removes the grid arg: ```py kernel.run(a, b, xnumel, stream=stream0) ``` instead now the grid computation is included in the kernel launcher, with something like: ```py def launcher(in_ptr0, out_ptr0, xnumel, stream): grid_0 = ((xnumel + 1023) >> 10) grid_1 = 1 grid_2 = 1 runner(grid_0, grid_1, grid_2, stream, function, metadata, None, launch_enter_hook, launch_exit_hook, in_ptr0, out_ptr0, xnumel) ``` This should be faster, since we remove multiple function/dict calls and are able to specialize the grid computation for each `triton.Config`. It also allows us to unify the handling of grids between the Python and C++ wrapper code. Before this, C++ wrapper code didn't actually support dynamic grid sizes and instead burned in a static grid. This unification allows this PR to be a net deletion of code. Reviewed By: eellison, shunting314, jianyuh Differential Revision: D70471332 fbshipit-source-id: ffebb07b3f44b4af65f0f9f1c9219c71d802969a
Summary: Relands D69965761 / #147583 Before this PR, calling a triton kernel would look like: ```py kernel.run(a, b, xnumel, grid=grid(xnumel), stream=stream0) ``` where the `grid=` was passed as a callable (function closure) arg. This PR removes the grid arg: ```py kernel.run(a, b, xnumel, stream=stream0) ``` instead now the grid computation is included in the kernel launcher, with something like: ```py def launcher(in_ptr0, out_ptr0, xnumel, stream): grid_0 = ((xnumel + 1023) >> 10) grid_1 = 1 grid_2 = 1 runner(grid_0, grid_1, grid_2, stream, function, metadata, None, launch_enter_hook, launch_exit_hook, in_ptr0, out_ptr0, xnumel) ``` This should be faster, since we remove multiple function/dict calls and are able to specialize the grid computation for each `triton.Config`. It also allows us to unify the handling of grids between the Python and C++ wrapper code. Before this, C++ wrapper code didn't actually support dynamic grid sizes and instead burned in a static grid. This unification allows this PR to be a net deletion of code. Differential Revision: D70471332 Pull Request resolved: #148305 Approved by: https://github.com/shunting314, https://github.com/eellison
Summary: Pull Request resolved: pytorch#148305 Relands D69965761 / pytorch#147583 Before this PR, calling a triton kernel would look like: ```py kernel.run(a, b, xnumel, grid=grid(xnumel), stream=stream0) ``` where the `grid=` was passed as a callable (function closure) arg. This PR removes the grid arg: ```py kernel.run(a, b, xnumel, stream=stream0) ``` instead now the grid computation is included in the kernel launcher, with something like: ```py def launcher(in_ptr0, out_ptr0, xnumel, stream): grid_0 = ((xnumel + 1023) >> 10) grid_1 = 1 grid_2 = 1 runner(grid_0, grid_1, grid_2, stream, function, metadata, None, launch_enter_hook, launch_exit_hook, in_ptr0, out_ptr0, xnumel) ``` This should be faster, since we remove multiple function/dict calls and are able to specialize the grid computation for each `triton.Config`. It also allows us to unify the handling of grids between the Python and C++ wrapper code. Before this, C++ wrapper code didn't actually support dynamic grid sizes and instead burned in a static grid. This unification allows this PR to be a net deletion of code. Reviewed By: eellison, shunting314, jianyuh Differential Revision: D70471332
Summary: Relands D69965761 / #147583 Before this PR, calling a triton kernel would look like: ```py kernel.run(a, b, xnumel, grid=grid(xnumel), stream=stream0) ``` where the `grid=` was passed as a callable (function closure) arg. This PR removes the grid arg: ```py kernel.run(a, b, xnumel, stream=stream0) ``` instead now the grid computation is included in the kernel launcher, with something like: ```py def launcher(in_ptr0, out_ptr0, xnumel, stream): grid_0 = ((xnumel + 1023) >> 10) grid_1 = 1 grid_2 = 1 runner(grid_0, grid_1, grid_2, stream, function, metadata, None, launch_enter_hook, launch_exit_hook, in_ptr0, out_ptr0, xnumel) ``` This should be faster, since we remove multiple function/dict calls and are able to specialize the grid computation for each `triton.Config`. It also allows us to unify the handling of grids between the Python and C++ wrapper code. Before this, C++ wrapper code didn't actually support dynamic grid sizes and instead burned in a static grid. This unification allows this PR to be a net deletion of code. Differential [disconnected] Revision: D70471332 Pull Request resolved: #148305 Approved by: https://github.com/shunting314, https://github.com/eellison
…ed CUDA kernels (#148561) Putting this up for a first pass review, though I will likely make a bunch of changes before landing to add more features, etc. This diff implements a first version of a static CUDA kernel launcher in `torch._C`. The goal here is to take a cubin file and some metadata from a CompiledKernel from `triton`, and launch the cubin file directly. Background doc: https://docs.google.com/document/d/1rjRcHl6MfauHG30nCoQX-9UKvKyIs4WWMy_GsGyqb9g/edit?tab=t.0#heading=h.ut5lf39lzq66 Normally, using triton's CompiledKernel.make_launcher(), we would pay the cost of codegenning C++ and running it at compile time. With this new approach, we can use one statically compiled library to launch the kernel. The tradeoff here is that this new kernel launcher will not be able to use codegen to deal with different lengths/types of arguments. So we use templating to handle up to 10 arguments for now. We also allocate 8 bytes on the stack per argument no matter the argument type, which can take more memory than codegenning. On the other hand, we improve compile time on cold and warm start by not having to call the C++ compiler at all. This diff does not add the launcher to torch, but introduces a basic test suite. A list of TODOs that are not yet complete, will do in separate diff: - Handle `nvTmaDesc` and `cuTensorMap`, which triton handles - Embed the grid logic instead of passing in gridX,Y,Z. With #147583, we should be able to handle all of the grid logic directly in _StaticCudaLauncher.launch_kernel, and get rid of the python evaluation. - Handle launch_enter and exit hooks? (Not sure if inductor has these) - Benchmarking to see if there's runtime performance loss - Hooking it up with a config to inductor - Testing harness to test against torch generated triton kernels Differential Revision: [D69926783](https://our.internmc.facebook.com/intern/diff/D69926783/) Pull Request resolved: #148561 Approved by: https://github.com/aorenste, https://github.com/syed-ahmed
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
Before this PR, calling a triton kernel would look like:
where the
grid=
was passed as a callable (function closure) arg. This PR removes the grid arg:instead now the grid computation is included in the kernel launcher, with something like:
This should be faster, since we remove multiple function/dict calls and are able to specialize the grid computation for each
triton.Config
.It also allows us to unify the handling of grids between the Python and C++ wrapper code. Before this, C++ wrapper code didn't actually support dynamic grid sizes and instead burned in a static grid.
This unification allows this PR to be a net deletion of code.
Note the attached diff contains some minor fbcode-only changes.
cc @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport @dllehr-amd @jataylo @hongxiayang @naromero77amd @albanD @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov