8000 [Intel GPU] allow_tf32 for oneDNN backend - XPU part (#137570) · pytorch/pytorch@ae351d4 · GitHub
[go: up one dir, main page]

Skip to content

Commit ae351d4

Browse files
ZhiweiYan-96guangyey
authored andcommitted
[Intel GPU] allow_tf32 for oneDNN backend - XPU part (#137570)
# Motivation Add context variable `torch.bachend.mkldnn.allow_tf32` to control tf32 computation in convolution kernels at XPU side. The tf32 data type is beneficial to improve the performance of deep learning workloads during training/inference. Current PR uses the [oneDNN API fpmath_mode](https://oneapi-src.github.io/oneDNN/dev_guide_attributes_fpmath_mode.html#the-floating-point-math-mode-attribute) to trigger the tf32 acceleration in convolution kernels. # Valiadation * ut to test context variable `python test/xpu/test_conv.py -k test_mkldnn_allow_tf32_get_set` * Runtime exemplification ``` onednn_verbose,primitive,exec,gpu:0,convolution,jit:ir,forward_training,src_f32::blocked:abcd::f0 wei_f32::blocked:abcd::f0 bia_f32::blocked:a::f0 dst_f32::blocked:abcd::f0,attr-scratchpad:user attr-fpmath:tf32,alg:convolution_direct,mb20_ic16oc33_ih50oh24kh3sh2dh0ph0_iw100ow49kw3sw2dw0pw0,0.649902 onednn_verbose,primitive,exec,gpu:0,convolution,jit:ir,forward_training,src_f32::blocked:abcd::f0 wei_f32::blocked:abcd::f0 bia_f32::blocked:a::f0 dst_f32::blocked:abcd::f0,attr-scratchpad:user attr-fpmath:tf32,alg:convolution_direct,mb20_ic33oc33_ih24oh24kh3sh1dh0ph1_iw49ow49kw3sw1dw0pw1,0.151855 onednn_verbose,primitive,exec,gpu:0,convolution,jit:ir,backward_data,src_f32::blocked:abcd::f0 wei_f32::blocked:abcd::f0 bia_undef::undef::: dst_f32::blocked:abcd::f0,attr-scratchpad:user attr-fpmath:tf32,alg:convolution_direct,mb20_ic33oc33_ih24oh24kh3sh1dh0ph1_iw49ow49kw3sw1dw0pw1,0.167969 onednn_verbose,primitive,exec,gpu:0,convolution,jit:ir,backward_weights,src_f32::blocked:abcd::f0 wei_f32::blocked:abcd::f0 bia_f32::blocked:a::f0 dst_f32::blocked:abcd::f0,attr-scratchpad:user attr-fpmath:tf32,alg:convolution_direct,mb20_ic33oc33_ih24oh24kh3sh1dh0ph1_iw49ow49kw3sw1dw0pw1,0.26709 onednn_verbose,primitive,exec,gpu:0,convolution,jit:ir,backward_weights,src_f32::blocked:abcd::f0 wei_f32::blocked:abcd::f0 bia_f32::blocked:a::f0 dst_f32::blocked:abcd::f0,attr-scratchpad:user attr-fpmath:tf32,alg:convolution_direct,mb20_ic16oc33_ih50oh24kh3sh2dh0ph0_iw100ow49kw3sw2dw0pw0,0.219971 ``` According to the field `fpmath:tf32` in verbose, we could see that, current context setting utils could successfully trigger tf32 computation in conv forward/backward_data/backward_weights kernels. Pull Request resolved: #137570 Approved by: https://github.com/guangyey, https://github.com/EikanWang, https://github.com/atalman, https://github.com/malfet Co-authored-by: Yu, Guangye <guangye.yu@intel.com>
1 parent 198ffbd commit ae351d4

File tree

9 files changed

+89
-5
lines changed

9 files changed

+89
-5
lines changed

aten/src/ATen/Context.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,18 @@ std::array<at::SDPBackend, at::num_sdp_backends> Context::sDPPriorityOrder() {
137137
return sdp_priority_order;
138138
}
139139

140+
bool Context::allowTF32OneDNN() const {
141+
return allow_tf32_onednn;
142+
}
143+
144+
void Context::setAllowTF32OneDNN(bool b){
145+
#ifdef USE_XPU
146+
allow_tf32_onednn = b;
147+
#else
148+
TORCH_WARN("TF32 acceleration on top of oneDNN is available for Intel GPUs. The current Torch version does not have Intel GPU Support.");
149+
#endif
150+
}
151+
140152
bool Context::userEnabledFlashSDP() const {
141153
return enabled_flashSDP;
142154
}

aten/src/ATen/Context.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,8 @@ class TORCH_API Context {
333333
void setFloat32MatmulPrecision(const std::string& s);
334334
bool allowTF32CuDNN() const;
335335
void setAllowTF32CuDNN(bool);
336+
bool allowTF32OneDNN() const;
337+
void setAllowTF32OneDNN(bool);
336338
bool allowTF32CuBLAS() const;
337339
void setAllowTF32CuBLAS(bool);
338340
Float32MatmulPrecision float32MatmulPrecision() const;
@@ -422,6 +424,7 @@ class TORCH_API Context {
422424
bool allow_bf16_reduction_cublas = true;
423425
bool allow_fp16_accumulation_cublas = false;
424426
bool enabled_mkldnn = true;
427+
bool allow_tf32_onednn = false;
425428
bool enabled_nnpack = true;
426429
at::LinalgBackend linalg_preferred_backend =
427430
c10::utils::check_env("TORCH_LINALG_PREFER_CUSOLVER") == true

aten/src/ATen/native/mkldnn/xpu/detail/Conv.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,8 @@ sycl::event convolution(
120120
}
121121
#endif
122122

123+
at::native::onednn::apply_tf32_if_allowed(pattr);
124+
123125
auto conv_fwd_pd = dnnl::convolution_forward::primitive_desc(
124126
engine,
125127
dnnl::prop_kind::forward,
@@ -211,6 +213,8 @@ sycl::event convolution_backward_weights(
211213
}
212214
#endif
213215

216+
at::native::onednn::apply_tf32_if_allowed(pattr);
217+
214218
pattr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
215219
auto conv_fwd_pd = dnnl::convolution_forward::primitive_desc(
216220
engine,
@@ -319,6 +323,9 @@ sycl::event convolution_backward_data(
319323
dnnl::memory::dims _padding_back_bottom_right =
320324
padding_back_bottom_right.vec();
321325
dnnl::memory::dims _dilation = compatible_dilation(dilation);
326+
327+
at::native::onednn::apply_tf32_if_allowed(pattr);
328+
322329
auto conv_forward_pd = dnnl::convolution_forward::primitive_desc(
323330
engine,
324331
dnnl::prop_kind::forward,

aten/src/ATen/native/mkldnn/xpu/detail/Utils.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1+
#include <ATen/Context.h>
12
#include <ATen/native/ConvUtils.h>
23
#include <ATen/native/mkldnn/xpu/detail/Utils.h>
4+
#include <dnnl.hpp>
5+
#include <dnnl_common.hpp>
36

47
namespace at::native::onednn {
58

@@ -487,4 +490,12 @@ dnnl::memory::format_tag conv_weight_fmt(
487490
}
488491
}
489492

493+
void apply_tf32_if_allowed(dnnl::primitive_attr& pattr) {
494+
auto& ctx = at::globalContext();
495+
bool allow_tf32 = ctx.allowTF32OneDNN();
496+
if (allow_tf32) {
497+
pattr.set_fpmath_mode(dnnl::fpmath_mode::tf32);
498+
}
499+
}
500+
490501
} // namespace at::native::onednn

aten/src/ATen/native/mkldnn/xpu/detail/Utils.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ bool is_broadcast_from_other_to_self(
4949

5050
at::MemoryFormat get_cl_tag_by_ndim(const int64_t ndim);
5151

52+
void apply_tf32_if_allowed(dnnl::primitive_attr& primitive_attr);
53+
5254
bool binary_valid(
5355
const at::Tensor& self,
5456
const at::Tensor& other,

test/xpu/test_conv.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1258,6 +1258,17 @@ def test_channels_last_ouput_stride(self, device, dtype):
12581258
# input NHWC, output NHWC
12591259
assert_size_stride(out, (2, 512, 7, 7), (25088, 1, 3584, 512))
12601260

1261+
@onlyXPU
1262+
def test_onednn_allow_tf32_get_set(self):
1263+
with torch.backends.mkldnn.flags(
1264+
enabled=None, deterministic=None, allow_tf32=False
1265+
):
1266+
self.assertFalse(torch.backends.mkldnn.allow_tf32)
1267+
with torch.backends.mkldnn.flags(
1268+
enabled=None, deterministic=None, allow_tf32=True
1269+
):
1270+
self.assertTrue(torch.backends.mkldnn.allow_tf32)
1271+
12611272

12621273
instantiate_device_type_tests(
12631274
TestConvolutionNNDeviceType, globals(), only_for="xpu", allow_xpu=True

torch/_C/__init__.pyi.in

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1181,6 +1181,8 @@ def _get_cudnn_deterministic() -> _bool: ... # THPModule_deterministicCuDNN
11811181
def _set_cudnn_deterministic(arg: _bool) -> None: ... # THPModule_setDeterministicCuDNN
11821182
def _get_mkldnn_deterministic() -> _bool: ... # THPModule_deterministicMkldnn
11831183
def _set_mkldnn_deterministic(arg: _bool) -> None: ... # THPModule_setDeterministicMkldnn
1184+
def _get_onednn_allow_tf32() -> _bool: ... # THPModule_allowTF32OneDNN
1185+
def _set_onednn_allow_tf32(arg: _bool) -> None: ... # THPModule_setAllowTF32OneDNN
11841186
def _get_deterministic_algorithms() -> _bool: ... # THPModule_deterministicAlgorithms
11851187
def _get_deterministic_algorithms_warn_only() -> _bool: ... # THPModule_deterministicAlgorithmsWarnOnly
11861188
def _set_deterministic_algorithms(

torch/backends/mkldnn/__init__.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,18 +64,25 @@ def __exit__(self, exc_type, exc_val, exc_tb):
6464
return False
6565

6666

67-
def set_flags(_enabled, _deterministic=None):
68-
orig_flags = (torch._C._get_mkldnn_enabled(), torch._C._get_mkldnn_deterministic())
69-
torch._C._set_mkldnn_enabled(_enabled)
67+
def set_flags(_enabled=None, _deterministic=None, _allow_tf32=None):
68+
orig_flags = (
69+
torch._C._get_mkldnn_enabled(),
70+
torch._C._get_mkldnn_deterministic(),
71+
torch._C._get_onednn_allow_tf32(),
72+
)
73+
if _enabled is not None:
74+
torch._C._set_mkldnn_enabled(_enabled)
7075
if _deterministic is not None:
7176
torch._C._set_mkldnn_deterministic(_deterministic)
77+
if _allow_tf32 is not None:
78+
torch._C._set_onednn_allow_tf32(_allow_tf32)
7279
return orig_flags
7380

7481

7582
@contextmanager
76-
def flags(enabled=False, deterministic=False):
83+
def flags(enabled=False, deterministic=False, allow_tf32=True):
7784
with __allow_nonbracketed_mutation():
78-
orig_flags = set_flags(enabled, deterministic)
85+
orig_flags = set_flags(enabled, deterministic, allow_tf32)
7986
try:
8087
yield
8188
finally:
@@ -91,10 +98,14 @@ def __init__(self, m, name):
9198
deterministic = ContextProp(
9299
torch._C._get_mkldnn_deterministic, torch._C._set_mkldnn_deterministic
93100
)
101+
allow_tf32 = ContextProp(
102+
torch._C._get_onednn_allow_tf32, torch._C._set_onednn_allow_tf32
103+
)
94104

95105

96106
if TYPE_CHECKING:
97107
enabled: ContextProp
98108
deterministic: ContextProp
109+
allow_tf32: ContextProp
99110

100111
sys.modules[__name__] = MkldnnModule(sys.modules[__name__], __name__)

torch/csrc/Module.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -947,6 +947,29 @@ static PyObject* THPModule_setDeterministicAlgorithms(
947947
END_HANDLE_TH_ERRORS
948948
}
949949

950+
static PyObject* THPModule_setAllowTF32OneDNN(
951+
PyObject* _unsued,
952+
PyObject* arg) {
953+
HANDLE_TH_ERRORS
954+
TORCH_CHECK(
955+
PyBool_Check(arg),
956+
"_set_onednn_allow_tf32 expects a bool, "
957+
"but got ",
958+
THPUtils_typename(arg));
959+
at::globalContext().setAllowTF32OneDNN(arg == Py_True);
960+
Py_RETURN_NONE;
961+
END_HANDLE_TH_ERRORS
962+
}
963+
964+
static PyObject* THPModule_allowTF32OneDNN(
965+
PyObject* _unused,
966+
PyObject* noargs) {
967+
if (at::globalContext().allowTF32OneDNN())
968+
Py_RETURN_TRUE;
969+
else
970+
Py_RETURN_FALSE;
971+
}
972+
950973
static PyObject* THPModule_deterministicAlgorithms(
951974
PyObject* _unused,
952975
PyObject* noargs) {
@@ -1527,6 +1550,8 @@ static std::initializer_list<PyMethodDef> TorchMethods = {
15271550
{"_set_mkldnn_enabled", THPModule_setUserEnabledMkldnn, METH_O, nullptr},
15281551
{"_get_cudnn_allow_tf32", THPModule_allowTF32CuDNN, METH_NOARGS, nullptr},
15291552
{"_set_cudnn_allow_tf32", THPModule_setAllowTF32CuDNN, METH_O, nullptr},
1553+
{"_get_onednn_allow_tf32", THPModule_allowTF32OneDNN, METH_NOARGS, nullptr},
1554+
{"_set_onednn_allow_tf32", THPModule_setAllowTF32OneDNN, METH_O, nullptr},
15301555
{"_get_cudnn_benchmark", THPModule_benchmarkCuDNN, METH_NOARGS, nullptr},
15311556
{"_set_cudnn_benchmark", THPModule_setBenchmarkCuDNN, METH_O, nullptr},
15321557
{"_get_cudnn_deterministic",

0 commit comments

Comments
 (0)
0