8000 Create device_info to support datasheet estimations · pytorch/pytorch@f1ae257 · GitHub
[go: up one dir, main page]

Skip to content

Commit f1ae257

Browse files
committed
Create device_info to support datasheet estimations
1 parent aead262 commit f1ae257

File tree

3 files changed

+128
-55
lines changed

3 files changed

+128
-55
lines changed
+112
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
import torch
2+
from dataclasses import dataclass
3+
4+
@dataclass(frozen=True)
5+
class DeviceInfo:
6+
"""
7+
Theoretical Numbers from data sheet. If two numbers are given, Tensor/Matrix Core vs not,
8+
then the higher number is reported. Sparsity is not considered.
9+
10+
11+
Bandwidth numbers are tricky, because there are platform differences that may not show up in the profiler trace.
12+
For example,
13+
"""
14+
tops: dict[torch.dtype, float]
15+
dram_bw_tbs: float
16+
dram_gb: float
17+
18+
# TODO investigate profiler support for tf32 and allow device to report correct number when it's turned on.
19+
_device_mapping: dict[str, DeviceInfo] = {
20+
# Source: https://resources.nvidia.com/en-us-tensor-core/nvidia-tensor-core-gpu-datasheet
21+
"NVIDIA H100": DeviceInfo(
22+
tops={
23+
torch.float64: 9.7,
24+
torch.float32: 19.5,
25+
torch.bfloat16: 1979.0,
26+
torch.float16: 1979.0,
27+
torch.float8_e8m0fnu: 3958.0,
28+
torch.float8_e8m0fnu: 3958.0,
29+
torch.float8_e4m3fnuz: 3958.0,
30+
torch.float8_e5m2: 3958.0,
31+
torch.float8_e5m2fnuz: 3958.0,
32+
torch.float8_e8m0fnu: 3958.0,
33+
torch.int8: 3958.0,
34+
},
35+
dram_bw_tbs=3350,
36+
dram_gb=80,
37+
),
38+
# Source: https://resources.nvidia.com/en-us-tensor-core/nvidia-tensor-core-gpu-datasheet
39+
"NVIDIA A100": DeviceInfo(
40+
tops={
41+
torch.float64: 19.5,
42+
torch.float32: 19.5,
43+
torch.bfloat16: 312.5,
44+
torch.float16: 312.5,
45+
# Not in datasheet: float8
46+
torch.int8: 624.0,
47+
},
48+
dram_bw_tbs=2039.0,
49+
dram_gb=80.0,
50+
),
51+
# Source: https://www.amd.com/content/dam/amd/en/documents/instinct-tech-docs/data-sheets/amd-instinct-mi300a-data-sheet.pdf
52+
"AMD MI300A": DeviceInfo(
53+
tops={
54+
torch.float64: 122.6,
55+
torch.float32: 122.6,
56+
# torch.tf32: 490.3,
57+
torch.bfloat16: 980.6,
58+
torch.float16: 980.6,
59+
torch.float8_e8m0fnu: 1961.2,
60+
torch.float8_e8m0fnu: 1961.2,
61+
torch.float8_e4m3fnuz: 1961.2,
62+
torch.float8_e5m2: 1961.2,
63+
torch.float8_e5m2fnuz: 1961.2,
64+
torch.float8_e8m0fnu: 1961.2,
65+
torch.int8: 1961.2,
66+
},
67+
dram_bw_tbs=5300.0,
68+
dram_gb=128.0,
69+
),
70+
# Source: https://www.amd.com/content/dam/amd/en/documents/instinct-tech-docs/data-sheets/amd-instinct-mi300x-data-sheet.pdf
71+
"AMD MI300X": DeviceInfo(
72+
tops={
73+
torch.float64: 163.4,
74+
torch.float32: 163.4,
75+
torch.bfloat16: 1307.4,
76+
torch.float16: 1307.4,
77+
torch.float8_e8m0fnu: 2614.9,
78+
torch.float8_e8m0fnu: 2614.9,
79+
torch.float8_e4m3fnuz: 2614.9,
80+
torch.float8_e5m2: 2614.9,
81+
torch.float8_e5m2fnuz: 2614.9,
82+
torch.float8_e8m0fnu: 2614.9,
83+
torch.int8: 2614.9,
84+
},
85+
dram_bw_tbs=5300.0,
86+
dram_gb=192.0,
87+
),
88+
}
89+
90+
91+
def lookup_device_info(name: str) -> "DeviceInfo":
92+
"""
93+
Problem: when diffing profiles between amd and nvidia, we don't have access to the device information
94+
of the other one. Also, since the analysis is static, we should be able to do it on another device unrelated
95+
to the recorded device. Therefore, _device_mapping statically contains the information for lots of devices.
96+
If one is missing, please run DeviceInfo.get_device_info() and add it to _device_mapping.
97+
name (str): name of the device to lookup. Should map onto torch.cuda.get_device_name().
98+
"""
99+
if name not in _device_mapping:
100+
raise RuntimeError(
101+
f"Unsupported device in profile: {name}, please consider contributing to _device_mapping."
102+
)
103+
return _device_mapping[name]
104+
105+
def datasheet_tops(dtype: torch.dtype) -> float:
106+
"""
107+
Get the theoretical TFLOPS of the device for a given dtype. This can throw an exception if the device
108+
is not in the datasheet list above.
109+
"""
110+
name = torch.cuda.get_device_name()
111+
device_info = lookup_device_info(name)
112+
return device_info.tops[dtype]

