8000 [ROCm] add flag torch.backends.miopen.immediate (#158951) · pytorch/pytorch@98e6c79 · GitHub
[go: up one dir, main page]

Skip to content

Commit 98e6c79

Browse files
jeffdailyyangw-dev
authored andcommitted
[ROCm] add flag torch.backends.miopen.immediate (#158951)
The MIOpen integration has changed over the years. In the past, the MIOpen default for benchmark was True and if it were set to False it would use MIOpen Immediate Mode. But with #145294 the MIOpen benchmark default changed to False and to activate immediate mode you would set the deterministic flag to True. This has proved too restrictive because benchmark and deterministic flags are independent from immediate mode. Thus, immediate mode needs its own flag. Though MIOpen still masquerades behind torch.backends.cudnn and its flags, it seemed inappropriate to add an miopen-exclusive flag to the set of cudnn flags. This PR adds the first miopen-only flag to control its immediate mode. Pull Request resolved: #158951 Approved by: https://github.com/jeffdaily Co-authored-by: Jeff Daily <jeff.daily@amd.com>
1 parent e5e1bf2 commit 98e6c79

File tree

9 files changed

+112
-12
lines changed

9 files changed

+112
-12
lines changed

aten/src/ATen/Context.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,14 @@ void Context::setBenchmarkLimitCuDNN(int b) {
334334
benchmark_limit_cudnn = b;
335335
}
336336

337+
bool Context::immediateMiopen() const {
338+
return immediate_miopen;
339+
}
340+
341+
void Context::setImmediateMiopen(bool b) {
342+
immediate_miopen = b;
343+
}
344+
337345
bool Context::allowTF32CuBLAS() const {
338346
#ifdef USE_ROCM
339347
const auto allow_tf32 = c10::utils::check_env(hipblaslt_allow_tf32);

aten/src/ATen/Context.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,8 @@ class TORCH_API Context {
205205
void setBenchmarkCuDNN(bool);
206206
int benchmarkLimitCuDNN() const;
207207
void setBenchmarkLimitCuDNN(int);
208+
bool immediateMiopen() const;
209+
void setImmediateMiopen(bool);
208210
bool deterministicCuDNN() const;
209211
void setDeterministicCuDNN(bool);
210212
bool deterministicMkldnn() const;
@@ -440,6 +442,7 @@ class TORCH_API Context {
440442
bool enabled_overrideable = true;
441443
bool allow_fp16_bf16_reduction_mathSDP = false;
442444
bool benchmark_cudnn = false;
445+
bool immediate_miopen = false;
443446
Float32MatmulPrecision float32_matmul_precision =
444447
c10::utils::check_env("TORCH_ALLOW_TF32_CUBLAS_OVERRIDE") == true
445448
? at::Float32MatmulPrecision::HIGH

aten/src/ATen/native/miopen/Conv_miopen.cpp

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -724,8 +724,7 @@ void raw_miopen_convolution_forward_out(
724724
args.odesc.set(output);
725725
args.cdesc.set(dataType, c_mode, input.dim() - 2, args.params.padding, args.params.stride, args.params.dilation, args.params.groups, benchmark, deterministic);
726726

727-
if (deterministic && !benchmark) {
728-
// immediate mode is triggered for the specific combination of benchmark=off deterministic=on
727+
if (at::globalContext().immediateMiopen()) {
729728
uint64_t solution_id;
730729
Workspace workspace = chooseSolution<miopenConvFwdAlgorithm_t>(args, &solution_id);
731730

@@ -833,8 +832,7 @@ void raw_miopen_depthwise_convolution_forward_out(
833832
args.odesc.set(output);
834833
args.cdesc.set(dataType, c_mode, input.dim() - 2, args.params.padding, args.params.stride, args.params.dilation, args.params.groups, benchmark, deterministic);
835834

836-
if (deterministic && !benchmark) {
837-
// immediate mode is triggered for the specific combination of benchmark=off deterministic=on
835+
if (at::globalContext().immediateMiopen()) {
838836
uint64_t solution_id;
839837
Workspace workspace = chooseSolution<miopenConvFwdAlgorithm_t>(args, &solution_id);
840838

@@ -989,8 +987,7 @@ void raw_miopen_convolution_backward_weight_out(
989987
args.odesc.set(grad_output);
990988
args.cdesc.set(dataType, c_mode, input.dim() - 2, args.params.padding, args.params.stride, args.params.dilation, args.params.groups, benchmark, deterministic);
991989

992-
if (deterministic && !benchmark) {
993-
// immediate mode is triggered for the specific combination of benchmark=off deterministic=on
990+
if (at::globalContext().immediateMiopen()) {
994991
uint64_t solution_id;
995992
Workspace workspace = chooseSolution<miopenConvBwdWeightsAlgorithm_t>(args, &solution_id);
996993

@@ -1034,8 +1031,7 @@ void raw_miopen_depthwise_convolution_backward_weight_out(
10341031
args.odesc.set(grad_output);
10351032
args.cdesc.set(dataType, c_mode, input.dim() - 2, args.params.padding, args.params.stride, args.params.dilation, args.params.groups, benchmark, deterministic);
10361033

1037-
if (deterministic && !benchmark) {
1038-
// immediate mode is triggered for the specific combination of benchmark=off deterministic=on
1034+
if (at::globalContext().immediateMiopen()) {
10391035
uint64_t solution_id;
10401036
Workspace workspace = chooseSolution<miopenConvBwdWeightsAlgorithm_t>(args, &solution_id);
10411037

@@ -1240,8 +1236,7 @@ void raw_miopen_convolution_backward_input_out(
12401236
args.odesc.set(grad_output);
12411237
args.cdesc.set(dataType, c_mode, grad_output.dim() - 2, args.params.padding, args.params.stride, args.params.dilation, args.params.groups, benchmark, deterministic);
12421238

1243-
if (deterministic && !benchmark) {
1244-
// immediate mode is triggered for the specific combination of benchmark=off deterministic=on
1239+
if (at::globalContext().immediateMiopen()) {
12451240
uint64_t solution_id;
12461241
Workspace workspace = chooseSolution<miopenConvBwdDataAlgorithm_t>(args, &solution_id);
12471242

@@ -1350,8 +1345,7 @@ void raw_miopen_depthwise_convolution_backward_input_out(
13501345
args.odesc.set(grad_output);
13511346
args.cdesc.set(dataType, c_mode, grad_output.dim() - 2, args.params.padding, args.params.stride, args.params.dilation, args.params.groups, benchmark, deterministic);
13521347

1353-
if (deterministic && !benchmark) {
1354-
// immediate mode is triggered for the specific combination of benchmark=off deterministic=on
1348+
if (at::globalContext().immediateMiopen()) {
13551349
uint64_t solution_id;
13561350
Workspace workspace = chooseSolution<miopenConvBwdDataAlgorithm_t>(args, &solution_id);
13571351

docs/source/backends.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,19 @@ These backends include:
253253
254254
```
255255

256+
## torch.backends.miopen
257+
258+
```{eval-rst}
259+
.. automodule:: torch.backends.miopen
260+
```
261+
262+
```{eval-rst}
263+
.. attribute:: immediate
264+
265+
A :class:`bool` that, if True, causes MIOpen to use Immediate Mode
266+
(https://rocm.docs.amd.com/projects/MIOpen/en/latest/how-to/find-and-immediate.html).
267+
```
268+
256269
## torch.backends.mps
257270

258271
```{eval-rst}

torch/_C/__init__.pyi.in

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1213,6 +1213,8 @@ def _get_mkldnn_enabled() -> _bool: ... # THPModule_userEnabledMkldnn
12131213
def _set_mkldnn_enabled(arg: _bool) -> None: ... # THPModule_setUserEnabledMkldnn
12141214
def _get_cudnn_benchmark() -> _bool: ... # THPModule_benchmarkCuDNN
12151215
def _set_cudnn_benchmark(arg: _bool) -> None: ... # THPModule_setBenchmarkCuDNN
1216+
def _get_miopen_immediate() -> _bool: ... # THPModule_userImmediateMiopen
1217+
def _set_miopen_immediate(arg: _bool) -> None: ... # THPModule_setUserImmediateMiopen
12161218
def _get_cudnn_deterministic() -> _bool: ... # THPModule_deterministicCuDNN
12171219
def _set_cudnn_deterministic(arg: _bool) -> None: ... # THPModule_setDeterministicCuDNN
12181220
def _get_mkldnn_deterministic() -> _bool: ... # THPModule_deterministicMkldnn

torch/_dynamo/trace_rules.py

Lines changed: 1 addition & 0 deletions
661
Original file line numberDiff line numberDiff line change
@@ -659,6 +659,7 @@
659659
"torch._C._get_cublas_allow_tf32",
660660
"torch._C._get_cudnn_allow_tf32",
661
"torch._C._get_cudnn_benchmark",
662+
"torch._C._get_miopen_immediate",
662663
"torch._C._get_cudnn_deterministic",
663664
"torch._C._get_cudnn_enabled",
664665
"torch._C._get_custom_class_python_wrapper",

torch/backends/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ def __init__(self, m, name):
131131
cusparselt as cusparselt,
132132
kleidiai as kleidiai,
133133
mha as mha,
134+
miopen as miopen,
134135
mkl as mkl,
135136
mkldnn as mkldnn,
136137
mps as mps,

torch/backends/miopen/__init__.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# mypy: allow-untyped-defs
2+
import sys
3+
from contextlib import contextmanager
4+
5+
import torch
6+
from torch.backends import __allow_nonbracketed_mutation, ContextProp, PropModule
7+
8+
9+
def set_flags(
10+
_immediate=None,
11+
):
12+
orig_flags = (torch._C._get_miopen_immediate(),)
13+
if _immediate is not None:
14+
torch._C._set_miopen_immediate(_immediate)
15+
return orig_flags
16+
17+
18+
@contextmanager
19+
def flags(
20+
immediate=False,
21+
):
22+
with __allow_nonbracketed_mutation():
23+
orig_flags = set_flags(
24+
immediate,
25+
)
26+
try:
27+
yield
28+
finally:
29+
# recover the previous values
30+
with __allow_nonbracketed_mutation():
31+
set_flags(*orig_flags)
32+
33+
34+
# The magic here is to allow us to intercept code like this:
35+
#
36+
# torch.backends.<miopen|mkldnn>.immediate = True
37+
38+
39+
class MiopenModule(PropModule):
40+
def __init__(self, m, name):
41+
super().__init__(m, name)
42+
43+
immediate = ContextProp(
44+
torch._C._get_miopen_immediate, torch._C._set_miopen_immediate
45+
)
46+
47+
48+
# This is the sys.modules replacement trick, see
49+
# https://stackoverflow.com/questions/2447353/getattr-on-a-module/7668273#7668273
50+
sys.modules[__name__] = MiopenModule(sys.modules[__name__], __name__)
51+
52+
# Add type annotation for the replaced module
53+
immediate: bool

torch/csrc/Module.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1172,6 +1172,29 @@ static PyObject* THPModule_benchmarkCuDNN(PyObject* _unused, PyObject* noargs) {
1172< 4367 code>1172
Py_RETURN_FALSE;
11731173
}
11741174

1175+
static PyObject* THPModule_setImmediateMiopen(
1176+
PyObject* _unused,
1177+
PyObject* arg) {
1178+
HANDLE_TH_ERRORS
1179+
TORCH_CHECK(
1180+
PyBool_Check(arg),
1181+
"set_immediate_miopen expects a bool, "
1182+
"but got ",
1183+
THPUtils_typename(arg));
1184+
at::globalContext().setImmediateMiopen(arg == Py_True);
1185+
Py_RETURN_NONE;
1186+
END_HANDLE_TH_ERRORS
1187+
}
1188+
1189+
static PyObject* THPModule_immediateMiopen(
1190+
PyObject* _unused,
1191+
PyObject* noargs) {
1192+
if (at::globalContext().immediateMiopen()) {
1193+
Py_RETURN_TRUE;
1194+
}
1195+
Py_RETURN_FALSE;
1196+
}
1197+
11751198
static PyObject* THPModule_setAllowTF32CuBLAS(
11761199
PyObject* _unused,
11771200
PyObject* arg) {
@@ -1642,6 +1665,8 @@ static std::initializer_list<PyMethodDef> TorchMethods = {
16421665
{"_set_onednn_allow_tf32", THPModule_setAllowTF32OneDNN, METH_O, nullptr},
16431666
{"_get_cudnn_benchmark", THPModule_benchmarkCuDNN, METH_NOARGS, nullptr},
16441667
{"_set_cudnn_benchmark", THPModule_setBenchmarkCuDNN, METH_O, nullptr},
1668+
{"_get_miopen_immediate", THPModule_immediateMiopen, METH_NOARGS, nullptr},
1669+
{"_set_miopen_immediate", THPModule_setImmediateMiopen, METH_O, nullptr},
16451670
{"_get_cudnn_deterministic",
16461671
THPModule_deterministicCuDNN,
16471672
METH_NOARGS,

0 commit comments

Comments
 (0)
0