|
9 | 9 | import torch
|
10 | 10 | import torch.nn.functional as F
|
11 | 11 | import torch.utils.flop_counter
|
12 |
| -from torch._inductor.analysis.profile_analysis import _augment_trace_helper, main |
| 12 | +from torch._inductor.analysis.profile_analysis import ( |
| 13 | + _augment_trace_helper, |
| 14 | + JsonProfile, |
| 15 | + main, |
| 16 | +) |
13 | 17 | from torch._inductor.utils import flatten, tabulate_2d, zip_dicts
|
14 | 18 | from torch.testing._internal.common_device_type import (
|
15 | 19 | dtypes,
|
@@ -297,10 +301,44 @@ def test_augment_trace_helper_args(self, device, dtype):
|
297 | 301 | om()
|
298 | 302 | trace1, trace2 = trace_files()
|
299 | 303 | p.export_chrome_trace(trace1)
|
300 |
| - # patch('sys.stdout', new_callable=StringIO) as mock_stdout, |
301 | 304 | with patch("sys.argv", [*prefix, "--augment_trace", trace1, trace2]):
|
302 | 305 | main()
|
303 |
| - # self.assertEqual(mock_stdout.getvalue(), "") |
| 306 | + profile = JsonProfile(trace2, 1, "foo") |
| 307 | + rep = profile.report() |
| 308 | + # If these fail, just update them. They could change over time |
| 309 | + if device != "cpu": |
| 310 | + self.assertTrue(len(rep.split("\n")) > 4) |
| 311 | + self.assertIn("Kernel Name", rep) |
| 312 | + self.assertIn("Kernel Count", rep) |
| 313 | + self.assertIn("FLOPS", rep) |
| 314 | + self.assertIn("bw gbps", rep) |
| 315 | + self.assertIn("Dur (ms)", rep) |
| 316 | + self.assertIn("Achieved", rep) |
| 317 | + self.assertIn("|", rep) |
| 318 | + self.assertIn("-----", rep) |
| 319 | + |
| 320 | + # TODO we need a robust way of checking this report. |
| 321 | + # In the mean time, make sure that no column is empty. |
| 322 | + # TODO check to make sure all % values are less than 100% |
| 323 | + tables = profile._create_tables(profile._devices) |
| 324 | + for tab in tables.values(): |
| 325 | + header, rows = tab |
| 326 | + ncols = len(header) - 1 |
| 327 | + seen = [False] * ncols |
| 328 | + for row in rows.values(): |
| 329 | + for i in range(len(row)): |
| 330 | + try: |
| 331 | + val = float(row[i]) |
| 332 | + except Exception: |
| 333 | + continue |
| 334 | + seen[i] = seen[i] or (val != 0.0) |
| 335 | + |
| 336 | + if device != "cpu": |
| 337 | + for i in range(len(seen)): |
| 338 | + self.assertTrue( |
| 339 | + seen[i], |
| 340 | + f"column values from column {i + 1} with header '{header[i + 1]}' are all zero", |
| 341 | + ) |
304 | 342 |
|
305 | 343 | @dtypes(torch.float, torch.double)
|
306 | 344 | def test_augment_trace_against_flop_counter(self, device, dtype):
|
|
0 commit comments