8000 `has_triton`: Use the device interface for detecting Triton availability by galexite · Pull Request #139171 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

has_triton: Use the device interface for detecting Triton availability #139171

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

@galexite galexite requested a review from zou3519 as a code owner October 29, 2024 09:08
Copy link
pytorch-bot bot commented Oct 29, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/139171

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 New Failures, 2 Unrelated Failures

As of commit d4b2963 with merge base 8904ba6 (image):

NEW FAILURES - The following jobs have failed:

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@galexite
Copy link
Contributor Author

@pytorchbot label 'topic: not user facing'

@pytorch-bot pytorch-bot bot added the topic: not user facing topic category label Oct 29, 2024
@zou3519 zou3519 removed their request for review October 29, 2024 20:52
@cpuhrsch cpuhrsch requested review from jansel, Chillee and eellison and removed request for Chillee October 31, 2024 22:16
@cpuhrsch cpuhrsch added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module and removed release notes: sparse release notes category labels Oct 31, 2024
jansel
jansel previously approved these changes Oct 31, 2024
@galexite
Copy link
Contributor Author
galexite commented Nov 1, 2024

@jansel sorry about that, looks like I messed up the typing, should be ready for another CI run now!

@jansel
Copy link
Contributor
jansel commented Nov 1, 2024

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Nov 1, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

@galexite
Copy link
Contributor Author

@jansel @tugsbayasgalan I've fixed the tests and added the HOPs from Inductor that appear when I call init_backends_registration (which then eventually imports torch._inductor.ir) in inductor_utils.py to the list of HOPs without op info here: https://github.com/pytorch/pytorch/pull/139171/files#diff-2a1edf2b2655350ac32e1d24b4cfaf5b65d90056c4fcb2f77600245e428d5131R78-R84. I hope this is okay.

@eellison eellison removed their request for review November 20, 2024 17:09
@galexite
Copy link
Contributor Author

Hey @jansel, could I trouble you for a re-review please? Thanks!

@pytorch-bot pytorch-bot bot removed ciflow/inductor ciflow/xpu Run XPU CI tasks labels May 6, 2025
@galexite
Copy link
Contributor Author
galexite commented May 6, 2025

Rebased since #152529 was merged.

I have removed the changes to the Inductor scheduler checks, I think that is where the problem may lie, instead this only includes the has_triton component. I'll submit the Inductor scheduler checks as a separate PR.

I'm hoping this PR will now pass when workflows are rerun.

@galexite galexite changed the title Use the device interface for detecting Triton availability has_triton: Use the device interface for detecting Triton availability May 6, 2025
@galexite
Copy link
Contributor Author
galexite commented May 7, 2025

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label May 7, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

facebook-github-bot pushed a commit to pytorch/benchmark that referenced this pull request May 7, 2025
…ity (#139171)

Summary:
This PR replaces the `has_triton()` global method which was previously used for this task.

X-link: pytorch/pytorch#139171
Approved by: https://github.com/jansel, https://github.com/shink

Reviewed By: huydhn

Differential Revision: D74338720

fbshipit-source-id: 27106df937bbdea2da1f4911ffffcfae056f844d
@masnesral
Copy link
Contributor

Looks like this causes a pretty big perf drop on some huggingface models. For example, on an H100 from 2.6x -> 1.3x for the following:
python benchmarks/dynamo/huggingface.py --performance --inference --bfloat16 --backend inductor --device cuda --cold-start-latency --only BlenderbotSmallForCausalLM

@galexite
Copy link
Contributor Author
galexite commented May 10, 2025

Hi @masnesral, if this PR is causing those performance drops, it might be because hand-written Triton kernels aren't being used, because has_triton returns False? What happens if you do the following in a Python REPL?

from torch.utils._triton import has_triton
print(f"{has_triton()=}")
print(f"{has_triton("cuda")=}")

from torch._dynamo.device_interface import get_interface_for_device
get_interface_for_device("cuda").raise_if_triton_unavailable()  # shouldn't throw if all is okay

Unfortunately, I don't have access to an H100, but this should work the same. These give True for me on a g5.8xlarge AWS instance and the last doesn't throw.

Also, to help with debugging this, could you also tell me what the result of these are for you?

import triton, torch
print(f'{"nvidia" in triton.backends.backends=}')
print(f'{torch.cuda.get_device_properties("cuda")=}')

@galexite
Copy link
Contributor Author
galexite commented May 10, 2025

The other thing is, I did change the if in Inductor which enables the pad_mm pass, to check if the tensor is on a device actively using the TritonScheduler, rather than using has_triton. Maybe that incorrectly causes an early exit of that Inductor pass?

@masnesral
Copy link
Contributor

Sorry, gonna try to revert this while we investigate further.

@masnesral
Copy link
Contributor

@pytorchbot revert -m="Performance regression for huggingface" -c=nosignal

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

pytorchmergebot added a commit that referenced this pull request May 10, 2025
…vailability (#139171)"

This reverts commit 48bfe9a.

Reverted #139171 on behalf of https://github.com/masnesral due to Performance regression for huggingface ([comment](#139171 (comment)))
@pytorchmergebot
Copy link
Collaborator

@galexite your PR has been successfully reverted.

@pytorch-bot pytorch-bot bot dismissed stale reviews from jansel and shink May 10, 2025 14:46

This PR was reopened (likely due to being reverted), so your approval was removed. Please request another review.

@masnesral
Copy link
Contributor

@galexite my having tested on H100 doesn't seem to be a related (see links to a100 above), but here's the output you requested anyway. I think it's what you were expecting :/

>>> import triton, torch
>>> print(f'{"nvidia" in triton.backends.backends=}')
"nvidia" in triton.backends.backends=True
>>> print(f'{torch.cuda.get_device_properties("cuda")=}')
torch.cuda.get_device_properties("cuda")=_CudaDeviceProperties(name='NVIDIA H100', major=9, minor=0, total_memory=97285MB, multi_processor_count=132, uuid=70fa10a6-2939-4471-959c-6da3b40decb6, pci_bus_id=6, pci_device_id=0, pci_domain_id=0, L2_cache_size=60MB)

>>> from torch.utils._triton import has_triton
>>> print(f"{has_triton()=}")
has_triton()=True
>>> print(f"{has_triton('cuda')=}")
has_triton('cuda')=True
>>> from torch._dynamo.device_interface import get_interface_for_device
>>> get_interface_for_device("cuda").raise_if_triton_unavailable()

@galexite
Copy link
Contributor Author

Hmm, okay. I'll have a look!

facebook-github-bot pushed a commit to pytorch/benchmark that referenced this pull request May 11, 2025
…vailability (#139171)"

Summary:
This reverts commit 48bfe9afc70a98addd5aa738bf501c029e4a9285.

Reverted pytorch/pytorch#139171 on behalf of https://github.com/masnesral due to Performance regression for huggingface ([comment](pytorch/pytorch#139171 (comment)))

Reviewed By: huydhn

Differential Revision: D74531472

fbshipit-source-id: 751398ae3c03cdd1d1d7c75a5088207a3a1784cb
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci-no-td Do not run TD on this PR ciflow/trunk Trigger trunk jobs on your pull request Merged module: cpu CPU specific problem (e.g., perf, algorithm) module: dynamo module: inductor oncall: distributed Add this issue/PR to distributed oncall triage queue open source release notes: quantization release notes category Reverted topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

10 participants
0