8000 inductor change needed to update triton pin by shunting314 · Pull Request #107722 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

inductor change needed to update triton pin #107722

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from

Conversation

shunting314
Copy link
Contributor
@shunting314 shunting314 commented Aug 22, 2023

@pytorch-bot
Copy link
pytorch-bot bot commented Aug 22, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/107722

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit f862382 with merge base 138e289 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: releng release notes category label Aug 22, 2023
@github-actions github-actions bot added ciflow/trunk Trigger trunk jobs on your pull request module: inductor ciflow/inductor labels Aug 22, 2023
shunting314 added a commit that referenced this pull request Aug 22, 2023
ghstack-source-id: 4972d83
Pull Request resolved: #107722
@shunting314
Copy link
Contributor Author

The failed test:

python test/inductor/test_pattern_matcher.py -k test_mixed_mm

seems to be related to this upgrade.

We trigger an error in triton C++ code:

UNREACHABLE executed at /home/shunting/ws/triton/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp:253!

@shunting314
Copy link
Contributor Author

Cut a triton issue for the failure in mixed mm: triton-lang/triton#2156

@shunting314
Copy link
Contributor Author

with the new triton pin,

python test/test_sparse_csr.py -k test_triton_scaled_dot_product_attention_block_size_16_cuda_bfloat16

starts to fail with error:

RuntimeError: Triton Error [CUDA]: misaligned address

I guess it may due to the test uses sparse tensor and triton may have changed it's alignment requirements.

@cpuhrsch I saw the test is added by #102095 . Do you think this is a blocking test failure?

@cpuhrsch
Copy link
Contributor

@shunting314 - Given that it used to work and now with the moved pin fails, I'd consider this a blocking failure. It'd be good to figure out why this is breaking now with the new version of Triton. cc @amjames @pearu

@shunting314
Copy link
Contributor Author

@shunting314 - Given that it used to work and now with the moved pin fails, I'd consider this a blocking failure. It'd be good to figure out why this is breaking now with the new version of Triton. cc @amjames @pearu

by any chance the test pass a view to trition which results in an unaligned address?

Copy link
Contributor
@jansel jansel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

failing tests?

@shunting314
Copy link
Contributor Author

failing tests?

There are 2 failed tests mentioned above:

  1. test_mixed_mm: should be a triton bug and I've put a standalone triton repro here: type conversion before tl.dot fails compilation triton-lang/triton#2156 . I can dig further but it may be much faster if triton team can took a look
  2. test_triton_scaled_dot_product_attention_block_size_16_cuda_bfloat16

There are the only broken tests in CI.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
shunting314 added a commit that referenced this pull request Aug 23, 2023
ghstack-source-id: 17cad06
Pull Request resolved: #107722
@shunting314
Copy link
Contributor Author

For the test failure in test_triton_scaled_dot_product_attention_block_size_16_cuda_bfloat16 , I've found the offending kernel is _bsr_softmax_kernel in torch/sparse/_triton_ops.py . @cpuhrsch , @amjames @pearu can you help create a standalone script calling that kernel with random examples? If that's a triton problem, we can cut a triton issue with your repro

Copy link
Contributor
@cpuhrsch cpuhrsch left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this update breaks a kernel that worked previously, why is it ok to land it?

@cpuhrsch
Copy link
Contributor

Regardless, we can create a standalone version of the kernel to more easily reproduce this error.

@shunting314
Copy link
Contributor Author
shunting314 commented Aug 24, 2023

There are 2 more test failures due to the testing environment is using an old version of trition: https://github.com/pytorch/pytorch/actions/runs/5947833455/job/16130966172 , https://github.com/pytorch/pytorch/actions/runs/5947833455/job/16130966318 . I think if we want, we can still make inductor work with older version of trition with a bit more complex code. Just not sure if we should do that or upgrading trition version in those cases instead.

EDIT: I partially fixed the BC issue by checking if CompiledKernel has num_ctas attributes. To fully fix, we also need check if triton expect the new definition of instance_descriptor.

shunting314 added a commit that referenced this pull request Aug 24, 2023
ghstack-source-id: e8e4078
Pull Request resolved: #107722
@shunting314
Copy link
Contributor Author

For the test failure in test_triton_scaled_dot_product_attention_block_size_16_cuda_bfloat16 , I've found the offending kernel is _bsr_softmax_kernel in torch/sparse/_triton_ops.py .

