8000 Scheduler Flops refactor by exclamaforte · Pull Request #152708 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Scheduler Flops refactor #152708

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed

Conversation

exclamaforte
Copy link
Contributor
@exclamaforte exclamaforte commented May 2, 2025

This refactors estimate_flops and get_estimated_runtime on scheduler nodes:

  1. New function on BaseSchedulerNode: estimate_flops. Works with all types of ir nodes now, not just ExternalKernels.
  2. Extends get_estimated_runtime to work with non-ExternalKernels.

Prelude to: #149697

Testing:
New unit tests cover functionality.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov

Copy link
pytorch-bot bot commented May 2, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/152708

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit ff251ba with merge base 61dd2a0 (image):

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@eellison
Copy link
Contributor
eellison commented May 2, 2025

should fix #147137

Copy link
Contributor
@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice, looks good ! just one comment about testing

Comment on lines 68 to 80
gm = make_fx(op)(*example_inputs, **kwargs)
reference_flops = get_total_flops(mode)

graph = GraphLowering(gm)

with V.set_graph_handler(graph), V.set_debug_handler(DebugContext()):
graph.run(*example_inputs, **kwargs)
graph.init_wrapper_code()
graph._update_scheduler()
scheduler_flops = 0
for node in graph.scheduler.nodes:
flops = node.estimate_flops()
scheduler_flops += flops if flops is not None else 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can we just make this a metric we store on counters ? I would rather we just run torch.compile here.

Copy link
Member
@xmfan xmfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the extension!

@exclamaforte exclamaforte requested a review from eellison May 7, 2025 08:39
@exclamaforte exclamaforte force-pushed the exclamaforte/scheduler-flops-refactor branch from 41d36ba to 76ffe99 Compare May 7, 2025 20:51
@exclamaforte exclamaforte force-pushed the exclamaforte/scheduler-flops-refactor branch from 76ffe99 to ff251ba Compare May 8, 2025 00:33
Copy link
Contributor
@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice !!

@exclamaforte
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label May 9, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@laithsakka
Copy link
Contributor
laithsakka commented May 10, 2025

Note this PR add the following 10% regression:
{
"mm_loop_inductor_dynamic_gpu": 9.9632850491536,
"mm_loop_inductor_gpu": -5.0718946821206
}
cc @eellison

Screenshot 2025-05-09 at 10 20 22 PM

@eellison
Copy link
Contributor

Can we avoid invoking estimate_runtime() when it's not needed ? also can we cache the flops estimation for a particular op and input shapes ?

@exclamaforte
Copy link
Contributor Author

@eellison yeah I think get_estimate_runtime is called for what amounts to some logging code in most cases, which probably shouldn't be happening:
https://github.com/pytorch/pytorch/blob/main/torch/_inductor/compile_fx.py#L1383

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants
0