8000 [Intel GPU] allow_tf32 context at XPU backend · pytorch/pytorch@e4ba41a · GitHub
[go: up one dir, main page]

Skip to content

Commit e4ba41a

Browse files
committed
[Intel GPU] allow_tf32 context at XPU backend
ghstack-source-id: 7c4a906 Pull Request resolved: #137570
1 parent ac8954d commit e4ba41a

File tree

7 files changed

+80
-5
lines changed

7 files changed

+80
-5
lines changed

aten/src/ATen/Context.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,14 @@ void Context::setAllowTF32CuDNN(bool b) {
121121
allow_tf32_cudnn = b;
122122
}
123123

124+
bool Context::allowTF32Mkldnn() const {
125+
return allow_tf32_mkldnn;
126+
}
127+
128+
void Context::setAllowTF32Mkldnn(bool b){
129+
allow_tf32_mkldnn = b;
130+
}
131+
124132
bool Context::userEnabledFlashSDP() const {
125133
return enabled_flashSDP;
126134
}

aten/src/ATen/Context.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,8 @@ class TORCH_API Context {
312312
void setFloat32MatmulPrecision(const std::string& s);
313313
bool allowTF32CuDNN() const;
314314
void setAllowTF32CuDNN(bool);
315+
bool allowTF32Mkldnn() const;
316+
void setAllowTF32Mkldnn(bool);
315317
bool allowTF32CuBLAS() const;
316318
void setAllowTF32CuBLAS(bool);
317319
Float32MatmulPrecision float32MatmulPrecision() const;
@@ -369,6 +371,7 @@ class TORCH_API Context {
369371
bool allow_fp16_reduction_cublas = true;
370372
bool allow_bf16_reduction_cublas = true;
371373
bool enabled_mkldnn = true;
374+
bool allow_tf32_mkldnn = true;
372375
bool enabled_nnpack = true;
373376
at::LinalgBackend linalg_preferred_backend =
374377
c10::utils::check_env("TORCH_LINALG_PREFER_CUSOLVER") == true

aten/src/ATen/native/mkldnn/xpu/detail/Conv.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,12 @@ sycl::event convolution(
201201
}
202202
#endif
203203

204+
auto& ctx = at::globalContext();
205+
bool allow_tf32 = ctx.allowTF32Mkldnn();
206+
if(allow_tf32) {
207+
pattr.set_fpmath_mode(dnnl::fpmath_mode::tf32);
208+
}
209+
204210
auto conv_fwd_pd = dnnl::convolution_forward::primitive_desc(
205211
engine,
206212
dnnl::prop_kind::forward,
@@ -288,6 +294,12 @@ sycl::event convolution_backward_weights(
288294
}
289295
#endif
290296

297+
auto& ctx = at::globalContext();
298+
bool allow_tf32 = ctx.allowTF32Mkldnn();
299+
if(allow_tf32) {
300+
pattr.set_fpmath_mode(dnnl::fpmath_mode::tf32);
301+
}
302+
291303
pattr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
292304
auto conv_fwd_pd = dnnl::convolution_forward::primitive_desc(
293305
engine,
@@ -390,6 +402,13 @@ sycl::event convolution_backward_data(
390402
dnnl::memory::dims _padding_front_top_left = padding_front_top_left.vec();
391403
dnnl::memory::dims _padding_back_bottom_right = padding_back_bottom_right.vec();
392404
dnnl::memory::dims _dilation = compatible_dilation(dilation);
405+
406+
auto& ctx = at::globalContext();
407+
bool allow_tf32 = ctx.allowTF32Mkldnn();
408+
if(allow_tf32) {
409+
pattr.set_fpmath_mode(dnnl::fpmath_mode::tf32);
410+
}
411+
393412
auto conv_forward_pd = dnnl::convolution_forward::primitive_desc(
394413
engine,
395414
dnnl::prop_kind::forward,

test/xpu/test_conv.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1264,6 +1264,17 @@ def test_channels_last_ouput_stride(self, device, dtype):
12641264
# input NHWC, output NHWC
12651265
assert_size_stride(out, (2, 512, 7, 7), (25088, 1, 3584, 512))
12661266

1267+
@onlyXPU
1268+
def test_mkldnn_allow_tf32_get_set(self, device):
1269+
with torch.backends.mkldnn.flags(
1270+
enabled=None, deterministic=None, allow_tf32=False
1271+
):
1272+
self.assertFalse(torch.backends.mkldnn.allow_tf32)
1273+
with torch.backends.mkldnn.flags(
1274+
enabled=None, deterministic=None, allow_tf32=True
1275+
):
1276+
self.assertTrue(torch.backends.mkldnn.allow_tf32)
1277+
12671278

12681279
instantiate_device_type_tests(
12691280
TestConvolutionNNDeviceType, globals(), only_for="xpu", allow_xpu=True

torch/_C/__init__.pyi.in

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1170,6 +1170,8 @@ def _get_cudnn_deterministic() -> _bool: ... # THPModule_deterministicCuDNN
11701170
def _set_cudnn_deterministic(arg: _bool) -> None: ... # THPModule_setDeterministicCuDNN
11711171
def _get_mkldnn_deterministic() -> _bool: ... # THPModule_deterministicMkldnn
11721172
def _set_mkldnn_deterministic(arg: _bool) -> None: ... # THPModule_setDeterministicMkldnn
1173+
def _get_mkldnn_allow_tf32() -> _bool: ... # THPModule_allowTF32Mkldnn
1174+
def _set_mkldnn_allow_tf32(arg: _bool) -> None: ... # THPModule_setAllowTF32Mkldnn
11731175
def _get_deterministic_algorithms() -> _bool: ... # THPModule_deterministicAlgorithms
11741176
def _get_deterministic_algorithms_warn_only() -> _bool: ... # THPModule_deterministicAlgorithmsWarnOnly
11751177
def _set_deterministic_algorithms(

torch/backends/mkldnn/__init__.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,18 +64,25 @@ def __exit__(self, exc_type, exc_val, exc_tb):
6464
return False
6565

6666

67-
def set_flags(_enabled, _deterministic=None):
68-
< F438 span class=pl-s1>orig_flags = (torch._C._get_mkldnn_enabled(), torch._C._get_mkldnn_deterministic())
69-
torch._C._set_mkldnn_enabled(_enabled)
67+
def set_flags(_enabled=None, _deterministic=None, _allow_tf32=None):
68+
orig_flags = (
69+
torch._C._get_mkldnn_enabled(),
70+
torch._C._get_mkldnn_deterministic(),
71+
torch._C._get_mkldnn_allow_tf32(),
72+
)
73+
if _enabled is not None:
74+
torch._C._set_mkldnn_enabled(_enabled)
7075
if _deterministic is not None:
7176
torch._C._set_mkldnn_deterministic(_deterministic)
77+
if _allow_tf32 is not None:
78+
torch._C._set_mkldnn_allow_tf32(_allow_tf32)
7279
return orig_flags
7380

7481

7582
@contextmanager
76-
def flags(enabled=False, deterministic=False):
83+
def flags(enabled=False, deterministic=False, allow_tf32=True):
7784
with __allow_nonbracketed_mutation():
78-
orig_flags = set_flags(enabled, deterministic)
85+
orig_flags = set_flags(enabled, deterministic, allow_tf32)
7986
try:
8087
yield
8188
finally:
@@ -91,10 +98,14 @@ def __init__(self, m, name):
9198
deterministic = ContextProp(
9299
torch._C._get_mkldnn_deterministic, torch._C._set_mkldnn_deterministic
93100
)
101+
allow_tf32 = ContextProp(
102+
torch._C._get_mkldnn_allow_tf32, torch._C._set_mkldnn_allow_tf32
103+
)
94104

95105

96106
if TYPE_CHECKING:
97107
enabled: ContextProp
98108
deterministic: ContextProp
109+
allow_tf32: ContextProp
99110

100111
sys.modules[__name__] = MkldnnModule(sys.modules[__name__], __name__)

torch/csrc/Module.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -888,6 +888,25 @@ PyObject* THPModule_setDeterministicAlgorithms(
888888
END_HANDLE_TH_ERRORS
889889
}
890890

891+
PyObject* THPModule_setAllowTF32Mkldnn(PyObject* _unsued, PyObject* arg) {
892+
HANDLE_TH_ERRORS
893+
TORCH_CHECK(
894+
PyBool_Check(arg),
895+
"set_allow_tf32_cublas expects a bool, "
896+
"but got ",
897+
THPUtils_typename(arg));
898+
at::globalContext().setAllowTF32Mkldnn(arg == Py_True);
899+
Py_RETURN_NONE;
900+
END_HANDLE_TH_ERRORS
901+
}
902+
903+
PyObject* THPModule_allowTF32Mkldnn(PyObject* _unused, PyObject* noargs) {
904+
if (at::globalContext().allowTF32Mkldnn())
905+
Py_RETURN_TRUE;
906+
else
907+
Py_RETURN_FALSE;
908+
}
909+
891910
PyObject* THPModule_deterministicAlgorithms(
892911
PyObject* _unused,
893912
PyObject* noargs) {
@@ -1410,6 +1429,8 @@ static PyMethodDef TorchMethods[] = { // NOLINT
14101429
{"_set_mkldnn_enabled", THPModule_setUserEnabledMkldnn, METH_O, nullptr},
14111430
{"_get_cudnn_allow_tf32", THPModule_allowTF32CuDNN, METH_NOARGS, nullptr},
14121431
{"_set_cudnn_allow_tf32", THPModule_setAllowTF32CuDNN, METH_O, nullptr},
1432+
{"_get_mkldnn_allow_tf32", THPModule_allowTF32Mkldnn, METH_NOARGS, nullptr},
1433+
{"_set_mkldnn_allow_tf32", THPModule_setAllowTF32Mkldnn, METH_O, nullptr},
14131434
{"_get_cudnn_benchmark", THPModule_benchmarkCuDNN, METH_NOARGS, nullptr},
14141435
{"_set_cudnn_benchmark", THPModule_setBenchmarkCuDNN, METH_O, nullptr},
14151436
{"_get_cudnn_deterministic",

0 commit comments

Comments
 (0)
0