I took a further look, actually the root cause of the issue is not kernel _bsr_softmax_kernel but an upstream kernel _sampled_addmm_kernel . _sampled_addmm_kernel 'corrupt' the input tensors and then cause issues in the downstream _bsr_softmax_kernel kernel.

To repro:

run

python test/test_sparse_csr.py -k test_triton_scaled_dot_product_attention_block_size_16_cuda_bfloat16

with the following breakpionts set:

diff --git a/torch/sparse/_triton_ops.py b/torch/sparse/_triton_ops.py
index 57c9ac0168a..80cf3ff6e0f 100644
--- a/torch/sparse/_triton_ops.py
+++ b/torch/sparse/_triton_ops.py
@@ -539,6 +539,7 @@ if _has_triton():
             allow_tf32 = False

         def kernel(grid, *sliced_tensors):
+            breakpoint() # TODO
             _sampled_addmm_kernel[grid](
                 alpha, beta, is_beta_zero,
                 *blocksize, k, tile_k,
@@ -548,6 +549,7 @@ if _has_triton():
                 num_stages=1,
                 num_warps=4
             )
+            breakpoint() # TODO

         launch_kernel(kernel, tensor_dims_map, full_grid, grid_blocks)

at the first breakpoint, we are able to print(sliced_tensors). but at the second breakpoint, print the same tensors (which has been corrupted) will result in:

(Pdb) sliced_tensors
*** RuntimeError: CUDA error: misaligned address
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.

@shunting314
Copy link
Contributor Author
shunting314 commented Aug 25, 2023

The perf test looks mostly neutral link, although

  • torchbench default slow down from 1.19x to 1.16x . But torchbench with cudagraphs is neutral
  • 4 models in TB get network issues and fail to run.

I'll rerun the perf tests.

Edit:
New perf test link

  • same conclusion as above. Except one more timm model pass. It's failed previous due to 'two_eager_run_differ'. I think it's not related to the upgrade (even if it's a nice thing), but due to flakiness.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
shunting314 added a commit that referenced this pull request Aug 28, 2023
ghstack-source-id: 5d421f7
Pull Request resolved: #107722
@shunting314
Copy link
Contributor Author
shunting314 commented Aug 28, 2023

Split the pin update to a separate PR per @shintaro-iwasaki 's request to make FBCode side testing easier.

@shunting314 shunting314 mentioned this pull request Aug 28, 2023
@@ -49,8 +49,25 @@ def is_aligned(x):
return V.graph.sizevars.statically_known_multiple_of(x.expr, ALIGNMENT)
raise NotImplementedError(f"unhandled {type(x)}: {x}")

def is_aligned_8(x):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could this share more code with is_aligned?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea, I'll do that in a follow up PR to make cherry-picking this one earlier.

@shunting314
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

shunting314 added a commit that referenced this pull request Aug 29, 2023
ghstack-source-id: 5d421f7
Pull Request resolved: #107722
@shunting314 shunting314 changed the title update triton pin with needed inductor change inductor change needed to update triton pin Aug 29, 2023
n = max(next_power_of_2(V.graph.sizevars.size_hint(n)), 16)
k = max(next_power_of_2(V.graph.sizevars.size_hint(k)), 16)

# According to https://github.com/openai/triton/issues/2156#issuecomment-1695897424
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, it could be 16x32 if you want to try to improve perf a bit :).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So that means m can be 16 but n and k have to be at lease 32 for int8?
Since we have tl.tensor with shape [m, k], [k, n], [m, n] in the triton kernel.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So that means m can be 16 but n and k have to be at lease 32 for int8?

k must be >= 32 but n and m can be >= 16 if both a and b in axb are not transposed.

shunting314 added a commit that referenced this pull request Aug 29, 2023
Resolve comment: #107722 (comment)  


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
pytorchmergebot pushed a commit that referenced this pull request Aug 30, 2023
Resolve comment: #107722 (comment)

Pull Request resolved: #108135
Approved by: https://github.com/jansel
ghstack dependencies: #107722
@facebook-github-bot facebook-github-bot deleted the gh/shunting314/74/head branch September 1, 2023 14:24
atalman pushed a commit that referenced this pull request Sep 5, 2023
ghstack-source-id: 5d421f7
Pull Request resolved: #107722
pytorchmergebot pushed a commit that referenced this pull request Sep 8, 2023
Pull Request resolved: #108104
Approved by: https://github.com/desertfire
ghstack dependencies: #107722
chuanqi129 added a commit to chuanqi129/pytorch that referenced this pull request Sep 14, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged module: inductor release notes: releng release notes category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants
0