8000 Inductor logging + analysis of torch.profile by exclamaforte · Pull Request #149697 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Inductor logging + analysis of torch.profile #149697

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

exclamaforte
Copy link
Contributor
@exclamaforte exclamaforte commented Mar 21, 2025

Prereqs:

Features:

  1. Adds inductor's estimate of flops and bandwidth to the json trace events that perfetto uses.
  2. Only use the tflops estimation from triton if we don't have the info from the datasheet because Triton's estimates are inaccurate. I have a backlog item to fix triton flops estimation upstream. New DeviceInfo class, and new function get_devic 8000 e_tflops.
  3. New helpers countable_fx and count_flops_fx helps get the flops of an fx.Node.
  4. Extends Triton torch.profiler logging to DebugAutotuner.
  5. New script profile_analysis.py: --augment_trace adds perf estimates to any perfetto json trace, --analyze creates a summary table of these perf estimates, and --diff will compare two traces side by side:
Device(NVIDIA H100, 0):
 Kernel Name                              | resnet Kernel Count | resnet FLOPS       | resnet bw gbps        | resnet Dur (ms)    | resnet Achieved FLOPS % | resnet Achieved Bandwidth % | newresnet Kernel Count | newresnet FLOPS    | newresnet bw gbps     | newresnet Dur (ms) | newresnet Achieved FLOPS % | newresnet Achieved Bandwidth % 
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
 triton_poi_fused__native_batch_norm_legi | 24                  | 0                  | 0.11395268248131513   | 2.5919166666666666 | 0                       | 0.003401572611382541        | 24                     | 0                  | 0.11395268248131513   | 2.5919166666666666 | 0                          | 0.003401572611382541           
 sm90_xmma_fprop_implicit_gemm_f32f32_tf3 | 142                 | 16932673552.422373 | 0.2585007824198784    | 12.441619718309857 | 0.08683422334575583     | 0.007716441266265022        | 142                    | 16932673552.422373 | 0.2585007824198784    | 12.441619718309857 | 0.08683422334575583        | 0.007716441266265022           
 triton_red_fused__native_batch_norm_legi | 39                  | 0                  | 0.13990024992108846   | 5.752589743589743  | 0                       | 0.004176126863316074        | 39                     | 0                  | 0.13990024992108846   | 5.752589743589743  | 0                          | 0.004176126863316074           
 triton_poi_fused__native_batch_norm_legi | 25                  | 0                  | 0.31824055917536503   | 2.5291999999999994 | 0                       | 0.009499718184339253        | 25                     | 0                  | 0.31824055917536503   | 2.5291999999999994 | 0                          | 0.009499718184339253           
 void cutlass::Kernel2<cutlass_80_tensoro | 98                  | 16211056473.596165 | 0.42972434051025826   | 7.130408163265306  | 0.08313362294151874     | 0.012827592254037562        | 98                     | 16211056473.596165 | 0.42972434051025826   | 7.130408163265306  | 0.08313362294151874        | 0.012827592254037562           
 triton_red_fused__native_batch_norm_legi | 73                  | 0                  | 0.3225381327611705    | 9.987068493150682  | 0                       | 0.009628003963020014        | 73                     | 0                  | 0.3225381327611705    | 9.987068493150682  | 0                          | 0.009628003963020014           
 triton_poi_fused__native_batch_norm_legi | 15                  | 0                  | 1.4491211346487216    | 4.439333333333333  | 0                       | 0.043257347302946926        | 15                     | 0                  | 1.4491211346487216    | 4.439333333333333  | 0                          | 0.043257347302946926           
 void cutlass::Kernel2<cutlass_80_tensoro | 186                 | 14501701145.337954 | 0.2667131401910989    | 7.873865591397849  | 0.07436769818122027     | 0.007961586274361157        | 186                    | 14501701145.337954 | 0.2667131401910989    | 7.873865591397849  | 0.07436769818122027        | 0.007961586274361157           
 triton_poi_fused__native_batch_norm_legi | 33                  | 0                  | 1.4924556538193923    | 4.3101515151515155 | 0                       | 0.044550915039384846        | 33                     | 0                  | 1.4924556538193923    | 4.3101515151515155 | 0                          | 0.044550915039384846           
 triton_red_fused__native_batch_norm_legi | 29                  | 0                  | 0.25562590522631107   | 6.296275862068965  | 0                       | 0.007630624036606301        | 29                     | 0                  | 0.25562590522631107   | 6.296275862068965  | 0                          | 0.007630624036606301           
 triton_poi_fused__native_batch_norm_legi | 13                  | 0                  | 0.5870562174192726    | 2.7397692307692307 | 0                       | 0.01752406619162008         | 13                     | 0                  | 0.5870562174192726    | 2.7397692307692307 | 0                          | 0.01752406619162008            
 triton_poi_fused__native_batch_norm_legi | 34                  | 0                  | 0.41409928846284      | 2.853588235294117  | 0                       | 0.012361172789935523        | 34                     | 0                  | 0.41409928846284      | 2.853588235294117  | 0                          | 0.012361172789935523           
 triton_per_fused__native_batch_norm_legi | 34                  | 0                  | 0.11705315007018151   | 3.460647058823529  | 0                       | 0.0034941238826919864       | 34                     | 0                  | 0.11705315007018151   | 3.460647058823529  | 0                          | 0.0034941238826919864          
 triton_poi_fused__native_batch_norm_legi | 16                  | 0                  | 0.17207853197124584   | 2.3459375000000002 | 0                       | 0.005136672596156592        | 16                     | 0                  | 0.17207853197124584   | 2.3459375000000002 | 0                          | 0.005136672596156592           
 triton_per_fused__native_batch_norm_legi | 30                  | 0                  | 0.2639714322022256    | 6.131199999999999  | 0                       | 0.007879744244842555        | 30                     | 0                  | 0.2639714322022256    | 6.131199999999999  | 0                          | 0.007879744244842555           
 sm90_xmma_fprop_implicit_gemm_f32f32_tf3 | 100                 | 11875430356.891787 | 0.19494470869421385   | 16.36534           | 0.06089964285585531     | 0.005819245035648175        | 100                    | 11875430356.891787 | 0.19494470869421385   | 16.36534           | 0.06089964285585531        | 0.005819245035648175           
 triton_poi_fused__native_batch_norm_legi | 8                   | 0                  | 0.9854096626224687    | 3.2757500000000004 | 0                       | 0.029415213809625928        | 8                      | 0                  | 0.9854096626224687    | 3.2757500000000004 | 0                          | 0.029415213809625928           
 void cublasLt::splitKreduce_kernel<32, 1 | 56                  | 34377923395.147064 | 0.8310300045762317    | 3.4199999999999986 | 0.17629704305203628     | 0.024806865808245714        | 56                     | 34377923395.147064 | 0.8310300045762317    | 3.4199999999999986 | 0.17629704305203628        | 0.024806865808245714           
 triton_poi_fused__native_batch_norm_legi | 23                  | 0                  | 0.9944002965861103    | 3.2431304347826084 | 0                       | 0.02968359094286896         | 23                     | 0                  | 0.9944002965861103    | 3.2431304347826084 | 0                          | 0.02968359094286896            
 triton_per_fused__native_batch_norm_legi | 10                  | 0                  | 0.1826801058931057    | 4.428800000000001  | 0                       | 0.00545313748934644         | 10                     | 0                  | 0.1826801058931057    | 4.428800000000001  | 0                          | 0.00545313748934644            
 triton_poi_fused__native_batch_norm_legi | 10                  | 0                  | 0.3168973585366449    | 2.5471999999999997 | 0                       | 0.009459622642884923        | 10                     | 0                  | 0.3168973585366449    | 2.5471999999999997 | 0                          | 0.009459622642884923           
 triton_poi_fused__native_batch_norm_legi | 34                  | 0                  | 1.1463614897015777    | 4.124323529411764  | 0                       | 0.03421974596124114         | 34                     | 0                  | 1.1463614897015777    | 4.124323529411764  | 0                          | 0.03421974596124114            
 void cask_plugin_cudnn::xmma_cudnn::init | 44                  | 44045510816.64277  | 2.0661232850348643    | 3.6887499999999993 | 0.22587441444432194     | 0.06167532194133924         | 44                     | 44045510816.64277  | 2.0661232850348643    | 3.6887499999999993 | 0.22587441444432194        | 0.06167532194133924            
 sm90_xmma_fprop_implicit_gemm_f32f32_tf3 | 95                  | 7876855400.165316  | 0.4694941555946739    | 18.224315789473682 | 0.04039413025725802     | 0.014014750913273854        | 95                     | 7876855400.165316  | 0.4694941555946739    | 18.224315789473682 | 0.04039413025725802        | 0.014014750913273854           
 triton_per_fused__native_batch_norm_legi | 41                  | 0                  | 0.06825669875995298   | 3.0384146341463416 | 0                       | 0.002037513395819492        | 41                     | 0                  | 0.06825669875995298   | 3.0384146341463416 | 0                          | 0.002037513395819492           
 triton_poi_fused__native_batch_norm_legi | 23                  | 0                  | 0.08808154712430301   | 2.3275652173913044 | 0                       | 0.0026292999141582997       | 23                     | 0                  | 0.08808154712430301   | 2.3275652173913044 | 0                          | 0.0026292999141582997          
 triton_per_fused__native_batch_norm_legi | 40                  | 0                  | 0.18179321034952417   | 4.556825           | 0                       | 0.005426662995508183        | 40                     | 0                  | 0.18179321034952417   | 4.556825           | 0                          | 0.005426662995508183           
 triton_poi_fused__native_batch_norm_legi | 15                  | 0                  | 0.5887415155454232    | 2.783866666666667  | 0                       | 0.017574373598370836        | 15                     | 0                  | 0.5887415155454232    | 2.783866666666667  | 0                          | 0.017574373598370836           
 void cutlass::Kernel2<cutlass_80_tensoro | 38                  | 14242013806.264643 | 0.256592404353939     | 7.217631578947369  | 0.0730359682372546      | 0.007659474756834           | 38                     | 14242013806.264643 | 0.256592404353939     | 7.217631578947369  | 0.0730359682372546         | 0.007659474756834              
 triton_poi_fused__native_batch_norm_legi | 21                  | 0                  | 0.5842860973430516    | 2.7779047619047623 | 0                       | 0.017441376040091088        | 21                     | 0                  | 0.5842860973430516    | 2.7779047619047623 | 0                          | 0.017441376040091088           
 triton_per_fused__native_batch_norm_legi | 16                  | 0                  | 0.11509365173486417   | 3.5959375000000002 | 0                       | 0.0034356313950705724       | 16                     | 0                  | 0.11509365173486417   | 3.5959375000000002 | 0                          | 0.0034356313950705724          
 triton_poi_fused__native_batch_norm_legi | 14                  | 0                  | 0.1704672000243914    | 2.4044285714285714 | 0                       | 0.00508857313505646         | 14                     | 0                  | 0.1704672000243914    | 2.4044285714285714 | 0                          | 0.00508857313505646            
 triton_poi_fused__native_batch_norm_legi | 58                  | 0                  | 2.307520779930795     | 8.190706896551722  | 0                       | 0.06888121731136704         | 58                     | 0                  | 2.307520779930795     | 8.190706896551722  | 0                          | 0.06888121731136704            
 triton_per_fused__native_batch_norm_legi | 29                  | 0                  | 0.037243248971881276  | 3.0277586206896556 | 0                       | 0.001111738775280038        | 29                     | 0                  | 0.037243248971881276  | 3.0277586206896556 | 0                          | 0.001111738775280038           
 triton_poi_fused__native_batch_norm_legi | 20                  | 0                  | 0.04741699795428918   | 2.2911500000000005 | 0                       | 0.0014154327747549007       | 20                     | 0                  | 0.04741699795428918   | 2.2911500000000005 | 0                          | 0.0014154327747549007          
 triton_per_fused__native_batch_norm_legi | 25                  | 0                  | 0.13357016893727824   | 3.37536            | 0                       | 0.003987169222008305        | 25                     | 0                  | 0.13357016893727824   | 3.37536            | 0                          | 0.003987169222008305           
 triton_poi_fused__native_batch_norm_legi | 13                  | 0                  | 0.3089862268300253    | 2.8111538461538457 | 0                       | 0.009223469457612694        | 13                     | 0                  | 0.3089862268300253    | 2.8111538461538457 | 0                          | 0.009223469457612694           
 triton_poi_fused__native_batch_norm_legi | 17                  | 0                  | 0.3129385387909844    | 2.673              | 0                       | 0.009341448919133863        | 17                     | 0                  | 0.3129385387909844    | 2.673              | 0                          | 0.009341448919133863           
 triton_per_fused__native_batch_norm_legi | 19                  | 0                  | 0.2215568162533158    | 3.8837368421052636 | 0                       | 0.0066136363060691275       | 19                     | 0                  | 0.2215568162533158    | 3.8837368421052636 | 0                          | 0.0066136363060691275          
 std::enable_if<!(false), void>::type int | 23                  | 504916805.19297093 | 1.0118296096314707    | 8.113913043478261  | 0.0025893169497075447   | 0.030203868944223014        | 23                     | 504916805.19297093 | 1.0118296096314707    | 8.113913043478261  | 0.0025893169497075447      | 0.030203868944223014           
 triton_poi_fused_add_copy__38            | 56                  | 0                  | 0                     | 2.132482142857143  | 0                       | 0                           | 56                     | 0                  | 0                     | 2.132482142857143  | 0                          | 0                              
 triton_poi_fused_convolution_0           | 18                  | 0                  | 0.43458610794936897   | 2.773333333333334  | 0                       | 0.012972719640279667        | 18                     | 0                  | 0.43458610794936897   | 2.773333333333334  | 0                          | 0.012972719640279667           
 triton_poi_fused_convolution_1           | 17                  | 0                  | 0.028816312469162712  | 2.6145882352941174 | 0                       | 0.0008601884319153051       | 17                     | 0                  | 0.028816312469162712  | 2.6145882352941174 | 0                          | 0.0008601884319153051          
 void convolve_common_engine_float_NHWC<f | 44                  | 8641868995.31118   | 0.024730540008465626  | 25.87327272727273  | 0.04431727689903169     | 0.0007382250748795709       | 44                     | 8641868995.31118   | 0.024730540008465626  | 25.87327272727273  | 0.04431727689903169        | 0.0007382250748795709          
 triton_per_fused__native_batch_norm_legi | 12                  | 0                  | 0.6809930918986744    | 4.82675            | 0                       | 0.020328151996975356        | 12                     | 0                  | 0.6809930918986744    | 4.82675            | 0                          | 0.020328151996975356           
 triton_per_fused__native_batch_norm_legi | 14                  | 0                  | 0.02883030597936608   | 2.6651428571428575 | 0                       | 0.0008606061486377935       | 14                     | 0                  | 0.02883030597936608   | 2.6651428571428575 | 0                          | 0.0008606061486377935          
 triton_per_fused__native_batch_norm_legi | 16                  | 0                  | 0.0014658988233201874 | 2.098              | 0                       | 4.375817383045335e-05       | 16                     | 0                  | 0.0014658988233201874 | 2.098              | 0                          | 4.375817383045335e-05          
 triton_poi_fused__native_batch_norm_legi | 13                  | 0                  | 0.9926297180284697    | 3.2367692307692306 | 0                       | 0.02963073785159611         | 13                     | 0                  | 0.9926297180284697    | 3.2367692307692306 | 0                          | 0.02963073785159611            
 triton_poi_fused__native_batch_norm_legi | 9                   | 0                  | 1.3008817095666507    | 3.0863333333333336 | 0                       | 0.03883228983781048         | 9                      | 0                  | 1.3008817095666507    | 3.0863333333333336 | 0                          | 0.03883228983781048            
 void at::native::(anonymous namespace):: | 98                  | 0                  | 0.09174335613709389   | 4.408520408163265  | 0                       | 0.0027386076458833994       | 98                     | 0                  | 0.09174335613709389   | 4.408520408163265  | 0                          | 0.0027386076458833994          
 void at::native::vectorized_elementwise_ | 7                   | 0                  | 0                     | 1.7278571428571428 | 0                       | 0                           | 7                      | 0                  | 0                     | 1.7278571428571428 | 0                          | 0                              

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov

Copy link
pytorch-bot bot commented Mar 21, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/149697

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 1 New Failure

As of commit 35daac3 with merge base a4459cd (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@exclamaforte exclamaforte force-pushed the exclamaforte/log_mul branch from ced9053 to 09d34f9 Compare March 21, 2025 03:39
@exclamaforte exclamaforte changed the title WIP python logging of flops WIP inductor logging of flops and bw Mar 21, 2025
@exclamaforte exclamaforte changed the title WIP inductor logging of flops and bw inductor logging of flops and bw Mar 28, 2025
@exclamaforte exclamaforte changed the title inductor logging of flops and bw WIP inductor logging of flops and bw Mar 28, 2025
@exclamaforte exclamaforte changed the title WIP inductor logging of flops and bw Inductor logging of flops and bw Mar 28, 2025
@exclamaforte exclamaforte changed the title Inductor logging of flops and bw Inductor logging of flops and bw in torch.profile Mar 28, 2025
@exclamaforte exclamaforte changed the title Inductor logging of flops and bw in torch.profile Inductor logging of flops and bandwidth in torch.profile Mar 28, 2025
@exclamaforte exclamaforte changed the title Inductor logging of flops and bandwidth in torch.profile WIP Inductor logging of flops and bandwidth in torch.profile Mar 28, 2025
@exclamaforte exclamaforte added topic: improvements topic category and removed topic: new features topic category labels Mar 28, 2025
@exclamaforte exclamaforte changed the title WIP Inductor logging of flops and bandwidth in torch.profile Inductor logging of flops and bandwidth in torch.profile Mar 28, 2025
@shunting314
Copy link
Contributor

Why matmul/convolution is missing in the diff table?

Copy link
Contributor
@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking good, a few comments.

General questions:

  • What is this enabled under, I see config.profile_bandwidth and config. benchmark_kernel ?

  • I think not being on by default makes sense. But just to say it - if this is on by default, we need to be careful about interfering with caching by annotating the configs. But anyways, not a real issue.

  • Could be nice to express flops/memory bandwidth as expression of symints. bc today they could be misleading with different shapes..

  • Ultimately we are going to want not to look at memory bandwidth or flops as % of hardware limits. so would be great to enable that as part of perfetto if possible.

For my own understanding - how does this get into perfetto trace ? and is there anyway to augment it more outside of perfetto (like for calculating memory bw based on symint inputs). cc @davidberard98 who is knowledgeable on profiler.

@@ -736,6 +737,48 @@ def get_buf_bytes(

return buf_byte_accesses

@cache_on_self
def estimate_flops(self) -> int | None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do we want to handle symints ? we would need to augment caching autotuner with a sympy expression based on inputs.

FlopCounter should work with sympy. see, colab here: https://colab.research.google.com/drive/1zjAisRrc8R6uixKsrs1DRm3lwz5MWN68#scrollTo=zbl9to6G02vQ

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eellison So symints ended up breaking this feature when used in max autotune, because it would generate an inductor_meta like:

inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': '969E2A31A8F620FC2E1E37975F57012FF2CF1D2549D5BE17FB534BA8B4F5EF84', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': True, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'kernel_num_gb': 0.042842176, 'kernel_flop': 1536*s27*s38},

and would throw an error:

File "/tmp/tmpbx3d5iaf/4y/c4yldhzvilvh5h35ih35aud3c42bbzeuuvxspmufnlojgth7yxgy.py", line 20, in <module>
torch._inductor.exc.InductorError: NameError: name 's27' is not defined

So I added a line like:

        resolved_flops = V.graph.sizevars.size_hints((flops,))[0]

@exclamaforte
Copy link
Contributor Author

@eellison

  • What is this enabled under, I see config.profile_bandwidth and config. benchmark_kernel ?

The inductor_meta annotations will be triggered if either is set. They will only be added to torch.profile whenever autograd_profiler._is_profiler_enabled, meaning when the code is surrounded in a with torch.profile block.

  • I think not being on by default makes sense. But just to say it - if this is on by default, we need to be careful about interfering with caching by annotating the configs. But anyways, not a real issue.

You mean adding to inductor_meta will prevent caching?

  • Could be nice to express flops/memory bandwidth as expression of symints. bc today they could be misleading with different shapes..

Good point, I'll take a look.

  • Ultimately we are going to want not to look at memory bandwidth or flops as % of hardware limits. so would be great to enable that as part of perfetto if possible.

On it.

For my own understanding - how does this get into perfetto trace ? and is there anyway to augment it more outside of perfetto (like for calculating memory bw based on symint inputs). cc @davidberard98 who is knowledgeable on profiler.

                "kernel_bandwidth": self.inductor_meta.get("kernel_num_gb", None),
                "kernel_flops": self.inductor_meta.get("kernel_flops", None),

These two lines add it to the args field on a "slice" in the json profile. In perfetto, you can click on a "slice" and see all the values in args. We can add any string we want to these, so adding % of hardware or in terms of symints is no problem.

@eellison
Copy link
Contributor
eellison commented Apr 1, 2025

The symints is not needed for this pr ! i think other things more important. but could be cool in the future.

@davidberard98 davidberard98 requested a review from sraikund16 April 1, 2025 18:43
@sraikund16
Copy link
Contributor

Overall seems fine to me. Just wondering if we think we will want this on by default? Many users have been complaining about the size of the traces

@exclamaforte exclamaforte requested a review from eellison April 11, 2025 11:09
@eellison
Copy link
Contributor

Test failures ?

@exclamaforte exclamaforte force-pushed the exclamaforte/log_mul branch 2 times, most recently from 0d7e9f6 to f1ae257 Compare April 15, 2025 20:07
@exclamaforte exclamaforte added the suppress-bc-linter Suppresses the failures of API backward-compatibility linter (Lint/bc_linter) label Apr 16, 2025
Copy link
Contributor
@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excited for this ! Left some comments.

cc @sraikund16 can you review as well ? At least the profiler relevant parts.

Also, today, a user will have to manually run the post processing step. It would be great if there were an ap to register a post process fn that will be called before the profile is saved so this gets added automatically.

# TODO investigate profiler support for tf32 and allow device to report correct number when it's turned on.
_device_mapping: dict[str, DeviceInfo] = {
# Source: https://resources.nvidia.com/en-us-tensor-core/nvidia-tensor-core-gpu-datasheet
"NVIDIA H100": DeviceInfo(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how do we distinguish between

H100 SXM and H100 NVL ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is based on torch.cuda.get_device_name(), which is stored by the profiler and available at runtime too. I'm not sure how to distinguish them, even at runtime.
Some ideas:

>>> torch.cuda.get_device_properties()
_CudaDeviceProperties(name='NVIDIA H100', major=9, minor=0, total_memory=97272MB, multi_processor_count=132, uuid=6efc17fa-5b7e-0452-613b-df241e45f2b8, L2_cache_size=60MB)
>>> torch.cuda.mem_get_info()
(99949740032, 101997215744)

dram_gb=80,
),
# Source: https://resources.nvidia.com/en-us-tensor-core/nvidia-tensor-core-gpu-datasheet
"NVIDIA A100": DeviceInfo(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly:

image

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I saw that, I'm not sure how we solve in general for bw (flops seems fine). Like on 8x machine, the interconnect bw could be more important than dram bw.

Comment on lines +25 to +36
"NVIDIA H100": DeviceInfo(
tops={
torch.float64: 9.7,
torch.float32: 19.5,
torch.bfloat16: 1979.0,
torch.float16: 1979.0,
torch.float8_e8m0fnu: 3958.0,
torch.float8_e8m0fnu: 3958.0,
torch.float8_e4m3fnuz: 3958.0,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know the fbcode servers are clock rate limited. So, the numbers will be off for those.. I think this is actually somewhat important for getting accurate numbers.

cc @bertmaher - who did similar analysis here - https://fb.workplace.com/groups/420659799592399/posts/761265522198490/

How would you adjust for clock rate ? Is something simple like current_clock_rate/default sufficient ? I dont have a good sense of this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

interested in Bert's opinion too, I would think that current_clock_rate/default sufficient would be fine considering that most of the flops calculation are just clockrate * core count * flops per core.

Comment on lines +61 to +71
def conv_adapter(
shapes: tuple[Any, ...], concrete: tuple[Any, ...]
) -> tuple[tuple[Any], dict[Any, Any]]:
tmp = list(shapes)
if len(tmp) == 4:
transposed = False

transposed = bool(tmp[6])
tmp[6] = transposed

kwargs = {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have any more details on what the format is for the json trace ? Ideally we would be able to do this automatically with schema.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The json trace args is just the input args dimensions in a list (convenient). For convolution, if Transposed is set, then the flops counter also needs "out_val", which at runtime contains the output tensor. I don't think it's necessary, basically we could move the conv_out_dims function inside the flops counter, but I didn't want to mess with the flops counter in this PR.

Other than that the len(tmp) == 4 comes from sometimes the input args only had 4 dims, not sure why. Maybe convolution vs _convolution...

In summary, I think we can do it from just the schema, but it requires the flops counter to be smarter. I expected other ops to have more conversions, but it's really just convolution.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couldn't we create FakeTensors with the input shapes, then run FlopCounter with them ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I just implemented this, but I think the code ends up looking worse because when I call the size function directly, it knows to ignore all of the other inputs like SymInt groups, bool deterministic, whereas if I run the op they get typechecked in c++

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still think the current impl is wrong.. we should just instantiate the tensors, and then run FlopCounter mode, and run the op. you dont need to run the size function directly.

@exclamaforte exclamaforte force-pushed the exclamaforte/log_mul branch 4 times, most recently from 4ec6ab6 to 3948b47 Compare April 30, 2025 22:34
@exclamaforte exclamaforte changed the title Inductor logging of flops and bandwidth in torch.profile Inductor logging + analysis of torch.profile May 1, 2025
@exclamaforte exclamaforte requested a review from eellison May 1, 2025 23:50
@exclamaforte exclamaforte force-pushed the exclamaforte/log_mul branch from c5df71f to 57716b8 Compare May 2, 2025 18:04
@exclamaforte exclamaforte force-pushed the exclamaforte/log_mul branch 2 times, most recently from da97b18 to 1d36147 Compare May 2, 2025 19:45
pytorchmergebot pushed a commit that referenced this pull request May 9, 2025
This refactors `estimate_flops` and `get_estimated_runtime` on scheduler nodes:
1. New function on BaseSchedulerNode: `estimate_flops`. Works with all types of ir nodes now, not just `ExternalKernels`.
1. Extends `get_estimated_runtime` to work with non-`ExternalKernels`.

Prelude to: #149697

Testing:
New unit tests cover functionality.

Pull Request resolved: #152708
Approved by: https://github.com/xmfan, https://github.com/eellison
@exclamaforte exclamaforte force-pushed the exclamaforte/log_mul branch from 1d36147 to a2c29a4 Compare May 12, 2025 19:56
@exclamaforte exclamaforte force-pushed the exclamaforte/log_mul branch from a2c29a4 to 35daac3 Compare May 14, 2025 00:11
Comment on lines +3006 to +3010
and not torch._inductor.utils.use_triton_template(
FixedLayout(torch.device("cuda"), torch.float16, [400, 800])
),
"Solo triton backend not possible",
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ugh, this is a bit indirect. Could we just do the same checks we do in test_max_autotune that checks if we can run it ? see is_big_gpu. I dont like reaching into implementation details when it doesnt add any benefit.

Comment on lines +3015 to +3017
in1 = torch.randn((400, 600), device="cuda", dtype=torch.float16)
in2 = torch.randn((600, 800), device="cuda", dtype=torch.float16)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: make tensors aligned so we wont do padding

Comment on lines +3053 to +3054
prof1.export_chrome_trace(trace1)
prof2.export_chrome_trace(trace2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does this test ?

Comment on lines +24 to +26
# Indexing is based on `torch.cuda.get_device_name()`
# TODO investigate profiler support for tf32 and allow device to report correct number when it's turned on.
_device_mapping: dict[str, DeviceInfo] = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we file an issue for this as a follow up ? It's not great that we are not doing this programatically..

node_type = "in_out"
else:
node_type = "default"
counted_flops = count_flops_fx(node)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make count_flops_fx a little bit smarter where it does not try to run this unless the node has a registration in flop_counter ? currently we always instantiate the tensors, and faketensormode, and run ?

@register_flop_formula([aten.convolution, aten._convolution])
def conv_flop(x_shape, w_shape, _bias, _stride, _padding, _dilation, transposed, *args, out_shape=None, **kwargs) -> int:
@register_flop_formula([aten.convolution, aten._convolution, aten.cudnn_convolution])
def conv_flop(x_shape, w_shape, bias, stride, padding, dilation, transposed, *args, out_shape=None, **kwargs) -> int:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: reason why we needed to change arg names ?

Comment on lines +65 to +69
if len(tmp) == 4:
transposed = False

transposed = bool(tmp[6])
tmp[6] = transposed
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If len(tmp) == 4, why are we indexing into tmp[6]?

Comment on lines +61 to +71
def conv_adapter(
shapes: tuple[Any, ...], concrete: tuple[Any, ...]
) -> tuple[tuple[Any], dict[Any, Any]]:
tmp = list(shapes)
if len(tmp) == 4:
transposed = False

transposed = bool(tmp[6])
tmp[6] = transposed

kwargs = {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still think the current impl is wrong.. we should just instantiate the tensors, and then run FlopCounter mode, and run the op. you dont need to run the size function directly.

< D942 td class="blob-num blob-num-addition empty-cell">
return flop_function(*args, **kwargs)


def _estimate_gb(event: dict[str, Any]) -> float:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When is this running? for the triton kernels, we should already have this, right ?

I mentioned this in a review internally - if we're going to estimate gb for mms, we should use the mm data access formula.

Generally - we have the gb from inductor right ? Also, inductor will dedupe inputs with same buffer, so i dont know that this comment is accurate.

Comment on lines +160 to +161
if "Input Dims" not in event["args"] or "Concrete Inputs" not in event["args"]:
breakpoint()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

breakpoint() in code - is this being tested ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/inductor module: inductor release notes: inductor suppress-bc-linter Suppresses the failures of API backward-compatibility linter (Lint/bc_linter) topic: improvements topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants
0