8000 Inductor logging + analysis of torch.profile · pytorch/pytorch@35daac3 · GitHub
[go: up one dir, main page]

Skip to content

Commit 35daac3

Browse files
committed
Inductor logging + analysis of torch.profile
1 parent a4459cd commit 35daac3

File tree

16 files changed

+1659
-55
lines changed

16 files changed

+1659
-55
lines changed

test/inductor/test_analysis.py

Lines changed: 704 additions & 0 deletions
Large diffs are not rendered by default.

test/profiler/test_profiler.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import torch.optim
2828
import torch.utils.data
2929
from torch._C._profiler import _ExperimentalConfig, _ExtraFields_PyCall
30+
from torch._inductor.ir import FixedLayout
3031
from torch.autograd.profiler import KinetoStepTracker, profile as _profile
3132
from torch.autograd.profiler_legacy import profile as _profile_legacy
3233
from torch.profiler import (
@@ -2998,6 +2999,64 @@ def validate_json(prof):
29982999
assert "Overload Name" in key_averages.table()
29993000
validate_json(prof)
30003001

3002+
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is required")
3003+
# this tests to see if we can only use a Triton backend for max autotune
3004+
@unittest.skipIf(
3005+
torch.cuda.is_available()
3006+
and not torch._inductor.utils.use_triton_template(
3007+
FixedLayout(torch.device("cuda"), torch.float16, [400, 800])
3008+
),
3009+
"Solo triton backend not possible",
3010+
)
3011+
def test_profiler_debug_autotuner(self):
3012+
"""
3013+
This test makes sure that profiling events will be present when the kernel is run using the DebugAutotuner.
3014+
"""
3015+
in1 = torch.randn((400, 600), device="cuda", dtype=torch.float16)
3016+
in2 = torch.randn((600, 800), device="cuda", dtype=torch.float16)
3017+
3018+
def mm():
3019+
return torch.mm(in1, in2)
3020+
3021+
pb_mm = torch.compile(
3022+
mm,
3023+
options={
3024+
"benchmark_kernel": True,
3025+
"max_autotune": True,
3026+
"max_autotune_gemm_backends": "TRITON",
3027+
"profile_bandwidth": True,
3028+
},
3029+
)
3030+
comp_mm = torch.compile(
3031+
mm,
3032+
options={
3033+
"benchmark_kernel": True,
3034+
"max_autotune": True,
3035+
"max_autotune_gemm_backends": "TRITON",
3036+
},
3037+
)
3038+
3039+
with profile() as prof1:
3040+
pb_mm()
3041+
with profile() as prof2:
3042+
comp_mm()
3043+
3044+
def names(prof):
3045+
return {
3046+
ev.name
3047+
for ev in prof.events()
3048+
if "mm" in ev.name or "triton" in ev.name
3049+
}
3050+
3051+
trace1 = "/tmp/trace1_pb.json"
3052+
trace2 = "/tmp/trace2_nopb.json"
3053+
prof1.export_chrome_trace(trace1)
3054+
prof2.export_chrome_trace(trace2)
3055+
3056+
n1 = names(prof1)
3057+
n2 = names(prof2)
3058+
self.assertEqual(n1, n2)
3059+
30013060

30023061
if __name__ == "__main__":
30033062
run_tests()

test/test_flop_counter.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -854,5 +854,6 @@ def test_scaled_mm(self):
854854

855855
self.assertExpectedInline(get_total_flops(mode), """860160""")
856856

857+
857858
if __name__ == "__main__":
858859
run_tests()

torch/_inductor/analysis/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# `torch._inductor.analysis`
2+
Contains scripts for inductor performance analysis.

torch/_inductor/analysis/__init__.py

Whitespace-only changes.
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
from dataclasses import dataclass
2+
from logging import info
3+
from typing import Optional
4+
5+
import torch
6+
7+
8+
@dataclass(frozen=True)
9+
class DeviceInfo:
10+
"""
11+
Theoretical Numbers from data sheet. If two numbers are given, Tensor/Matrix Core vs not,
12+
then the higher number is reported. Sparsity is not considered.
13+
14+
15+
Bandwidth numbers are tricky, because there are platform differences that may not show up in the profiler trace.
16+
For example,
17+
"""
18+
19+
tops: dict[torch.dtype, float]
20+
dram_bw_gbs: float
21+
dram_gb: float
22+
23+
24+
# Indexing is based on `torch.cuda.get_device_name()`
25+
# TODO investigate profiler support for tf32 and allow device to report correct number when it's turned on.
26+
_device_mapping: dict[str, DeviceInfo] = {
27+
# Source: https://resources.nvidia.com/en-us-tensor-core/nvidia-tensor-core-gpu-datasheet
28+
"NVIDIA H100": DeviceInfo(
29+
tops={
30+
torch.float64: 9.7,
31+
torch.float32: 19.5,
32+
torch.bfloat16: 1979.0,
33+
torch.float16: 1979.0,
34+
torch.float8_e8m0fnu: 3958.0,
35+
torch.float8_e8m0fnu: 3958.0,
36+
torch.float8_e4m3fnuz: 3958.0,
37+
torch.float8_e5m2: 3958.0,
38+
torch.float8_e5m2fnuz: 3958.0,
39+
torch.float8_e8m0fnu: 3958.0,
40+
torch.int8: 3958.0,
41+
},
42+
dram_bw_gbs=3350,
43+
dram_gb=80,
44+
),
45+
# Source: https://resources.nvidia.com/en-us-tensor-core/nvidia-tensor-core-gpu-datasheet
46+
"NVIDIA A100": DeviceInfo(
47+
tops={
48+
torch.float64: 19.5,
49+
torch.float32: 19.5,
50+
torch.bfloat16: 312.5,
51+
torch.float16: 312.5,
52+
# Not in datasheet: float8
53+
torch.int8: 624.0,
54+
},
55+
dram_bw_gbs=2039.0,
56+
dram_gb=80.0,
57+
),
58+
# Source: https://resources.nvidia.com/en-us-gpu-resources/l4-tensor-datasheet
59+
"NVIDIA L4": DeviceInfo(
60+
tops={
61+
# This is a guess, not in datasheet
62+
torch.float64: 15.1,
63+
torch.float32: 30.3,
64+
torch.bfloat16: 242.0,
65+
torch.float16: 242.0,
66+
torch.float8_e8m0fnu: 485.0,
67+
torch.float8_e8m0fnu: 485.0,
68+
torch.float8_e4m3fnuz: 485.0,
69+
torch.float8_e5m2: 485.0,
70+
torch.float8_e5m2fnuz: 485.0,
71+
torch.float8_e8m0fnu: 485.0,
72+
torch.int8: 485.0,
73+
},
74+
dram_bw_gbs=3350,
75+
dram_gb=24,
76+
),
77+
# Source: https://www.amd.com/content/dam/amd/en/documents/instinct-tech-docs/data-sheets/amd-instinct-mi300a-data-sheet.pdf
78+
"AMD MI300A": DeviceInfo(
79+
tops={
80+
torch.float64: 122.6,
81+
torch.float32: 122.6,
82+
# torch.tf32: 490.3,
83+
torch.bfloat16: 980.6,
84+
torch.float16: 980.6,
85+
torch.float8_e8m0fnu: 1961.2,
86+
torch.float8_e8m0fnu: 1961.2,
87+
torch.float8_e4m3fnuz: 1961.2,
88+
torch.float8_e5m2: 1961.2,
89+
torch.float8_e5m2fnuz: 1961.2,
90+
torch.float8_e8m0fnu: 1961.2,
91+
torch.int8: 1961.2,
92+
},
93+
dram_bw_gbs=5300.0,
94+
dram_gb=128.0,
95+
),
96+
# Source: https://www.amd.com/content/dam/amd/en/documents/instinct-tech-docs/data-sheets/amd-instinct-mi300x-data-sheet.pdf
97+
"AMD MI300X": DeviceInfo(
98+
tops={
99+
torch.float64: 163.4,
100+
torch.float32: 163.4,
101+
torch.bfloat16: 1307.4,
102+
torch.float16: 1307.4,
103+
torch.float8_e8m0fnu: 2614.9,
104+
torch.float8_e8m0fnu: 2614.9,
105+
torch.float8_e4m3fnuz: 2614.9,
106+
torch.float8_e5m2: 2614.9,
107+
torch.float8_e5m2fnuz: 2614.9,
108+
torch.float8_e8m0fnu: 2614.9,
109+
torch.int8: 2614.9,
110+
},
111+
dram_bw_gbs=5300.0,
112+
dram_gb=192.0,
113+
),
114+
}
115+
116+
117+
def lookup_device_info(name: str) -> Optional[DeviceInfo]:
118+
"""
119+
Problem: when diffing profiles between amd and nvidia, we don't have access to the device information
120+
of the other one. Also, since the analysis is static, we should be able to do it on another device unrelated
121+
to the recorded device. Therefore, _device_mapping statically contains the information for lots of devices.
122+
If one is missing, please run DeviceInfo.get_device_info() and add it to _device_mapping.
123+
name (str): name of the device to lookup. Should map onto torch.cuda.get_device_name().
124+
"""
125+
if name not in _device_mapping:
126+
return None
127+
return _device_mapping[name]
128+
129+
130+
def datasheet_tops(dtype: torch.dtype) -> Optional[float]:
131+
"""
132+
Get the theoretical TFLOPS of the device for a given dtype. This can throw an exception if the device
133+
is not in the datasheet list above.
134+
"""
135+
name: Optional[str] = torch.cuda.get_device_name()
136+
if name is None:
137+
info("No device found, returning None")
138+
return None
139+
device_info = lookup_device_info(name)
140+
if device_info is None:
141+
log_str = f"Device {name} not in datasheet, returning None"
142+
info(log_str)
143+
return None
144+
if dtype not in device_info.tops:
145+
log_str = (
146+
f"Device {name} does not have a datasheet entry for {dtype}, returning None"
147+
)
148+
info(log_str)
149+
return None
150+
return device_info.tops[dtype]

0 commit comments

Comments
 (0)
0