8000 fix bw estimation · pytorch/pytorch@0d7e9f6 · GitHub
[go: up one dir, main page]

Skip to content

Commit 0d7e9f6

Browse files
committed
fix bw estimation
1 parent 169c156 commit 0d7e9f6

File tree

2 files changed

+45
-7
lines changed

2 files changed

+45
-7
lines changed

test/inductor/test_analysis.py

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,11 @@
99
import torch
1010
import torch.nn.functional as F
1111
import torch.utils.flop_counter
12-
from torch._inductor.analysis.profile_analysis import _augment_trace_helper, main
12+
from torch._inductor.analysis.profile_analysis import (
13+
_augment_trace_helper,
14+
JsonProfile,
15+
main,
16+
)
1317
from torch._inductor.utils import flatten, tabulate_2d, zip_dicts
1418
from torch.testing._internal.common_device_type import (
1519
dtypes,
@@ -297,10 +301,44 @@ def test_augment_trace_helper_args(self, device, dtype):
297301
om()
298302
trace1, trace2 = trace_files()
299303
p.export_chrome_trace(trace1)
300-
# patch('sys.stdout', new_callable=StringIO) as mock_stdout,
301304
with patch("sys.argv", [*prefix, "--augment_trace", trace1, trace2]):
302305
main()
303-
# self.assertEqual(mock_stdout.getvalue(), "")
306+
profile = JsonProfile(trace2, 1, "foo")
307+
rep = profile.report()
308+
# If these fail, just update them. They could change over time
309+
if device != "cpu":
310+
self.assertTrue(len(rep.split("\n")) > 4)
311+
self.assertIn("Kernel Name", rep)
312+
self.assertIn("Kernel Count", rep)
313+
self.assertIn("FLOPS", rep)
314+
self.assertIn("bw gbps", rep)
315+
self.assertIn("Dur (ms)", rep)
316+
self.assertIn("Achieved", rep)
317+
self.assertIn("|", rep)
318+
self.assertIn("-----", rep)
319+
320+
# TODO we need a robust way of checking this report.
321+
# In the mean time, make sure that no column is empty.
322+
# TODO check to make sure all % values are less than 100%
323+
tables = profile._create_tables(profile._devices)
324+
for tab in tables.values():
325+
header, rows = tab
326+
ncols = len(header) - 1
327+
seen = [False] * ncols
328+
for row in rows.values():
329+
for i in range(len(row)):
330+
try:
331+
val = float(row[i])
332+
except Exception:
333+
continue
334+
seen[i] = seen[i] or (val != 0.0)
335+
336+
if device != "cpu":
337+
for i in range(len(seen)):
338+
self.assertTrue(
339+
seen[i],
340+
f"column values from column {i + 1} with header '{header[i + 1]}' are all zero",
341+
)
304342

305343
@dtypes(torch.float, torch.double)
306344
def test_augment_trace_against_flop_counter(self, device, dtype):

torch/_inductor/analysis/profile_analysis.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def _estimate_gb(event: dict[str, Any]) -> float:
160160
This estimate isn't the best because it doesn't know if two input buffers are the same buffer, leading to an
161161
overestimate of the real achieved bandwidth.
162162
"""
163-
if "Input Type" not in event["args"] or "Input Dims" not in event["args"]:
163+
if "Input type" not in event["args"] or "Input Dims" not in event["args"]:
164164
return 0
165165
sizes_and_types = zip(event["args"]["Input Dims"], event["args"]["Input type"])
166166
bw = 0
@@ -386,7 +386,7 @@ def _compute_stats(self) -> None:
386386
achieved_flops = 0
387387
else:
388388
dtype = self.convert_dtype(event)
389-
achieved_flops = op_flops / (1e12 * dev.info.tflops[dtype])
389+
achieved_flops = 100 * op_flops / (1e12 * dev.info.tflops[dtype])
390390
else:
391391
op_flops = 0
392392
achieved_flops = 0
@@ -395,7 +395,7 @@ def _compute_stats(self) -> None:
395395
assert dur != 0
396396
< 8000 span class=pl-c># 1000ms/s * gb / ms = gb/s
397397
op_gbps = 1e3 * event["args"]["kernel_num_gb"] / dur
398-
achieved_bandwidth = op_gbps / dev.info.dram_bw_gbs
398+
achieved_bandwidth = 100 * op_gbps / dev.info.dram_bw_gbs
399399
else:
400400
op_gbps = 0
401401
achieved_bandwidth = 0
@@ -534,7 +534,7 @@ def dump(self, out: str) -> None:
534534

535535
def parse_profile_event_list(
536536
benchmark_name: str,
537-
event_list: torch.autograd.profiler_util.EventList | dict[str, Any],
537+
event_list: Union[torch.autograd.profiler_util.EventList, dict[str, Any]],
538538
wall_time_ms: float,
539539
nruns: int,
540540
device_name: str,

0 commit comments

Comments
 (0)
0