diff --git a/test/inductor/test_analysis.py b/test/inductor/test_analysis.py new file mode 100644 index 00000000000000..19e232c9cea9f2 --- /dev/null +++ b/test/inductor/test_analysis.py @@ -0,0 +1,704 @@ +# Owner(s): ["module: inductor"] + +import json +import re +import tempfile +import unittest +import uuid +from io import StringIO +from unittest.mock import patch + +import torch +import torch.nn.functional as F +from torch._inductor.analysis.profile_analysis import ( + _augment_trace_helper, + _create_extern_mapping, + JsonProfile, + main, +) +from torch._inductor.ir import FixedLayout +from torch._inductor.utils import ( + fresh_inductor_cache, + run_and_get_code, + tabulate_2d, + zip_dicts, +) +from torch.testing._internal.common_cuda import SM70OrLater +from torch.testing._internal.common_device_type import ( + dtypes, + instantiate_device_type_tests, + skipIf, +) +from torch.testing._internal.common_utils import parametrize, run_tests, TestCase + + +example_profile = """ +{ + "schemaVersion": 1, + "deviceProperties": [ + { + "id": 0, "name": "NVIDIA H100", "totalGlobalMem": 101997215744, + "computeMajor": 9, "computeMinor": 0, + "maxThreadsPerBlock": 1024, "maxThreadsPerMultiprocessor": 2048, + "regsPerBlock": 65536, "warpSize": 32, + "sharedMemPerBlock": 49152, "numSms": 132 + , "regsPerMultiprocessor": 65536, "sharedMemPerBlockOptin": 232448, "sharedMemPerMultiprocessor": 233472 + } + ], + "cupti_version": 24, + "cuda_runtime_version": 12060, + "with_flops": 1, + "record_shapes": 1, + "cuda_driver_version": 12040, + "profile_memory": 1, + "trace_id": "301995E163ED42048FBD783860E6E7DC", + "displayTimeUnit": "ms", + "baseTimeNanoseconds": 1743521598000000000, + "traceEvents": [ + { + "ph": "X", "cat": "cpu_op", "name": "aten::convolution", "pid": 1147039, "tid": 1147039, + "ts": 198093488368.463, "dur": 425.453, + "args": { + "External id": 1340,"Sequence number": 0, "Fwd thread id": 0, "Record function id": 0, "Concrete Inputs": \ +["", "", "", "[2, 2]", "[3, 3]", "[1, 1]", "False", "[0, 0]", "1"], "Input type": ["float", "float", "", \ +"ScalarList", "ScalarList", "ScalarList", "Scalar", "ScalarList", "Scalar"], "Input Strides": [[150528, 1, 672, 3],\ +[147, 1, 21, 3], [], [], [], [], [], [], []], "Input Dims": [[1, 3, 224, 224], [64, 3, 7, 7], [], [], [], [], [], \ +[], []], "Ev Idx": 1339 + } + }, + { + "ph": "X", "cat": "cpu_op", "name": "aten::_convolution", "pid": 1147039, "tid": 1147039, + "ts": 198093488444.498, "dur": 341.867, + "args": { + "External id": 1341,"Record function id": 0, "Concrete Inputs": ["", "", "", "[2, 2]", "[3, 3]", "[1, 1]",\ + "False", "[0, 0]", "1", "False", "False", "True", "True"], "Input type": ["float", "float", "", "ScalarList",\ + "ScalarList", "ScalarList", "Scalar", "ScalarList", "Scalar", "Scalar", "Scalar", "Scalar", "Scalar"], "Input Strides": \ +[[150528, 1, 672, 3], [147, 1, 21, 3], [], [], [], [], [], [], [], [], [], [], []], "Input Dims": [[1, 3, 224, 224], \ +[64, 3, 7, 7], [], [], [], [], [], [], [], [], [], [], []], "Ev Idx": 1340 + } + }, + { + "ph": "X", "cat": "cpu_op", "name": "aten::addmm", "pid": 1147039, "tid": 1147039, + "ts": 198093513655.849, "dur": 251.130, + "args": { + "External id": 1619,"Sequence number": 0, "Fwd thread id": 0, "Record function id": 0, "Concrete Inputs": \ +["", "", "", "1", "1", ""], "Input type": ["float", "float", "float", "Scalar", "Scalar", "float"], "Input Strides":\ + [[1], [0, 1], [1, 2048], [], [], [1000, 1]], "Input Dims": [[1000], [1, 2048], [2048, 1000], [], [], [1, 1000]], \ +"Ev Idx": 1618 + } + }, + { + "ph": "X", "cat": "kernel", "name": "void cutlass_addmm", "pid": 1147039, "tid": 1147039, + "ts": 198093513655.849, "dur": 251.130, + "args": { + "External id": 1619,"Sequence number": 0, "Fwd thread id": 0, "Record function id": 0, "Ev Idx": 1618 + } + }, + { + "ph": "X", "cat": "kernel", "name": "void convolution_kernel", "pid": 1147039, "tid": 1147039, + "ts": 198093513655.849, "dur": 200.130, + "args": { + "External id": 1342, "Sequence number": 0, "Fwd thread id": 0, "Record function id": 0, "Ev Idx": 1618 + } + }, + { + "ph": "X", "cat": "cpu_op", "name": "aten::convolution", "pid": 1147039, "tid": 1147039, + "ts": 198093488444.498, "dur": 341.867, + "args": { + "External id": 1342,"Record function id": 0, "Concrete Inputs": ["", "", "", "[2, 2]", "[3, 3]", "[1, 1]", \ +"False", "[0, 0]", "1", "False", "False", "True", "True"], "Input type": ["float", "float", "", "ScalarList", \ +"ScalarList", "ScalarList", "Scalar", "ScalarList", "Scalar", "Scalar", "Scalar", "Scalar", "Scalar"], "Input \ +Strides": [[150528, 1, 672, 3], [147, 1, 21, 3], [], [], [], [], [], [], [], [], [], [], []], "Input Dims": \ +[[1, 3, 224, 224], [64, 3, 7, 7], [], [], [], [], [], [], [], [], [], [], []], "Ev Idx": 1340 + } + } +], + "traceName": "/tmp/compiled_module_profile.json" +} +""" + + +def verify_flops(self, expected_flops, out_profile): + j = 0 + for i in range(len(out_profile["traceEvents"])): + if "kernel_flop" in out_profile["traceEvents"][i]["args"]: + self.assertEqual( + out_profile["traceEvents"][i]["args"]["kernel_flop"], + expected_flops[j], + ) + j += 1 + + +def random_tensor(size, dtype, **kwargs): + if dtype in [torch.half, torch.bfloat16, torch.float, torch.double]: + return torch.randn(size, dtype=dtype, **kwargs) + elif dtype in [torch.uint8, torch.int8, torch.short, torch.int, torch.long]: + return torch.randint(0, 100, size, dtype=dtype, **kwargs) + else: + raise ValueError("Unsupported data type") + + +def cT(device, dtype): + def T(*shape, requires_grad=False): + return random_tensor( + shape, requires_grad=requires_grad, device=device, dtype=dtype + ) + + return T + + +def FlopCounterMode(*args, **kwargs): + return torch.utils.flop_counter.FlopCounterMode(*args, **kwargs, display=False) + + +TMP_DIR = tempfile.mkdtemp() + + +def trace_files(): + TRACE1 = f"{TMP_DIR}/trace1-{uuid.uuid4()}.json" + TRACE2 = f"{TMP_DIR}/trace2-{uuid.uuid4()}.json" + return TRACE1, TRACE2 + + +def omni_model(device, dtype, compile=True): + T = cT(device, dtype) + + def model(): + input_conv = T(1, 3, 56, 56) + conv_weight = T(12, 3, 5, 5) + + # Increased matrix sizes + mat1 = T(400, 600) + mat2 = T(600, 800) + + batch_mat1 = T(1, 600, 800) + batch_mat2 = T(1, 800, 20 * 48) + + # Convolution operation + conv_output = F.conv2d(input_conv, conv_weight) + + # a pointwise op + conv_output = conv_output * 10 + + # Matrix multiplication (addmm) operation + addmm_output = torch.addmm( + torch.zeros(400, 800, device=mat1.device, dtype=mat1.dtype), mat1, mat2 + ) + + # Batch matrix multiplication (bmm) operation + bmm_output = torch.bmm(batch_mat1, batch_mat2) + + # Batch addition matrix multiplication (baddbmm) operation + baddbmm_output = torch.baddbmm( + torch.zeros( + 1, 600, 20 * 48, device=batch_mat1.device, dtype=batch_mat1.dtype + ), + batch_mat1, + batch_mat2, + ) + + mm_output = torch.mm(mat1, mat2) + + return torch.cat( + [ + conv_output.flatten(), + addmm_output.flatten(), + bmm_output.flatten(), + baddbmm_output.flatten(), + mm_output.flatten(), + ] + ) + + if compile: + return torch.compile( + model, options={"benchmark_kernel": True, "profile_bandwidth": True} + ) + return model + + +def omni_model_no_addmm(device, dtype, compile=True): + T = cT(device, dtype) + + def model(): + input_conv = T(1, 3, 56, 56) + conv_weight = T(12, 3, 5, 5) + + # Increased matrix sizes + mat1 = T(400, 600) + mat2 = T(600, 800) + + batch_mat1 = T(1, 600, 800) + batch_mat2 = T(1, 800, 20 * 48) + + # Convolution operation + conv_output = F.conv2d(input_conv, conv_weight) + + # a pointwise op + conv_output = conv_output * 10 + + # Batch matrix multiplication (bmm) operation + bmm_output = torch.bmm(batch_mat1, batch_mat2) + + # Batch addition matrix multiplication (baddbmm) operation + baddbmm_output = torch.baddbmm( + torch.zeros( + 1, 600, 20 * 48, device=batch_mat1.device, dtype=batch_mat1.dtype + ), + batch_mat1, + batch_mat2, + ) + + mm_output = torch.mm(mat1, mat2) + + return torch.cat( + [ + conv_output.flatten(), + bmm_output.flatten(), + baddbmm_output.flatten(), + mm_output.flatten(), + ] + ) + + if compile: + return torch.compile( + model, options={"benchmark_kernel": True, "profile_bandwidth": True} + ) + return model + + +def omni_model_no_bmm(device, dtype, compile=True): + T = cT(device, dtype) + + def model(): + input_conv = T(1, 3, 56, 56) + conv_weight = T(12, 3, 5, 5) + + # Increased matrix sizes + mat1 = T(400, 600) + mat2 = T(600, 800) + + # Convolution operation + conv_output = F.conv2d(input_conv, conv_weight) + + # a pointwise op + conv_output = conv_output * 10 + + # Matrix multiplication (addmm) operation + addmm_output = torch.addmm( + torch.zeros(400, 800, device=mat1.device, dtype=mat1.dtype), mat1, mat2 + ) + + mm_output = torch.mm(mat1, mat2) + + return torch.cat( + [ + conv_output.flatten(), + addmm_output.flatten(), + mm_output.flatten(), + ] + ) + + if compile: + return torch.compile( + model, options={"benchmark_kernel": True, "profile_bandwidth": True} + ) + return model + + +prefix = ["profile.py"] + + +class TestUtils(TestCase): + def test_tabulate2d(self): + headers = ["Kernel", "Self H100 TIME (ms)", "Count", "Percent"] + rows = [ + ["aten::mm", 0.500, 7, 0.0], + ["aten::bmm", 0.400, 6, 0.0], + ["aten::baddbmm", 0.300, 5, 0.0], + ["aten::convolution", 0.200, 4, 0.0], + ["aten::cudnn_convolution", 0.100, 3, 0.0], + ] + table = [ + " Kernel | Self H100 TIME (ms) | Count | Percent ", + "-----------------------------------------------------------------", + " aten::mm | 0.5 | 7 | 0.0 ", + " aten::bmm | 0.4 | 6 | 0.0 ", + " aten::baddbmm | 0.3 | 5 | 0.0 ", + " aten::convolution | 0.2 | 4 | 0.0 ", + " aten::cudnn_convolution | 0.1 | 3 | 0.0 ", + ] + res = tabulate_2d(rows, headers) + for r, t in zip(res.split("\n"), table): + self.assertEqual(r, t) + + def test_zip_dicts(self): + d1 = {"a": 1, "b": 2} + d2 = {"a": 3, "c": 4} + res = zip_dicts(d1, d2, d1_default="foo", d2_default="bar") + self.assertEqual(set(res), {("a", 1, 3), ("b", 2, "bar"), ("c", "foo", 4)}) + res = zip_dicts(d1, d2) + self.assertEqual(set(res), {("a", 1, 3), ("b", 2, None), ("c", None, 4)}) + + +class TestAnalysis(TestCase): + @skipIf(not SM70OrLater, "Requires sm70") + def test_noop(self): + with ( + patch("sys.stdout", new_callable=StringIO) as mock_stdout, + patch("sys.argv", [*prefix]), + ): + main() + self.assertEqual(mock_stdout.getvalue(), "") + + @skipIf(not SM70OrLater, "Requires sm70") + @dtypes(torch.float, torch.double, torch.float16) + def test_diff(self, device, dtype): + """ + diff, testing out the nruns feature too. + """ + if device == "cpu": + # TODO cpu support + return + om = omni_model(device, dtype) + REPEAT = 5 + trace1, trace2 = trace_files() + print("first trace") + torch._dynamo.reset() # reset the cache + with fresh_inductor_cache(): + with torch.profiler.profile(record_shapes=True) as p: + om() + p.export_chrome_trace(trace1) + + print("second trace") + torch._dynamo.reset() # reset the cache + with fresh_inductor_cache(): + with torch.profiler.profile(record_shapes=True) as p: + for _ in range(REPEAT): + om() + p.export_chrome_trace(trace2) + + print("diffing...") + with patch( + "sys.argv", + [ + *prefix, + "--diff", + trace1, + "1", + "foo", + trace2, + str(REPEAT), + "bar", + "--name_limit", + "200", + ], + ): + main() + + @skipIf(not SM70OrLater, "Requires sm70") + def test_augment_trace_helper_unit(self): + js = json.loads(example_profile) + out_profile = _augment_trace_helper(js) + expected_flops = [4096000, 4096000, 223552896, 223552896, 0, 0, 0] + verify_flops(self, expected_flops, out_profile) + + @skipIf(not SM70OrLater, "Requires sm70") + @dtypes(torch.float, torch.double, torch.float16) + def test_augment_trace_helper_args(self, device, dtype): + if device == "cpu": + # cpu doesn't produce traces currently + return + om = omni_model(device, dtype) + torch._dynamo.reset() # reset the cache + with fresh_inductor_cache(): + with torch.profiler.profile(record_shapes=True) as p: + om() + trace1, trace2 = trace_files() + p.export_chrome_trace(trace1) + + with patch("sys.argv", [*prefix, "--augment_trace", trace1, trace2]): + main() + profile = JsonProfile(trace2, 1, "foo") + rep = profile.report() + self.assertTrue(len(rep.split("\n")) > 3, f"Error, empty table:\n{rep}") + # If these fail, just update them. They could change over time + self.assertIn("Kernel Name", rep) + self.assertIn("Kernel Count", rep) + self.assertIn("FLOPS", rep) + self.assertIn("bw gbps", rep) + self.assertIn("Dur (ms)", rep) + self.assertIn("Achieved", rep) + self.assertIn("|", rep) + self.assertIn("-----", rep) + + tables = profile._create_tables(profile._devices) + # check to make sure none of the cols are all zero, no empty columns + for tab in tables.values(): + header, rows = tab + ncols = len(header) - 1 + seen = [False] * ncols + for row in rows.values(): + for i in range(len(row)): + try: + val = float(row[i]) + except Exception: + continue + seen[i] = seen[i] or (val != 0.0) + + for i in range(len(seen)): + self.assertTrue( + seen[i], + f"column values from column {i + 1} with header '{header[i + 1]}' are all zero", + ) + + # check to make sure all % values are less than 100% + percents = [] + for tab in tables.values(): + header, rows = tab + for i, h in enumerate(header): + if "%" in h: + percents.append(i) + self.assertTrue(len(percents) > 0, "There are no headers with % in them") + for row in rows.values(): + for p in percents: + idx = p - 1 + self.assertTrue( + float(row[idx]) <= 100.0, + f"column values from column {idx} with header '{header[idx]}' is greater than 100%: {row[idx]}", + ) + self.assertTrue( + float(row[idx]) >= 0.0, + f"column values from column {idx} with header '{header[idx]}' is less than 0%: {row[idx]}", + ) + + @skipIf(not SM70OrLater, "Requires sm70") + @dtypes(torch.float, torch.float16) + @parametrize("maxat", [(True, "TRITON")]) + # this tests to see if we can only use a Triton backend for max autotune + @unittest.skipIf( + torch.cuda.is_available() + and not torch._inductor.utils.use_triton_template( + FixedLayout(torch.device("cuda"), torch.float16, [400, 800]) + ), + "Solo triton backend not possible", + ) + def test_inductor_meta_flop_gb_annotations(self, device, dtype, maxat): + if device == "cpu": + return + max_autotune, backends = maxat + om = omni_model_no_bmm(device, dtype) + comp_omni = torch.compile( + om, + options={ + "benchmark_kernel": True, + "profile_bandwidth": True, + "max_autotune_gemm_backends": backends, + "force_disable_caches": True, + "max_autotune": max_autotune, + }, + ) + code_string = run_and_get_code(comp_omni)[1][0] + triton_mm_string_name = r"triton_.*_fused_mm.* = async_compile\.triton" + self.assertRegex(code_string, triton_mm_string_name) + lines = code_string.split("\n") + lookforward = 50 + seen = False + for line_number, line in enumerate(lines): + if re.search(triton_mm_string_name, line): + seen = True + surrounding_lines = "\n".join( + lines[line_number : min(len(lines), line_number + lookforward)] + ) + if re.search(r"kernel_flop", surrounding_lines): + res = re.search(r"'kernel_flop': (\d+)", surrounding_lines) + self.assertNotEqual(res, None) + assert res is not None + kernel_flop_number = int(res.group(1)) + + self.assertNotEqual( + kernel_flop_number, 0, "kernel_flop should be nonzero" + ) + else: + self.assertTrue(False, "kernel_flop not found in last 10 lines") + break + self.assertTrue(seen) + + @skipIf(not SM70OrLater, "Requires sm70") + @dtypes(torch.float, torch.double, torch.float16) + @parametrize( + "maxat", + [ + (True, "TRITON"), + ], + ) + # this tests to see if we can only use a Triton backend for max autotune + @unittest.skipIf( + torch.cuda.is_available() + and not torch._inductor.utils.use_triton_template( + FixedLayout(torch.device("cuda"), torch.float16, [400, 800]) + ), + "Solo triton backend not possible", + ) + def test_triton_has_metadata(self, device, dtype, maxat): + """ + make sure that the chrome trace of triton kernels contains certain values + """ + if device == "cpu": + return + + T = cT(device, dtype) + input_conv = T(1, 3, 56, 56) + conv_weight = T(12, 3, 5, 5) + + def om(i, w): + # Convolution operation + conv_output = F.conv2d(i, w) + return conv_output + + max_autotune, backends = maxat + comp_omni = torch.compile( + om, + options={ + "benchmark_kernel": True, + "max_autotune_gemm_backends": backends, + "force_disable_caches": True, + "max_autotune": max_autotune, + }, + ) + + def verify_triton(comp): + torch._dynamo.reset() # reset the cache + with fresh_inductor_cache(): + with torch.profiler.profile(record_shapes=True) as profile: + comp(input_conv, conv_weight) + + trace1, _ = trace_files() + profile.export_chrome_trace(trace1) + with open(trace1) as f: + out_profile = json.load(f) + seen = False + for event in out_profile["traceEvents"]: + if "triton" in event["name"] and "conv" in event["name"]: + seen = True + self.assertTrue(seen, "no triton conv found") + + verify_triton(comp_omni) + + @skipIf(not SM70OrLater, "Requires sm70") + @dtypes(torch.float, torch.float16) + @parametrize( + "maxat", + [ + (False, "ATEN,TRITON"), + (True, "ATEN,TRITON"), + (True, "ATEN"), + (True, "TRITON"), + ], + ) + def test_augment_trace_against_flop_counter(self, device, dtype, maxat): + # this tests to see if we can only use a Triton backend for max autotune + max_autotune, backends = maxat + if ( + backends == "TRITON" + and torch.cuda.is_available() + and not torch._inductor.utils.use_triton_template( + FixedLayout(torch.device("cuda"), torch.float16, [400, 800]) + ) + ): + return + if device == "cpu": + return + om = omni_model(device, dtype, compile=False) + + comp_omni = torch.compile( + om, + options={ + "benchmark_kernel": True, + "max_autotune_gemm_backends": backends, + "force_disable_caches": True, + "max_autotune": max_autotune, + }, + ) + comp_omni() + + torch._dynamo.reset() # reset the cache + with fresh_inductor_cache(): + with torch.profiler.profile(record_shapes=True) as profile: + comp_omni() + + torch._dynamo.reset() # reset the cache + with fresh_inductor_cache(): + with FlopCounterMode() as mode: + comp_omni() + + trace1, trace2 = trace_files() + profile.export_chrome_trace(trace1) + with patch("sys.argv", [*prefix, "--augment_trace", trace1, trace2]): + main() + + with open(trace2) as f: + out_profile = json.load(f) + + flop_counts = mode.flop_counts + extern_mapping = _create_extern_mapping(out_profile) + + seen_mm = False + seen_bmm = False + seen_baddbmm = False + seen_conv = False + for event in out_profile["traceEvents"]: + if ( + "cat" not in event + or event["cat"] != "kernel" + or "args" not in event + or "External id" not in event["args"] + ): + continue + + external_op = extern_mapping[event["args"]["External id"]][0] + name: str = external_op["name"] + self.assertNotEqual(name, None) + self.assertEqual(type(name), str) + if name.startswith("aten::mm") or "_mm_" in name: + seen_mm = True + self.assertEqual( + event["args"]["kernel_flop"], + flop_counts["Global"][torch.ops.aten.mm], + ) + if ( + name.startswith( + ( + "aten::cudnn_convolution", + "aten::convolution", + "aten::_convolution", + ) + ) + or "_convolution_" in name + ): + seen_conv = True + self.assertEqual( + event["args"]["kernel_flop"], + flop_counts["Global"][torch.ops.aten.convolution], + ) + if name.startswith("aten::baddbmm") or "_baddbmm_" in name: + seen_baddbmm = True + self.assertEqual( + event["args"]["kernel_flop"], + flop_counts["Global"][torch.ops.aten.baddbmm], + ) + if name.startswith("aten::bmm") or "_bmm_" in name: + seen_bmm = True + self.assertEqual( + event["args"]["kernel_flop"], + flop_counts["Global"][torch.ops.aten.bmm], + ) + self.assertTrue(seen_mm) + self.assertTrue(seen_bmm) + self.assertTrue(seen_baddbmm) + self.assertTrue(seen_conv) + + +instantiate_device_type_tests(TestAnalysis, globals()) + +if __name__ == "__main__": + run_tests() diff --git a/test/profiler/test_profiler.py b/test/profiler/test_profiler.py index b204fbfcd227f0..e2cb7a8240ecbb 100644 --- a/test/profiler/test_profiler.py +++ b/test/profiler/test_profiler.py @@ -27,6 +27,7 @@ import torch.optim import torch.utils.data from torch._C._profiler import _ExperimentalConfig, _ExtraFields_PyCall +from torch._inductor.ir import FixedLayout from torch.autograd.profiler import KinetoStepTracker, profile as _profile from torch.autograd.profiler_legacy import profile as _profile_legacy from torch.profiler import ( @@ -2998,6 +2999,64 @@ def validate_json(prof): assert "Overload Name" in key_averages.table() validate_json(prof) + @unittest.skipIf(not torch.cuda.is_available(), "CUDA is required") + # this tests to see if we can only use a Triton backend for max autotune + @unittest.skipIf( + torch.cuda.is_available() + and not torch._inductor.utils.use_triton_template( + FixedLayout(torch.device("cuda"), torch.float16, [400, 800]) + ), + "Solo triton backend not possible", + ) + def test_profiler_debug_autotuner(self): + """ + This test makes sure that profiling events will be present when the kernel is run using the DebugAutotuner. + """ + in1 = torch.randn((400, 600), device="cuda", dtype=torch.float16) + in2 = torch.randn((600, 800), device="cuda", dtype=torch.float16) + + def mm(): + return torch.mm(in1, in2) + + pb_mm = torch.compile( + mm, + options={ + "benchmark_kernel": True, + "max_autotune": True, + "max_autotune_gemm_backends": "TRITON", + "profile_bandwidth": True, + }, + ) + comp_mm = torch.compile( + mm, + options={ + "benchmark_kernel": True, + "max_autotune": True, + "max_autotune_gemm_backends": "TRITON", + }, + ) + + with profile() as prof1: + pb_mm() + with profile() as prof2: + comp_mm() + + def names(prof): + return { + ev.name + for ev in prof.events() + if "mm" in ev.name or "triton" in ev.name + } + + trace1 = "/tmp/trace1_pb.json" + trace2 = "/tmp/trace2_nopb.json" + prof1.export_chrome_trace(trace1) + prof2.export_chrome_trace(trace2) + + n1 = names(prof1) + n2 = names(prof2) + self.assertEqual(n1, n2) + if __name__ == "__main__": run_tests() diff --git a/test/test_flop_counter.py b/test/test_flop_counter.py index 58400d86a81506..88276e91b06f49 100644 --- a/test/test_flop_counter.py +++ b/test/test_flop_counter.py @@ -854,5 +854,6 @@ def test_scaled_mm(self): self.assertExpectedInline(get_total_flops(mode), """860160""") + if __name__ == "__main__": run_tests() diff --git a/torch/_inductor/analysis/README.md b/torch/_inductor/analysis/README.md new file mode 100644 index 00000000000000..1e655bdd8e987f --- /dev/null +++ b/torch/_inductor/analysis/README.md @@ -0,0 +1,2 @@ +# `torch._inductor.analysis` +Contains scripts for inductor performance analysis. diff --git a/torch/_inductor/analysis/__init__.py b/torch/_inductor/analysis/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/torch/_inductor/analysis/device_info.py b/torch/_inductor/analysis/device_info.py new file mode 100644 index 00000000000000..96b68920df05a7 --- /dev/null +++ b/torch/_inductor/analysis/device_info.py @@ -0,0 +1,150 @@ +from dataclasses import dataclass +from logging import info +from typing import Optional + +import torch + + +@dataclass(frozen=True) +class DeviceInfo: + """ + Theoretical Numbers from data sheet. If two numbers are given, Tensor/Matrix Core vs not, + then the higher number is reported. Sparsity is not considered. + + + Bandwidth numbers are tricky, because there are platform differences that may not show up in the profiler trace. + For example, + """ + + tops: dict[torch.dtype, float] + dram_bw_gbs: float + dram_gb: float + + +# Indexing is based on `torch.cuda.get_device_name()` +# TODO investigate profiler support for tf32 and allow device to report correct number when it's turned on. +_device_mapping: dict[str, DeviceInfo] = { + # Source: https://resources.nvidia.com/en-us-tensor-core/nvidia-tensor-core-gpu-datasheet + "NVIDIA H100": DeviceInfo( + tops={ + torch.float64: 9.7, + torch.float32: 19.5, + torch.bfloat16: 1979.0, + torch.float16: 1979.0, + torch.float8_e8m0fnu: 3958.0, + torch.float8_e8m0fnu: 3958.0, + torch.float8_e4m3fnuz: 3958.0, + torch.float8_e5m2: 3958.0, + torch.float8_e5m2fnuz: 3958.0, + torch.float8_e8m0fnu: 3958.0, + torch.int8: 3958.0, + }, + dram_bw_gbs=3350, + dram_gb=80, + ), + # Source: https://resources.nvidia.com/en-us-tensor-core/nvidia-tensor-core-gpu-datasheet + "NVIDIA A100": DeviceInfo( + tops={ + torch.float64: 19.5, + torch.float32: 19.5, + torch.bfloat16: 312.5, + torch.float16: 312.5, + # Not in datasheet: float8 + torch.int8: 624.0, + }, + dram_bw_gbs=2039.0, + dram_gb=80.0, + ), + # Source: https://resources.nvidia.com/en-us-gpu-resources/l4-tensor-datasheet + "NVIDIA L4": DeviceInfo( + tops={ + # This is a guess, not in datasheet + torch.float64: 15.1, + torch.float32: 30.3, + torch.bfloat16: 242.0, + torch.float16: 242.0, + torch.float8_e8m0fnu: 485.0, + torch.float8_e8m0fnu: 485.0, + torch.float8_e4m3fnuz: 485.0, + torch.float8_e5m2: 485.0, + torch.float8_e5m2fnuz: 485.0, + torch.float8_e8m0fnu: 485.0, + torch.int8: 485.0, + }, + dram_bw_gbs=3350, + dram_gb=24, + ), + # Source: https://www.amd.com/content/dam/amd/en/documents/instinct-tech-docs/data-sheets/amd-instinct-mi300a-data-sheet.pdf + "AMD MI300A": DeviceInfo( + tops={ + torch.float64: 122.6, + torch.float32: 122.6, + # torch.tf32: 490.3, + torch.bfloat16: 980.6, + torch.float16: 980.6, + torch.float8_e8m0fnu: 1961.2, + torch.float8_e8m0fnu: 1961.2, + torch.float8_e4m3fnuz: 1961.2, + torch.float8_e5m2: 1961.2, + torch.float8_e5m2fnuz: 1961.2, + torch.float8_e8m0fnu: 1961.2, + torch.int8: 1961.2, + }, + dram_bw_gbs=5300.0, + dram_gb=128.0, + ), + # Source: https://www.amd.com/content/dam/amd/en/documents/instinct-tech-docs/data-sheets/amd-instinct-mi300x-data-sheet.pdf + "AMD MI300X": DeviceInfo( + tops={ + torch.float64: 163.4, + torch.float32: 163.4, + torch.bfloat16: 1307.4, + torch.float16: 1307.4, + torch.float8_e8m0fnu: 2614.9, + torch.float8_e8m0fnu: 2614.9, + torch.float8_e4m3fnuz: 2614.9, + torch.float8_e5m2: 2614.9, + torch.float8_e5m2fnuz: 2614.9, + torch.float8_e8m0fnu: 2614.9, + torch.int8: 2614.9, + }, + dram_bw_gbs=5300.0, + dram_gb=192.0, + ), +} + + +def lookup_device_info(name: str) -> Optional[DeviceInfo]: + """ + Problem: when diffing profiles between amd and nvidia, we don't have access to the device information + of the other one. Also, since the analysis is static, we should be able to do it on another device unrelated + to the recorded device. Therefore, _device_mapping statically contains the information for lots of devices. + If one is missing, please run DeviceInfo.get_device_info() and add it to _device_mapping. + name (str): name of the device to lookup. Should map onto torch.cuda.get_device_name(). + """ + if name not in _device_mapping: + return None + return _device_mapping[name] + + +def datasheet_tops(dtype: torch.dtype) -> Optional[float]: + """ + Get the theoretical TFLOPS of the device for a given dtype. This can throw an exception if the device + is not in the datasheet list above. + """ + name: Optional[str] = torch.cuda.get_device_name() + if name is None: + info("No device found, returning None") + return None + device_info = lookup_device_info(name) + if device_info is None: + log_str = f"Device {name} not in datasheet, returning None" + info(log_str) + return None + if dtype not in device_info.tops: + log_str = ( + f"Device {name} does not have a datasheet entry for {dtype}, returning None" + ) + info(log_str) + return None + return device_info.tops[dtype] diff --git a/torch/_inductor/analysis/profile_analysis.py b/torch/_inductor/analysis/profile_analysis.py new file mode 100644 index 00000000000000..2b6ff55c46efab --- /dev/null +++ b/torch/_inductor/analysis/profile_analysis.py @@ -0,0 +1,580 @@ +import json +import math +from collections import defaultdict +from dataclasses import dataclass +from logging import info +from typing import Any, Callable, Optional, Union + +import torch +from torch._inductor.analysis.device_info import DeviceInfo, lookup_device_info +from torch._inductor.utils import tabulate_2d, zip_dicts +from torch.utils import _pytree as pytree +from torch.utils._ordered_set import OrderedSet +from torch.utils.flop_counter import flop_registry + + +ATEN_PREFIX = "aten::" + + +@dataclass +class ProfileEvent: + category: str + key: str + self_device_time_ms: float + # the benchmark is run multiple times and we average the count across all the + # runs. It should be an integer but define a float just in case. + count: float + + +# adapters convert the json trace into a format that works with flops_counter +ArgsType = tuple[tuple[Any, ...], dict[Any, Any]] +AdapterType = Callable[[tuple[Any, ...], tuple[Any, ...]], ArgsType] +adapters_map: dict[str, AdapterType] = {} + + +def parse_list(lst: str) -> list[int]: + lst = lst.replace("[", "").replace("]", "") + substrings = lst.split(",") + return [int(substring.strip()) for substring in substrings] + + +def register_adapter( + aten: Union[str, list[str]], +) -> Callable[ + [AdapterType], + AdapterType, +]: + def decorator(func: AdapterType) -> AdapterType: + global _adapters_map + + if isinstance(aten, str): + adapters_map[aten] = func + else: + for at in aten: + adapters_map[at] = func + return func + + return decorator + + +@register_adapter(["convolution", "_convolution", "cudnn_convolution"]) +def conv_adapter( + shapes: tuple[Any, ...], concrete: tuple[Any, ...] +) -> tuple[tuple[Any], dict[Any, Any]]: + tmp = list(shapes) + if len(tmp) == 4: + transposed = False + + transposed = bool(tmp[6]) + tmp[6] = transposed + + kwargs = {} + if not transposed: + # calculate output shape if not transposed. + def conv_out_dims(x: int, kernel: int, stride: int) -> int: + return (x - kernel) // stride + 1 + + stride = parse_list(concrete[3]) + inp = shapes[0] + w = shapes[1] + out_x_y = [conv_out_dims(*args) for args in zip(inp[2:], w[2:], stride)] + out = [inp[0], w[0]] + out_x_y # we only need the xy values + kwargs["out_val"] = out + + return tuple(tmp[:-1]), kwargs + + +def default_adapter( + shapes: tuple[Any], concrete: tuple[Any] +) -> tuple[tuple[Any], dict[Any, Any]]: + return shapes, {} + + +@register_adapter("addmm") +def addmm_adapter( + shapes: tuple[Any], concrete: tuple[Any] +) -> tuple[tuple[Any], dict[Any, Any]]: + tmp = list(shapes)[:3] + return tuple(tmp), {} + + +@register_adapter("bmm") +def bmm_adapter( + shapes: tuple[Any], concrete: tuple[Any] +) -> tuple[tuple[Any], dict[Any, Any]]: + tmp = list(shapes) + return tuple(tmp[:2]), {} + + +@register_adapter("baddbmm") +def baddbmm_adapter( + shapes: tuple[Any], concrete: tuple[Any] +) -> tuple[tuple[Any], dict[Any, Any]]: + tmp = list(shapes)[:3] + return tuple(tmp), {} + + +@register_adapter("mm") +def mm_adapter( + shapes: tuple[Any], concrete: tuple[Any] +) -> tuple[tuple[Any], dict[Any, Any]]: + return shapes, {} + + +def _parse_kernel_name(name: str) -> Optional[str]: + if name.startswith(ATEN_PREFIX): + return name[len(ATEN_PREFIX) :] + elif "convolution" in name: + return "convolution" + elif "addmm" in name: + return "addmm" + elif "bmm" in name: + return "bmm" + elif "baddbmm" in name: + return "baddbmm" + elif "_mm" in name: + return "mm" + else: + return None + + +def _calculate_flops(event: dict[str, Any]) -> int: + """ + This function has to parse the kernel name, which is error prone. There doesn't seem to be another solution that + will support all the different backends that can generate kernels, so make sure to update this function when new + ops and backends are desired. + """ + name = event["name"] + if "kernel_flop" in event["args"] and event["args"]["kernel_flop"] != 0: + return event["args"]["kernel_flop"] + op_name = _parse_kernel_name(name) + if op_name is None: + return 0 + + op_obj = getattr(torch.ops.aten, op_name, None) + if op_obj is None or op_obj not in flop_registry: + return 0 + + flop_function = flop_registry[op_obj] + + if "Input Dims" not in event["args"] or "Concrete Inputs" not in event["args"]: + breakpoint() + input_shapes = event["args"]["Input Dims"] + concrete = event["args"]["Concrete Inputs"] + if op_name in adapters_map: + args, kwargs = adapters_map[op_name](input_shapes, concrete) + else: + args, kwargs = default_adapter(input_shapes, concrete) + return flop_function(*args, **kwargs) + + +def _estimate_gb(event: dict[str, Any]) -> float: + """ + This estimate isn't the best because it doesn't know if two input buffers are the same buffer, leading to an + overestimate of the real achieved bandwidth. + """ + if "Input type" not in event["args"] or "Input Dims" not in event["args"]: + return 0 + sizes_and_types = zip(event["args"]["Input Dims"], event["args"]["Input type"]) + bw = 0 + for size, typ in sizes_and_types: + if not hasattr(torch, typ): + isize = 0 + else: + isize = getattr(torch, typ).itemsize + bw += isize * math.prod(pytree.tree_flatten(size)[0]) + return bw / 1e9 + + +def _create_extern_mapping( + data: dict[str, Any], +) -> defaultdict[int, list[dict[str, Any]]]: + """ + compute a mapping from exteral ids to non kernels, which contain the information we need to estimate flops etc + """ + extern_mapping: defaultdict[int, list[dict[str, Any]]] = defaultdict(list) + for event in data["traceEvents"]: + if ( + "args" not in event + or "External id" not in event["args"] + or event["cat"] != "cpu_op" + ): + continue + if len(extern_mapping[event["args"]["External id"]]) > 0: + raise ParseException("duplicate external id in event") + extern_mapping[event["args"]["External id"]].append(event) + return extern_mapping + + +def _augment_trace_helper(data: dict[str, Any]) -> dict[str, Any]: + extern_mapping = _create_extern_mapping(data) + + for event in data["traceEvents"]: + if "cat" not in event or event["cat"] != "kernel": + continue + if "args" not in event: + raise ParseException(f"kernel has no args: {event}") + if "External id" not in event["args"]: + event_str = f"kernel has no External id: {event}" + info(event_str) + continue + + external_op = extern_mapping[event["args"]["External id"]][0] + flops = _calculate_flops(external_op) + if flops == 0: + flops = _calculate_flops(event) + external_op["args"]["kernel_flop"] = flops + external_op["args"]["kernel_num_gb"] = _estimate_gb(external_op) + event["args"]["kernel_flop"] = external_op["args"]["kernel_flop"] + event["args"]["kernel_num_gb"] = external_op["args"]["kernel_num_gb"] + return data + + +_dtype_map = { + "float": torch.float, + "int": torch.int, + "long": torch.long, + "long int": torch.long, + "bfloat16": torch.bfloat16, + "float16": torch.float16, +} + + +@dataclass(frozen=True) +class KernelStats: + flops: int + bw: float + latency: float + achieved_flops: float + achieved_bandwidth: float + + +KernelNameMap = defaultdict[str, OrderedSet[KernelStats]] + + +@dataclass(frozen=False) +class Device: + name: str + index: int + info: DeviceInfo + stats: KernelNameMap + + def __repr__(self) -> str: + return f"Device({self.name}, {self.index})" + + +DeviceMap = dict[int, Device] +Table = tuple[list[str], dict[str, list[str]]] + + +class JsonProfile: + _devices: DeviceMap + + def __init__( + self, + path: str, + nruns: int, + benchmark_name: Optional[str] = None, + ): + """ + Convienence class for running common operations on chrome/perfetto json traces. + """ + self.path = path + with open(path) as f: + self.data = json.load(f) + self.events = self.data["traceEvents"] + self.nruns = nruns + self.benchmark_name = benchmark_name + self._create_devices() + + def convert_dtype(self, event: dict[str, Any]) -> torch.dtype: + """ + Each op has a list of dtypes for each input arg. We need to convert these into a single dtype for flop estimation. + Issues: + - converting the strings to concrete torch.dtypes + - What if we have float32, float, float16 all in the inputs? Our choice is to use the largest buffer dtype. + """ + + if ( + "Input Dims" not in event["args"] + or "Input type" not in event["args"] + or "Concrete Inputs" not in event["args"] + ): + if "bfloat16" in event["name"]: + return torch.bfloat16 + elif "float16" in event["name"]: + return torch.float16 + else: + return torch.float + + input_sizes = event["args"]["Input Dims"] + input_types = event["args"]["Input type"] + concrete_inputs = event["args"]["Concrete Inputs"] + assert len(input_sizes) == len(input_types) + assert len(input_types) == len(concrete_inputs) + + if len(input_sizes) == 0: + raise RuntimeError("Empty input_sizes and input_types") + + biggest_size = 0 + biggest_index = 0 + for i in range(len(input_sizes)): + if concrete_inputs[i] != "": + # concrete inputs are usually small tensors, so we can just skip + continue + my_size = input_sizes[i] + total_size = sum(parse_list(my_size)) + if total_size > biggest_size: + biggest_size = total_size + biggest_index = i + ret_type = input_types[biggest_index] + if ret_type in _dtype_map: + return _dtype_map[ret_type] + raise RuntimeError(f"Unknown type: {ret_type}. Please add to _dtype_map.") + + def _create_devices(self) -> None: + self._devices = {} + for dev in self.data["deviceProperties"]: + name = dev["name"] + device_info = lookup_device_info(name) + if device_info is None: + raise RuntimeError( + f"Unsupported device in profile: {name}, please consider contributing to _device_mapping." + ) + self._devices[dev["id"]] = Device( + name, dev["id"], device_info, defaultdict(OrderedSet) + ) + + def calculate_flops(self, event: dict[str, Any]) -> int: + return _calculate_flops(event) + + def estimate_gb(self, event: dict[str, Any]) -> float: + """ + This estimate isn't the best because it doesn't know if two input buffers are the same buffer, leading to an + overestimate of the real achieved bandwidth. + """ + return _estimate_gb(event) + + def augment_trace(self) -> None: + self.data = _augment_trace_helper(self.data) + + def _compute_stats(self) -> None: + """populates the name -> stats map""" + for event in self.events: + if "cat" not in event or "args" not in event or event["cat"] != "kernel": + continue + dev = self._devices[event["args"]["device"]] + dur = event["dur"] + if "kernel_flop" in event["args"]: + assert dur != 0 + # 1000ms/s * flop / ms + op_flops = 1e3 * event["args"]["kernel_flop"] / dur + if op_flops == 0: + achieved_flops = 0 + else: + dtype = self.convert_dtype(event) + achieved_flops = 100 * op_flops / (1e12 * dev.info.tops[dtype]) + else: + op_flops = 0 + achieved_flops = 0 + + if "kernel_num_gb" in event["args"]: + assert dur != 0 + # 1000ms/s * gb / ms = gb/s + op_gbps = 1e3 * event["args"]["kernel_num_gb"] / dur + achieved_bandwidth = 100 * op_gbps / dev.info.dram_bw_gbs + else: + op_gbps = 0 + achieved_bandwidth = 0 + + dev.stats[event["name"]].add( + KernelStats( + flops=op_flops, + bw=op_gbps, + latency=dur, + achieved_bandwidth=achieved_bandwidth, + achieved_flops=achieved_flops, + ) + ) + + def _create_single_table(self, dev: Device) -> Table: + """Create a table with the devices mapped to indices.""" + headers = [ + "Kernel Name", + "Kernel Count", + "FLOPS", + "bw gbps", + "Dur (ms)", + "Achieved FLOPS %", + "Achieved Bandwidth %", + ] + rows: dict[str, list[str]] = {} + + for kernel_name, stats_set in dev.stats.items(): + ker_count = 0 + flops = 0 + flops_count = 0 + achieved_flops = 0.0 + bw = 0.0 + bw_count = 0 + achieved_bandwidth = 0.0 + latency = 0.0 + for stats in stats_set: + if stats.flops != 0: + flops += stats.flops + achieved_flops += stats.achieved_flops + flops_count += 1 + if stats.bw != 0: + bw += stats.bw + achieved_bandwidth += stats.achieved_bandwidth + bw_count += 1 + latency += stats.latency + ker_count += 1 + assert ker_count != 0 + rows[kernel_name] = [ + str(ker_count), + str(flops / flops_count if flops_count != 0 else 0), + str(bw / bw_count if bw_count != 0 else 0), + str(latency / ker_count if ker_count != 0 else 0), + str(achieved_flops / flops_count if flops_count != 0 else 0), + str(achieved_bandwidth / bw_count if bw_count != 0 else 0), + ] + + return headers, rows + + def _create_tables(self, devs: DeviceMap) -> dict[int, Table]: + return {idx: self._create_single_table(dev) for idx, dev in devs.items()} + + def _combine_tables( + self, table1: Table, table1_name: str, table2: Table, table2_name: str + ) -> Table: + new_headers = ( + ["Kernel Name"] + + [f"{table1_name} {head}" for head in table1[0][1:]] + + [f"{table2_name} {head}" for head in table2[0][1:]] + ) + t1_length = len(table1[0][1:]) + t2_length = len(table2[0][1:]) + new_rows = {} + + for key, row1, row2 in zip_dicts( + table1[1], + table2[1], + d1_default=["Empty"] * t1_length, + d2_default=["Empty"] * t2_length, + ): + new_rows[key] = row1 + row2 + return new_headers, new_rows + + def report( + self, other: Optional["JsonProfile"] = None, name_limit: int = 40 + ) -> str: + def create_ret( + table_headers: list[str], table_rows: dict[str, list[str]] + ) -> str: + table_flattened = [ + [kernel_name[:name_limit], *kernel_vals] + for kernel_name, kernel_vals in table_rows.items() + ] + return tabulate_2d(table_flattened, headers=table_headers) + + if other is not None: + self._compute_stats() + other._compute_stats() + + self_tables = self._create_tables(self._devices) + other_tables = self._create_tables(other._devices) + + self_name = ( + self.benchmark_name if self.benchmark_name is not None else "Table 1" + ) + other_name = ( + other.benchmark_name if other.benchmark_name is not None else "Table 2" + ) + + ret = [] + assert self._devices.keys() == other._devices.keys() + for device_idx, t1, t2 in zip_dicts( + self_tables, other_tables, d1_default=None, d2_default=None + ): + table_headers, table_rows = self._combine_tables( + t1, self_name, t2, other_name + ) + tab_string = create_ret(table_headers, table_rows) + ret.append(f"{self._devices[device_idx]}:\n{tab_string}") + return "\n".join(ret) + self._compute_stats() + + self_tables = self._create_tables(self._devices) + + ret = [] + for idx, table in self_tables.items(): + table_headers, table_rows = table + tab_string = create_ret(table_headers, table_rows) + ret.append(f"{self._devices[idx]}:\n{tab_string}") + return "\n".join(ret) + + def dump(self, out: str) -> None: + with open(out, "w") as f: + json.dump(self.data, f) + + +class ParseException(RuntimeError): + pass + + +def main() -> None: + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "--diff", + nargs=6, + metavar=("input_file1", "nruns1", "name1", "input_file2", "nruns2", "name2"), + help="Two json traces to compare with, specified as ", + ) + parser.add_argument( + "--name_limit", + type=int, + help="the maximum name size in the final report", + ) + parser.add_argument( + "--augment_trace", + "-a", + type=str, + nargs=2, + metavar=("input_file", "output_file"), + help="Augment a trace with inductor meta information. Provide input and output file paths.", + ) + parser.add_argument( + "--analysis", + nargs=3, + metavar=("input_file", "nruns", "name"), + help="Run analysis on a single trace, specified as ", + ) + args = parser.parse_args() + + if args.diff: + p1 = JsonProfile(args.diff[0], int(args.diff[1]), args.diff[2]) + p1.augment_trace() + p2 = JsonProfile(args.diff[3], int(args.diff[4]), args.diff[5]) + p2.augment_trace() + if args.name_limit: + print(p1.report(p2, name_limit=args.name_limit)) + else: + print(p1.report(p2)) + if args.analysis: + p1 = JsonProfile(args.analysis[0], args.analysis[1], args.analysis[2]) + p1.augment_trace() + if args.name_limit: + print(p1.report(name_limit=args.name_limit)) + else: + print(p1.report()) + if args.augment_trace: + p = JsonProfile(args.augment_trace[0], 1) + p.augment_trace() + p.dump(args.augment_trace[1]) + + +if __name__ == "__main__": + main() diff --git a/torch/_inductor/codegen/simd.py b/torch/_inductor/codegen/simd.py index 04c1a010fae9c4..0b457e02a51807 100644 --- a/torch/_inductor/codegen/simd.py +++ b/torch/_inductor/codegen/simd.py @@ -923,6 +923,13 @@ def _map_tuple_or_scalar(fn, value): return tuple(map(fn, value)) return fn(value) + def estimate_flops(self) -> Optional[int]: + flops = [ + node.estimate_flops() + for node in NodeScheduleMarker.only_nodes(self.features.node_schedule) + ] + return sum(filter(None, flops)) + def estimate_kernel_num_bytes(self): """ Try the best to estimate the total size (in bytes) of the @@ -1554,7 +1561,9 @@ def codegen_template( kernel.cse.invalidate(OrderedSet()) if not isinstance(partial_code, str): - partial_code.finalize_hook("") + # This is used to calculate flops in TritonTemplateKernels + with ir.IRNode.current_origins(template_node.node.origins): + partial_code.finalize_hook("") partial_code.finalize_hook("", strict=False) # finalize must be called after adding epilogue above diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 8f3ddb77129091..0ece3a957b3e29 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -3648,6 +3648,8 @@ def add_constexpr_arg(arg_name): if config.benchmark_kernel or config.profile_bandwidth: num_gb = self.estimate_kernel_num_bytes() / 1e9 inductor_meta["kernel_num_gb"] = num_gb + if config.benchmark_kernel: + inductor_meta["kernel_flop"] = self.estimate_flops() triton_meta["configs"] = [config_of(signature)] diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index f5774833ef5169..b689e8b9f81bd3 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -64,6 +64,7 @@ MissingOperatorWithDecomp, MissingOperatorWithoutDecomp, ) +from .fx_utils import count_flops_fx from .ir import ( Constant, DonatedBuffer, @@ -652,32 +653,24 @@ def is_small_channel(n: torch.fx.Node) -> bool: # only grouped convolutions benchmarked as slower in conv samples for inference only if is_inference: - from torch.utils.flop_counter import FlopCounterMode - flop_counts: dict[str, float] = defaultdict(float) for node in conv_nodes: - success, args, kwargs = torch._inductor.fx_utils.get_fake_args_kwargs( - node - ) - - if success: - with FlopCounterMode(display=False) as flop_counter_mode: - with V.fake_mode: - node.target(*args, **kwargs) - - counted_flops = flop_counter_mode.get_total_flops() - if is_grouped(node): - node_type = "grouped" - elif is_small_channel(node): - node_type = "small" - elif is_in_out_channel(node): - node_type = "in_out" - else: - node_type = "default" + counted_flops = count_flops_fx(node) + if counted_flops is None: + continue - flop_counts[node_type] += counted_flops + if is_grouped(node): + node_type = "grouped" + elif is_small_channel(node): + node_type = "small" + elif is_in_out_channel(node): + node_type = "in_out" else: - log.debug("Conv inputs meta not found") + node_type = "default" + + flop_counts[node_type] += counted_flops + else: + log.debug("Conv inputs meta not found") # average benchmarked channels last speedup / slowdown, < 1 is speedup. # taken from the set of convolution inputs in benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/ diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 999c87c56ae985..a69dc6329c758a 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -699,14 +699,29 @@ def kernel_call(): # reset to zero before evaluating any config self.reset_to_zero_args(*args, **kwargs) args_with_constexprs = self._get_args_with_constexprs(cloned_args, launcher) - launcher( - *args_with_constexprs, - **cloned_kwargs, - stream=stream, - ) + if autograd_profiler._is_profiler_enabled: + profiler_kwargs = self.get_profiler_kwargs(stream, launcher) + with torch._C._profiler._RecordFunctionFast( + self.inductor_meta.get("kernel_name", "triton kernel"), + args_with_constexprs, + profiler_kwargs, + ): + launcher( + *args_with_constexprs, + **cloned_kwargs, + stream=stream, + ) + + else: + launcher( + *args_with_constexprs, + **cloned_kwargs, + stream=stream, + ) self.restore_args_from_cpu(cpu_copies) - if with_profiler: + # only use profiler when not already in a profiler instance + if with_profiler and not autograd_profiler._is_profiler_enabled: from torch._inductor.utils import do_bench_using_profiling return do_bench_using_profiling(kernel_call, warmup=10, rep=40) @@ -997,6 +1012,23 @@ def benchmark_one_config(config): ).make_launcher() return config2launcher[best_config] + def get_profiler_kwargs(self, stream, launcher): + kernel_kwargs_str = ",".join( + f"{k}={v}" for (k, v) in launcher.config.kwargs.items() + ) + + return { + "kernel_file": (self.filename or ""), + "kernel_hash": self.kernel_hash, + "kernel_backend": "triton", + "stream": stream, + "num_warps": launcher.config.num_warps, + "num_stages": launcher.config.num_stages, + "kernel_kwargs": kernel_kwargs_str, + "kernel_num_gb": self.inductor_meta.get("kernel_num_gb", None), + "kernel_flop": self.inductor_meta.get("kernel_flop", None), + } + def run( self, *args, @@ -1040,19 +1072,7 @@ def run( # it is faster than entering and exiting a context manager, even if the context # manager is a nullcontext. if autograd_profiler._is_profiler_enabled: - kernel_kwargs_str = ",".join( - f"{k}={v}" for (k, v) in launcher.config.kwargs.items() - ) - - profiler_kwargs = { - "kernel_file": (self.filename or ""), - "kernel_hash": self.kernel_hash, - "kernel_backend": "triton", - "stream": stream, - "num_warps": launcher.config.num_warps, - "num_stages": launcher.config.num_stages, - "kernel_kwargs": kernel_kwargs_str, - } + profiler_kwargs = self.get_profiler_kwargs(stream, launcher) with torch._C._profiler._RecordFunctionFast( self.inductor_meta.get("kernel_name", "triton kernel"), diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index ddbc54cbcb67e0..db2cbe109527d5 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -61,6 +61,7 @@ from .codegen.triton_utils import config_of, equal_1_arg_indices, signature_to_meta from .codegen.wrapper import pexpr from .exc import CUDACompileError +from .fx_utils import count_flops_fx, countable_fx from .ir import ChoiceCaller, PrimitiveInfoType from .ops_handler import StoreMode from .runtime.benchmarking import benchmarker @@ -426,12 +427,21 @@ def estimate_kernel_num_bytes(self): ninplace_args = len(unique(self.args.inplace_buffers.values())) num_bytes = [] for i, inp in enumerate(itertools.chain(self.input_nodes, (self.output_node,))): - size = V.graph.sizevars.size_hints(inp.get_size()) + size = V.graph.sizevars.size_hints(inp.get_size(), fallback=0) numel = functools.reduce(operator.mul, size, 1) dtype_size = get_dtype_size(inp.get_dtype()) num_bytes.append(numel * dtype_size * (1 + int(i < ninplace_args))) return sum(num_bytes) + def estimate_flops(self) -> int: + for node in self.input_nodes: + for fx_node in node._current_origins: + if countable_fx(fx_node): + f = count_flops_fx(fx_node) + if f is not None: + return V.graph.sizevars.size_hints((f,), fallback=0)[0] + return 0 + def jit_lines(self): if self.use_jit: return "@triton.jit" @@ -467,6 +477,9 @@ def jit_lines(self): if config.profile_bandwidth or config.benchmark_kernel: num_gb = self.estimate_kernel_num_bytes() / 1e9 inductor_meta["kernel_num_gb"] = num_gb + if config.benchmark_kernel: + flops = self.estimate_flops() + inductor_meta["kernel_flop"] = flops template_args = f""" num_stages={self.num_stages}, diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index be819dd33a26bd..211fed357459ba 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -22,7 +22,14 @@ import textwrap import time import unittest -from collections.abc import Collection, Iterator, Mapping, MutableMapping, MutableSet +from collections.abc import ( + Collection, + Generator, + Iterator, + Mapping, + MutableMapping, + MutableSet, +) from datetime import datetime from io import StringIO from typing import ( @@ -51,6 +58,7 @@ import sympy import torch +from torch._inductor.analysis.device_info import datasheet_tops from torch._inductor.runtime.hints import DeviceProperties from torch.utils._ordered_set import OrderedSet from torch.utils._pytree import tree_map_only @@ -2104,16 +2112,24 @@ def get_backend_num_stages() -> int: @functools.lru_cache(None) -def get_device_tflops(dtype: torch.dtype) -> int: +def get_device_tflops(dtype: torch.dtype) -> float: + """ + We don't want to throw errors in this function. First check to see if the device is in device_info.py, + then fall back to the inaccurate triton estimation. + """ + ds_tops = datasheet_tops(dtype) + if ds_tops is not None: + return ds_tops + from triton.testing import get_max_simd_tflops, get_max_tensorcore_tflops assert dtype in (torch.float16, torch.bfloat16, torch.float32) if inspect.signature(get_max_simd_tflops).parameters.get("clock_rate"): # Triton API change in https://github.com/triton-lang/triton/pull/2293 - from torch._utils_internal import max_clock_rate + from torch._utils_internal import max_clock_rate_mhz - sm_clock = max_clock_rate() + sm_clock = max_clock_rate_mhz() if dtype in (torch.float16, torch.bfloat16): return get_max_tensorcore_tflops(dtype, sm_clock) @@ -3072,3 +3088,50 @@ def get_ld_library_path() -> str: path = os.pathsep.join([lib_path, path]) if path else lib_path return path + + +def tabulate_2d(elements: Sequence[Sequence[T]], headers: Sequence[T]) -> str: + widths = [len(str(e)) for e in headers] + for row in elements: + assert len(row) == len(headers) + for i, e in enumerate(row): + widths[i] = max(widths[i], len(str(e))) + lines = [] + lines.append("|".join(f" {h:{w}} " for h, w in zip(headers, widths))) + # widths whitespace horizontal separators + total_width = sum(widths) + (len(widths) * 2) + (len(widths) - 1) + lines.append("-" * total_width) + for row in elements: + lines.append("|".join(f" {e:{w}} " for e, w in zip(row, widths))) + return "\n".join(lines) + + +def zip_dicts( + dict1: dict[Any, Any], + dict2: dict[Any, Any], + d1_default: Any = None, + d2_default: Any = None, +) -> Generator[tuple[Any, Any, Any], None, None]: + """ + Zip two dictionaries together, indicating missing keys. + + Args: + dict1 (dict): The first dictionary. + dict2 (dict): The second dictionary. + d1_default (Any): the default value for the first dictionary + d2_default (Any): the default value for the second dictionary + + Yields: + tuple: A tuple containing the key, the value from dict1 (or d1_default if missing), + and the value from dict2 (or d2_default if missing). + """ + # Find the union of all keys + all_keys = OrderedSet(dict1.keys()) | OrderedSet(dict2.keys()) + + # Iterate over all keys + for key in all_keys: + # Get the values from both dictionaries, or default if missing + value1 = dict1.get(key, d1_default) + value2 = dict2.get(key, d2_default) + + yield key, value1, value2 diff --git a/torch/_inductor/wrapper_benchmark.py b/torch/_inductor/wrapper_benchmark.py index 911b719fc607ca..8722ec9aa93df3 100644 --- a/torch/_inductor/wrapper_benchmark.py +++ b/torch/_inductor/wrapper_benchmark.py @@ -1,8 +1,8 @@ import argparse -import dataclasses import datetime import tempfile from collections import defaultdict +from dataclasses import dataclass from types import ModuleType from typing import Any, Optional, Protocol @@ -159,7 +159,7 @@ def get_info_str( ) -@dataclasses.dataclass +@dataclass class ProfileEvent: category: str key: str @@ -176,6 +176,10 @@ def parse_profile_event_list( nruns: int, device_name: str, ) -> None: + """ + Parse and generate a report for an event_list. + """ + def get_self_device_time( ev: torch.autograd.profiler_util.EventList, ) -> float: @@ -295,6 +299,10 @@ def report() -> None: report() +PROFILE_DIR = tempfile.gettempdir() +PROFILE_PATH = f"{PROFILE_DIR}/compiled_module_profile.json" + + def perf_profile( wall_time_ms: float, times: int, @@ -305,14 +313,14 @@ def perf_profile( with torch.profiler.profile(record_shapes=True) as p: benchmark_compiled_module_fn(times=times, repeat=repeat) - path = f"{tempfile.gettempdir()}/compiled_module_profile.json" + path = PROFILE_PATH p.export_chrome_trace(path) print(f"Profiling result for a compiled module of benchmark {benchmark_name}:") print(f"Chrome trace for the profile is written to {path}") event_list = p.key_averages(group_by_input_shape=True) print(event_list.table(sort_by="self_device_time_total", row_limit=10)) parse_profile_event_list( - benchmark_name, event_list, wall_time_ms, times * repeat, p.use_device + benchmark_name, event_list, wall_time_ms, times * repeat, p.use_device or "" ) diff --git a/torch/_utils_internal.py b/torch/_utils_internal.py index 89fbd6787281cd..5e14000fb5aa02 100644 --- a/torch/_utils_internal.py +++ b/torch/_utils_internal.py @@ -211,7 +211,7 @@ def is_fb_unit_test() -> bool: @functools.lru_cache(None) -def max_clock_rate(): +def max_clock_rate_mhz(): if not torch.version.hip: from triton.testing import nvsmi diff --git a/torch/utils/flop_counter.py b/torch/utils/flop_counter.py index 255978dd6de077..d7435fa73266ad 100644 --- a/torch/utils/flop_counter.py +++ b/torch/utils/flop_counter.py @@ -127,7 +127,6 @@ def conv_flop_count( Returns: int: the number of flops """ - batch_size = x_shape[0] conv_shape = (x_shape if transposed else out_shape)[2:] c_out, c_in, *filter_size = w_shape @@ -146,8 +145,8 @@ def conv_flop_count( flop = prod(conv_shape) * prod(filter_size) * batch_size * c_out * c_in * 2 return flop -@register_flop_formula([aten.convolution, aten._convolution]) -def conv_flop(x_shape, w_shape, _bias, _stride, _padding, _dilation, transposed, *args, out_shape=None, **kwargs) -> int: +@register_flop_formula([aten.convolution, aten._convolution, aten.cudnn_convolution]) +def conv_flop(x_shape, w_shape, bias, stride, padding, dilation, transposed, *args, out_shape=None, **kwargs) -> int: """Count flops for convolution.""" return conv_flop_count(x_shape, w_shape, out_shape, transposed=transposed) @@ -561,6 +560,7 @@ def _efficient_attention_backward_flop( aten._scaled_mm: _scaled_mm_flop, aten.convolution: conv_flop, aten._convolution: conv_flop, + aten.cudnn_convolution: conv_flop, aten.convolution_backward: conv_backward_flop, aten._scaled_dot_product_efficient_attention: sdpa_flop, aten._scaled_dot_product_flash_attention: sdpa_flop,