From 1bcd7977ce7119870a3fd4ca4f6eee9eba323728 Mon Sep 17 00:00:00 2001 From: Jeff Daily Date: Wed, 23 Jul 2025 19:54:57 +0000 Subject: [PATCH 1/3] [ROCm] add flag torch.backends.miopen.immediate --- aten/src/ATen/Context.cpp | 8 +++ aten/src/ATen/Context.h | 3 ++ aten/src/ATen/native/miopen/Conv_miopen.cpp | 18 +++---- torch/_C/__init__.pyi.in | 2 + torch/_dynamo/trace_rules.py | 1 + torch/backends/__init__.py | 1 + torch/backends/miopen/__init__.py | 57 +++++++++++++++++++++ torch/csrc/Module.cpp | 21 ++++++++ 8 files changed, 99 insertions(+), 12 deletions(-) create mode 100644 torch/backends/miopen/__init__.py diff --git a/aten/src/ATen/Context.cpp b/aten/src/ATen/Context.cpp index ded7743c4d86..05e4d3444dc0 100644 --- a/aten/src/ATen/Context.cpp +++ b/aten/src/ATen/Context.cpp @@ -334,6 +334,14 @@ void Context::setBenchmarkLimitCuDNN(int b) { benchmark_limit_cudnn = b; } +bool Context::immediateMiopen() const { + return immediate_miopen; +} + +void Context::setImmediateMiopen(bool b) { + immediate_miopen = b; +} + bool Context::allowTF32CuBLAS() const { #ifdef USE_ROCM const auto allow_tf32 = c10::utils::check_env(hipblaslt_allow_tf32); diff --git a/aten/src/ATen/Context.h b/aten/src/ATen/Context.h index d612030461c2..945076f3f012 100644 --- a/aten/src/ATen/Context.h +++ b/aten/src/ATen/Context.h @@ -205,6 +205,8 @@ class TORCH_API Context { void setBenchmarkCuDNN(bool); int benchmarkLimitCuDNN() const; void setBenchmarkLimitCuDNN(int); + bool immediateMiopen() const; + void setImmediateMiopen(bool); bool deterministicCuDNN() const; void setDeterministicCuDNN(bool); bool deterministicMkldnn() const; @@ -440,6 +442,7 @@ class TORCH_API Context { bool enabled_overrideable = true; bool allow_fp16_bf16_reduction_mathSDP = false; bool benchmark_cudnn = false; + bool immediate_miopen = false; Float32MatmulPrecision float32_matmul_precision = c10::utils::check_env("TORCH_ALLOW_TF32_CUBLAS_OVERRIDE") == true ? at::Float32MatmulPrecision::HIGH diff --git a/aten/src/ATen/native/miopen/Conv_miopen.cpp b/aten/src/ATen/native/miopen/Conv_miopen.cpp index d2cef0ca6218..2d0d98bdd436 100644 --- a/aten/src/ATen/native/miopen/Conv_miopen.cpp +++ b/aten/src/ATen/native/miopen/Conv_miopen.cpp @@ -724,8 +724,7 @@ void raw_miopen_convolution_forward_out( args.odesc.set(output); args.cdesc.set(dataType, c_mode, input.dim() - 2, args.params.padding, args.params.stride, args.params.dilation, args.params.groups, benchmark, deterministic); - if (deterministic && !benchmark) { - // immediate mode is triggered for the specific combination of benchmark=off deterministic=on + if (at::globalContext().immediateMiopen()) { uint64_t solution_id; Workspace workspace = chooseSolution(args, &solution_id); @@ -833,8 +832,7 @@ void raw_miopen_depthwise_convolution_forward_out( args.odesc.set(output); args.cdesc.set(dataType, c_mode, input.dim() - 2, args.params.padding, args.params.stride, args.params.dilation, args.params.groups, benchmark, deterministic); - if (deterministic && !benchmark) { - // immediate mode is triggered for the specific combination of benchmark=off deterministic=on + if (at::globalContext().immediateMiopen()) { uint64_t solution_id; Workspace workspace = chooseSolution(args, &solution_id); @@ -989,8 +987,7 @@ void raw_miopen_convolution_backward_weight_out( args.odesc.set(grad_output); args.cdesc.set(dataType, c_mode, input.dim() - 2, args.params.padding, args.params.stride, args.params.dilation, args.params.groups, benchmark, deterministic); - if (deterministic && !benchmark) { - // immediate mode is triggered for the specific combination of benchmark=off deterministic=on + if (at::globalContext().immediateMiopen()) { uint64_t solution_id; Workspace workspace = chooseSolution(args, &solution_id); @@ -1034,8 +1031,7 @@ void raw_miopen_depthwise_convolution_backward_weight_out( args.odesc.set(grad_output); args.cdesc.set(dataType, c_mode, input.dim() - 2, args.params.padding, args.params.stride, args.params.dilation, args.params.groups, benchmark, deterministic); - if (deterministic && !benchmark) { - // immediate mode is triggered for the specific combination of benchmark=off deterministic=on + if (at::globalContext().immediateMiopen()) { uint64_t solution_id; Workspace workspace = chooseSolution(args, &solution_id); @@ -1240,8 +1236,7 @@ void raw_miopen_convolution_backward_input_out( args.odesc.set(grad_output); args.cdesc.set(dataType, c_mode, grad_output.dim() - 2, args.params.padding, args.params.stride, args.params.dilation, args.params.groups, benchmark, deterministic); - if (deterministic && !benchmark) { - // immediate mode is triggered for the specific combination of benchmark=off deterministic=on + if (at::globalContext().immediateMiopen()) { uint64_t solution_id; Workspace workspace = chooseSolution(args, &solution_id); @@ -1350,8 +1345,7 @@ void raw_miopen_depthwise_convolution_backward_input_out( args.odesc.set(grad_output); args.cdesc.set(dataType, c_mode, grad_output.dim() - 2, args.params.padding, args.params.stride, args.params.dilation, args.params.groups, benchmark, deterministic); - if (deterministic && !benchmark) { - // immediate mode is triggered for the specific combination of benchmark=off deterministic=on + if (at::globalContext().immediateMiopen()) { uint64_t solution_id; Workspace workspace = chooseSolution(args, &solution_id); diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index dea17d26ef21..9ee4f9ddfedc 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -1213,6 +1213,8 @@ def _get_mkldnn_enabled() -> _bool: ... # THPModule_userEnabledMkldnn def _set_mkldnn_enabled(arg: _bool) -> None: ... # THPModule_setUserEnabledMkldnn def _get_cudnn_benchmark() -> _bool: ... # THPModule_benchmarkCuDNN def _set_cudnn_benchmark(arg: _bool) -> None: ... # THPModule_setBenchmarkCuDNN +def _get_miopen_immediate() -> _bool: ... # THPModule_userImmediateMiopen +def _set_miopen_immediate(arg: _bool) -> None: ... # THPModule_setUserImmediateMiopen def _get_cudnn_deterministic() -> _bool: ... # THPModule_deterministicCuDNN def _set_cudnn_deterministic(arg: _bool) -> None: ... # THPModule_setDeterministicCuDNN def _get_mkldnn_deterministic() -> _bool: ... # THPModule_deterministicMkldnn diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 4ff88a25bce3..fc3665bac8c9 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -659,6 +659,7 @@ "torch._C._get_cublas_allow_tf32", "torch._C._get_cudnn_allow_tf32", "torch._C._get_cudnn_benchmark", + "torch._C._get_miopen_immediate", "torch._C._get_cudnn_deterministic", "torch._C._get_cudnn_enabled", "torch._C._get_custom_class_python_wrapper", diff --git a/torch/backends/__init__.py b/torch/backends/__init__.py index c5ab7640386a..c02a8c36fd08 100644 --- a/torch/backends/__init__.py +++ b/torch/backends/__init__.py @@ -131,6 +131,7 @@ def __init__(self, m, name): cusparselt as cusparselt, kleidiai as kleidiai, mha as mha, + miopen as miopen, mkl as mkl, mkldnn as mkldnn, mps as mps, diff --git a/torch/backends/miopen/__init__.py b/torch/backends/miopen/__init__.py new file mode 100644 index 000000000000..85f85914a943 --- /dev/null +++ b/torch/backends/miopen/__init__.py @@ -0,0 +1,57 @@ +# mypy: allow-untyped-defs +import sys +from contextlib import contextmanager + +import torch +from torch.backends import ( + __allow_nonbracketed_mutation, + ContextProp, + PropModule, +) + + +def set_flags( + _immediate=None, +): + orig_flags = ( + torch._C._get_miopen_immediate(), + ) + if _immediate is not None: + torch._C._set_miopen_immediate(_immediate) + return orig_flags + + +@contextmanager +def flags( + immediate=False, +): + with __allow_nonbracketed_mutation(): + orig_flags = set_flags( + immediate, + ) + try: + yield + finally: + # recover the previous values + with __allow_nonbracketed_mutation(): + set_flags(*orig_flags) + + +# The magic here is to allow us to intercept code like this: +# +# torch.backends..immediate = True + + +class MiopenModule(PropModule): + def __init__(self, m, name): + super().__init__(m, name) + + immediate = ContextProp(torch._C._get_miopen_immediate, torch._C._set_miopen_immediate) + + +# This is the sys.modules replacement trick, see +# https://stackoverflow.com/questions/2447353/getattr-on-a-module/7668273#7668273 +sys.modules[__name__] = MiopenModule(sys.modules[__name__], __name__) + +# Add type annotation for the replaced module +immediate: bool diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index 15efa62ae978..b51e1029af3a 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -1172,6 +1172,25 @@ static PyObject* THPModule_benchmarkCuDNN(PyObject* _unused, PyObject* noargs) { Py_RETURN_FALSE; } +static PyObject* THPModule_setImmediateMiopen(PyObject* _unused, PyObject* arg) { + HANDLE_TH_ERRORS + TORCH_CHECK( + PyBool_Check(arg), + "set_immediate_miopen expects a bool, " + "but got ", + THPUtils_typename(arg)); + at::globalContext().setImmediateMiopen(arg == Py_True); + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + +static PyObject* THPModule_immediateMiopen(PyObject* _unused, PyObject* noargs) { + if (at::globalContext().immediateMiopen()) { + Py_RETURN_TRUE; + } + Py_RETURN_FALSE; +} + static PyObject* THPModule_setAllowTF32CuBLAS( PyObject* _unused, PyObject* arg) { @@ -1642,6 +1661,8 @@ static std::initializer_list TorchMethods = { {"_set_onednn_allow_tf32", THPModule_setAllowTF32OneDNN, METH_O, nullptr}, {"_get_cudnn_benchmark", THPModule_benchmarkCuDNN, METH_NOARGS, nullptr}, {"_set_cudnn_benchmark", THPModule_setBenchmarkCuDNN, METH_O, nullptr}, + {"_get_miopen_immediate", THPModule_immediateMiopen, METH_NOARGS, nullptr}, + {"_set_miopen_immediate", THPModule_setImmediateMiopen, METH_O, nullptr}, {"_get_cudnn_deterministic", THPModule_deterministicCuDNN, METH_NOARGS, From b6aff027c096d4cc5301e821bd7e0cd5ac6403a2 Mon Sep 17 00:00:00 2001 From: Jeff Daily Date: Wed, 23 Jul 2025 21:37:57 +0000 Subject: [PATCH 2/3] lint --- torch/backends/miopen/__init__.py | 14 +++++--------- torch/csrc/Module.cpp | 8 ++++++-- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/torch/backends/miopen/__init__.py b/torch/backends/miopen/__init__.py index 85f85914a943..93453cc11592 100644 --- a/torch/backends/miopen/__init__.py +++ b/torch/backends/miopen/__init__.py @@ -3,19 +3,13 @@ from contextlib import contextmanager import torch -from torch.backends import ( - __allow_nonbracketed_mutation, - ContextProp, - PropModule, -) +from torch.backends import __allow_nonbracketed_mutation, ContextProp, PropModule def set_flags( _immediate=None, ): - orig_flags = ( - torch._C._get_miopen_immediate(), - ) + orig_flags = (torch._C._get_miopen_immediate(),) if _immediate is not None: torch._C._set_miopen_immediate(_immediate) return orig_flags @@ -46,7 +40,9 @@ class MiopenModule(PropModule): def __init__(self, m, name): super().__init__(m, name) - immediate = ContextProp(torch._C._get_miopen_immediate, torch._C._set_miopen_immediate) + immediate = ContextProp( + torch._C._get_miopen_immediate, torch._C._set_miopen_immediate + ) # This is the sys.modules replacement trick, see diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index b51e1029af3a..aab2a31402aa 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -1172,7 +1172,9 @@ static PyObject* THPModule_benchmarkCuDNN(PyObject* _unused, PyObject* noargs) { Py_RETURN_FALSE; } -static PyObject* THPModule_setImmediateMiopen(PyObject* _unused, PyObject* arg) { +static PyObject* THPModule_setImmediateMiopen( + PyObject* _unused, + PyObject* arg) { HANDLE_TH_ERRORS TORCH_CHECK( PyBool_Check(arg), @@ -1184,7 +1186,9 @@ static PyObject* THPModule_setImmediateMiopen(PyObject* _unused, PyObject* arg) END_HANDLE_TH_ERRORS } -static PyObject* THPModule_immediateMiopen(PyObject* _unused, PyObject* noargs) { +static PyObject* THPModule_immediateMiopen( + PyObject* _unused, + PyObject* noargs) { if (at::globalContext().immediateMiopen()) { Py_RETURN_TRUE; } From 88136b97e783ad4674feae0caa14ff0e518088ff Mon Sep 17 00:00:00 2001 From: Jeff Daily Date: Thu, 24 Jul 2025 19:10:29 +0000 Subject: [PATCH 3/3] document new module --- docs/source/backends.md | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/docs/source/backends.md b/docs/source/backends.md index 6b8cc8bd7072..3e6cdc9697bf 100644 --- a/docs/source/backends.md +++ b/docs/source/backends.md @@ -253,6 +253,19 @@ These backends include: ``` +## torch.backends.miopen + +```{eval-rst} +.. automodule:: torch.backends.miopen +``` + +```{eval-rst} +.. attribute:: immediate + + A :class:`bool` that, if True, causes MIOpen to use Immediate Mode + (https://rocm.docs.amd.com/projects/MIOpen/en/latest/how-to/find-and-immediate.html). +``` + ## torch.backends.mps ```{eval-rst}