torch/_inductor/analysis/profile_analysis.py

+4-55
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,8 @@
88
from typing import Any, Optional, Union
99

1010
import torch
11-
from torch._inductor.utils import (
12-
flatten,
13-
get_device_tflops,
14-
get_gpu_dram_gbps,
15-
tabulate_2d,
16-
zip_dicts,
17-
)
11+
from torch._inductor.analysis.device_info import DeviceInfo, lookup_device_info
12+
from torch._inductor.utils import flatten, tabulate_2d, zip_dicts
1813
from torch.autograd import DeviceType
1914
from torch.utils._ordered_set import OrderedSet
2015
from torch.utils.flop_counter import flop_registry
@@ -212,50 +207,6 @@ def _augment_trace_helper(data: dict[str, Any]) -> dict[str, Any]:
212207
}
213208

214209

215-
@dataclass(frozen=True)
216-
class DeviceInfo:
217-
tflops: dict[torch.dtype, float]
218-
dram_bw_gbs: float
219-
220-
@staticmethod
221-
def get_device_info() -> tuple[dict[torch.dtype, int], float]:
222-
"""
223-
This is the info that populates DeviceInfo, but it needs to be run on each device separately.
224-
For new hardware, run this function and then add the information to `_device_mapping`
225-
"""
226-
# TODO support int dtypes
227-
floats = [torch.float, torch.bfloat16, torch.float16]
228-
return {
229-
dtype: get_device_tflops(dtype) for dtype in floats
230-
}, get_gpu_dram_gbps()
231-
232-
233-
_device_mapping: dict[str, DeviceInfo] = {
234-
"NVIDIA H100": DeviceInfo(
235-
tflops={
236-
torch.float32: 0.033454080000000004,
237-
torch.bfloat16: 0.5352652800000001,
238-
torch.float16: 0.5352652800000001,
239-
},
240-
dram_bw_gbs=2446.848,
241-
)
242-
}
243-
244-
245-
def lookup_device_info(name: str) -> "DeviceInfo":
246-
"""
247-
problem: when diffing profiles between amd and nvidia, we don't have access to the device information
248-
of the other one. Also, since the analysis is static, we should be able to do it on another device unrelated
249-
to the recorded device. Therefore, _device_mapping statically contains the information for lots of devices.
250-
If one is missing, please run DeviceInfo.get_device_info() and add it to _device_mapping.
251-
"""
252-
if name not in _device_mapping:
253-
raise RuntimeError(
254-
f"Unsupported device in profile: {name}, consider contributing to _device_mapping."
255-
)
256-
return _device_mapping[name]
257-
258-
259210
@dataclass(frozen=True)
260211
class KernelStats:
261212
flops: int
@@ -386,9 +337,7 @@ def _compute_stats(self) -> None:
386337
achieved_flops = 0
387338
else:
388339
dtype = self.convert_dtype(event)
389-
if event["name"].startswith("sm80_xmma_gemm_f32f32"):
390-
breakpoint()
391-
achieved_flops = 100 * op_flops / (1e12 * dev.info.tflops[dtype])
340+
achieved_flops = 100 * op_flops / (1e12 * dev.info.tops[dtype])
392341
else:
393342
op_flops = 0
394343
achieved_flops = 0
@@ -397,7 +346,7 @@ def _compute_stats(self) -> None:
397346
assert dur != 0
398347
# 1000ms/s * gb / ms = gb/s
399348
op_gbps = 1e3 * event["args"]["kernel_num_gb"] / dur
400-
achieved_bandwidth = 100 * op_gbps / dev.info.dram_bw_gbs
349+
achieved_bandwidth = 100 * op_gbps / dev.info.dram_bw_tbs
401350
else:
402351
op_gbps = 0
403352
achieved_bandwidth = 0

torch/_inductor/utils.py

+12
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
import sympy
5757

5858
import torch
59+
from torch._inductor.analysis.device_info import datasheet_tops
5960
from torch._inductor.runtime.hints import DeviceProperties
6061
from torch.fx.experimental.symbolic_shapes import ShapeEnv
6162
from torch.utils._ordered_set import OrderedSet
@@ -1895,6 +1896,16 @@ def get_backend_num_stages() -> int:
18951896

18961897
@functools.lru_cache(None)
18971898
def get_device_tflops(dtype: torch.dtype) -> int:
1899+
"""
1900+
We don't want to throw errors in this function. First check to see if the device is in device_info.py,
1901+
then fall back to the inaccurate triton estimation.
1902+
"""
1903+
try:
1904+
return datasheet_tops(dtype)
1905+
except Exception:
1906+
# Not all devices are supported, fall back to triton theroetical estimate.
1907+
pass
1908+
18981909
from triton.testing import get_max_simd_tflops, get_max_tensorcore_tflops
18991910

19001911
assert dtype in (torch.float16, torch.bfloat16, torch.float32)
@@ -2846,6 +2857,7 @@ def get_ld_library_path() -> str:
28462857

28472858
return path
28482859

2860+
28492861
def tabulate_2d(elements: Sequence[Sequence[T]], headers: Sequence[T]) -> str:
28502862
widths = [len(str(e)) for e in headers]
28512863
for row in elements:

0 commit comments

Comments
 (0)
0