8000 [Intel GPU] qconv_pointwise.binary XPU support (#135189) · pytorch/pytorch@91a8c8a · GitHub
[go: up one dir, main page]

Skip to content

Commit 91a8c8a

Browse files
ZhiweiYan-96guangyey
authored andcommitted
[Intel GPU] qconv_pointwise.binary XPU support (#135189)
# Motivation This PR intends to enable quantized fusion `qconv+add` and `qconv+add+relu` at Intel GPU backend. At backend level, we register the op via schema `TORCH_SELECTIVE_NAME("onednn::qconv2d_pointwise.binary")` which is the one already defined in `x86InductorQuantzer` At Inductor level, we have small modification at `torch/_inductor/fx_passes/quantization.py` to allow signed int8 data type(s8) during op lowering. As for the pattern matching, we greatly reuse the code existing at x86InductorQuantizer. # UT verification ```bash python test/inductor/test_mkldnn_pattern_matcher.py -v \ -k test_qconv2d_add_xpu \ -k test_qconv2d_add_relu_xpu 2>&1 ``` # Runtime exemplification Following is the oneDNN verbose collected from UT ```bash onednn_verbose,primitive,exec,gpu:0,convolution,jit:ir,forward_training,src_s8::blocked:acdb::f0 wei_s8::blocked:abcd::f0 bia_f32::blocked:a::f0 dst_s8::blocked:acdb::f0,attr-scratchpad:user attr-scales:src0:0:f32+dst:0:f32+wei:1:f32 attr-zero-points:src0:0:s32+dst:0:s32 attr-post-ops:eltwise_linear:1:0.337704+sum:0.0241217+eltwise_relu,alg:convolution_direct,mb1_ic3oc6_ih8oh6kh3sh1dh0ph0_iw8ow6kw3sw1dw0pw0,0.151123 ``` Pull Request resolved: #135189 Approved by: https://github.com/liangan1, https://github.com/EikanWang, https://github.com/guangyey, https://github.com/jerryzh168 ghstack dependencies: #133307 Co-authored-by: guangyey <guangye.yu@intel.com>
1 parent 21ac321 commit 91a8c8a

File tree

7 files changed

+256
-74
lines changed

7 files changed

+256
-74
lines changed

aten/src/ATen/native/mkldnn/xpu/detail/Attr.h

Lines changed: 42 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ class Attr {
177177
float sum_q_scale = 1.f,
178178
int64_t zp = 0) {
179179
ops_params_.push_back(
180-
PostOpParam(/*scale_sum*/ sum_scale * sum_q_scale, kind_t::sum));
180+
PostOpParam(/*scale_sum*/ sum_scale * sum_q_scale, zp, kind_t::sum));
181181
return *this;
182182
}
183183

@@ -261,10 +261,7 @@ class Attr {
261261
return *this;
262262
}
263263

264-
dnnl::post_ops extract_post_ops(
265-
const at::Tensor& dst,
266-
bool is_quantized = false,
267-
bool int8_output = false) {
264+
dnnl::post_ops extract_post_ops(const at::Tensor& dst) {
268265
// this function is used to extract post ops params from the ops_params_
269266
// and put them into onednn post ops
270267
for (size_t i = 0; i < ops_params_.size(); ++i) {
@@ -303,11 +300,6 @@ class Attr {
303300
}
304301
}
305302

306-
// if output is quantized, then append the eltwise linear to adjust the
307-
// output scale/zero_point
308-
if (is_quantized && int8_output) {
309-
dnnl_post_ops_.append_eltwise(kind_with_linear, q_scale_, q_zero_point_);
310-
}
311303
return dnnl_post_ops_;
312304
}
313305

@@ -410,6 +402,7 @@ static inline void construct_attr_by_post_op(
410402
double binary_alpha,
411403
double input1_scale,
412404
int64_t input1_zero_point,
405+
std::optional<at::Tensor> accum,
413406
const std::string_view& unary_post_op,
414407
const torch::List<std::optional<at::Scalar>>& unary_post_op_args,
415408
const std::string_view& unary_post_op_algorithm,
@@ -418,11 +411,46 @@ static inline void construct_attr_by_post_op(
418411
(binary_post_op == "none" && unary_post_op == "none"); // not post-ops
419412
bool is_unary_post_op_only =
420413
(binary_post_op == "none" && unary_post_op != "none"); // ex., conv + relu
414+
bool is_valid_binary_combination =
415+
(binary_post_op == "add" || binary_post_op == "sum") &&
416+
(unary_post_op == "none" || unary_post_op == "relu");
421417
TORCH_INTERNAL_ASSERT(
422-
is_unary_post_op_only || is_none_post_op,
423-
"Currently, quantization backend for Intel GPU only supports convolution or convolution with unary post operation like ReLU");
424-
construct_attr_for_unary(
425-
unary_post_op, unary_post_op_args, unary_post_op_algorithm, attr);
418+
is_unary_post_op_only || is_none_post_op || is_valid_binary_combination,
419+
"Please provide valid combination of unary post operators and binary post operators");
420+
421+
if (binary_post_op == "none") {
422+
construct_attr_for_unary(
423+
unary_post_op, unary_post_op_args, unary_post_op_algorithm, attr);
424+
} else if (binary_post_op == "sum") {
425+
if (unary_post_op == "none") {
426+
if (input1_zero_point != 0)
427+
attr = attr.append_post_eltwise(
428+
/*scale*/ 1.f,
429+
/*alpha*/ 1.f,
430+
-input1_zero_point * input1_scale,
431+
attr.kind_with_linear);
432+
attr = attr.append_post_sum(1, input1_scale, /*input1_zero_point*/ 0);
433+
} else if (unary_post_op == "relu") {
434+
if (input1_zero_point != 0)
435+
attr = attr.append_post_eltwise(
436+
/*scale*/ 1.f,
437+
/*alpha*/ 1.f,
438+
-input1_zero_point * input1_scale,
439+
attr.kind_with_linear);
440+
attr = attr.append_post_sum(1, input1_scale, /*input1_zero_point*/ 0);
441+
attr = attr.append_post_eltwise(
442+
/* scale */ 1.f,
443+
/* alpha */ 0.f,
444+
/* beta */ 0.f,
445+
attr.kind_with_relu);
446+
}
447+
} else if (binary_post_op == "add") {
448+
TORCH_CHECK(accum.has_value());
449+
attr = attr.append_post_binary(attr.kind_with_binary_add, accum.value());
450+
if (unary_post_op == "relu") {
451+
attr = attr.append_post_eltwise(1.f, 0.f, 0.f, attr.kind_with_relu);
452+
}
453+
}
426454
}
427455

428456
} // namespace at::native::onednn

aten/src/ATen/native/mkldnn/xpu/detail/QConv.cpp

Lines changed: 42 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,19 @@
1111

1212
namespace at::native::onednn {
1313

14-
static std::tuple<dnnl::memory::desc, dnnl::memory::desc, dnnl::memory::desc>
14+
static std::tuple<
15+
dnnl::memory::desc,
16+
dnnl::memory::desc,
17+
dnnl::memory::desc,
18+
dnnl::memory::desc>
1519
qconv_get_md(
1620
const at::Tensor& src,
1721
const at::Tensor& wgh,
22+
std::optional<at::Tensor> bias,
1823
const at::Tensor& dst,
1924
int64_t groups) {
2025
// create dnnl::memory desc from the src/wgh/dst tensors
21-
dnnl::memory::desc src_usr_md, wgh_usr_md, dst_usr_md;
26+
dnnl::memory::desc src_usr_md, wgh_usr_md, dst_usr_md, bias_usr_md;
2227
auto ndim = src.ndimension();
2328
bool src_is_cl =
2429
(src.suggest_memory_format() == at::MemoryFormat::ChannelsLast) ||
@@ -44,7 +49,14 @@ qconv_get_md(
4449
auto fmt_wgh = conv_weight_fmt(ndim, groups != 1, wgh_is_cl);
4550
wgh_usr_md = dnnl::memory::desc(wgh_tz, wei_data_t, fmt_wgh);
4651

47-
return {src_usr_md, wgh_usr_md, dst_usr_md};
52+
if (bias.has_value()) {
53+
bias_usr_md = dnnl::memory::desc(
54+
bias.value().sizes().vec(),
55+
dnnl::memory::data_type::f32,
56+
dnnl::memory::format_tag::x);
57+
}
58+
59+
return {src_usr_md, wgh_usr_md, bias_usr_md, dst_usr_md};
4860
}
4961

5062
at::Tensor quantized_convolution(
@@ -76,14 +88,12 @@ at::Tensor quantized_convolution(
7688
Attr(/*q_scale=*/1.0 / inv_output_scale, /*zp=*/output_zero_point);
7789

7890
auto ndim = act.ndimension();
79-
if (bias.has_value()) {
80-
attr = attr.append_bias(bias.value(), ndim - 2);
81-
}
8291
construct_attr_by_post_op(
8392
binary_attr.has_value() ? binary_attr.value() : "none",
8493
binary_alpha.has_value() ? binary_alpha.value().to<double>() : 1.0,
8594
accum_scale,
8695
accum_zero_point,
96+
accum,
8797
unary_attr.has_value() ? unary_attr.value() : "none",
8898
unary_scalars,
8999
unary_algorithm.has_value() ? unary_algorithm.value() : "",
@@ -110,10 +120,7 @@ at::Tensor quantized_convolution(
110120
dnnl::memory::dims _dilation = compatible_dilation(dilation);
111121
dnnl::post_ops po;
112122
// extract post ops
113-
po = attr.extract_post_ops(
114-
output,
115-
/*is_quantized*/ true,
116-
output.scalar_type() == at::kByte || output.scalar_type() == at::kChar);
123+
po = attr.extract_post_ops(output);
117124
int mask_ac = 0, mask_weight;
118125
// [Note: Per-channel quantization mask setting]
119126
// Per-channel quantization is on weight output channel mostly, mask_weight=
@@ -127,10 +134,11 @@ at::Tensor quantized_convolution(
127134
dnnl::primitive_attr pattr;
128135

129136
bool src_need_zp = (act_scale != 0);
137+
bool dst_need_zp = (output_zero_point != 0);
130138

131139
// create usr_md for tensors, and md for conv primitive
132-
auto [src_md, weight_md, output_md] =
133-
qconv_get_md(act, weight, output, groups);
140+
auto [src_md, weight_md, bias_md, output_md] =
141+
qconv_get_md(act, weight, bias, output, groups);
134142

135143
// get tensor md
136144
auto ic = act.size(1);
@@ -139,11 +147,14 @@ at::Tensor quantized_convolution(
139147
compatible_weight_dims(ndim, groups, oc, ic, weight.sizes());
140148

141149
pattr.set_scales_mask(DNNL_ARG_SRC, mask_ac);
150+
pattr.set_scales_mask(DNNL_ARG_DST, mask_ac);
142151
pattr.set_scales_mask(DNNL_ARG_WEIGHTS, mask_weight);
143152
pattr.set_post_ops(po);
144153

145154
if (src_need_zp)
146155
pattr.set_zero_points_mask(DNNL_ARG_SRC, mask_ac);
156+
if (dst_need_zp)
157+
pattr.set_zero_points_mask(DNNL_ARG_DST, mask_ac);
147158
pattr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
148159

149160
// create primitive
@@ -153,7 +164,7 @@ at::Tensor quantized_convolution(
153164
dnnl::algorithm::convolution_direct,
154165
src_md,
155166
weight_md,
156-
dnnl::memory::desc(),
167+
bias.has_value() ? bias_md : dnnl::memory::desc(),
157168
output_md,
158169
_stride,
159170
_dilation,
@@ -164,18 +175,24 @@ at::Tensor quantized_convolution(
164175
dnnl::convolution_forward conv_forward =
165176
dnnl::convolution_forward(conv_fwd_pd);
166177

167-
dnnl::memory src_m, weight_m, output_m;
178+
dnnl::memory src_m, weight_m, output_m, bias_m;
168179

169180
src_m = make_onednn_memory(src_md, engine, act.data_ptr());
170181
output_m = make_onednn_memory(output_md, engine, output.data_ptr());
171182
weight_m = make_onednn_memory(weight_md, engine, weight.data_ptr());
183+
if (bias.has_value()) {
184+
bias_m = make_onednn_memory(bias_md, engine, bias.value().data_ptr());
185+
}
172186

173187
std::unordered_map<int, dnnl::memory> args;
174188
if (attr.with_binary())
175189
attr.construct_post_binary(conv_fwd_pd, args);
176190
args.insert({DNNL_ARG_SRC, src_m});
177191
args.insert({DNNL_ARG_WEIGHTS, weight_m});
178192
args.insert({DNNL_ARG_DST, output_m});
193+
if (bias.has_value()) {
194+
args.insert({DNNL_ARG_BIAS, bias_m});
195+
}
179196

180197
dnnl::memory src_sc_m, src_zp_m;
181198
Tensor src_sc_tensor, src_zp_tensor;
@@ -188,7 +205,17 @@ at::Tensor quantized_convolution(
188205
args.insert({DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC, src_zp_m});
189206
}
190207

191-
// dst scale is no need for setting, since it is fused in postop via linear
208+
dnnl::memory dst_sc_m, dst_zp_m;
209+
Tensor dst_sc_tensor, dst_zp_tensor;
210+
dst_sc_m = dnnl_memory_from_host_scalar(
211+
static_cast<float>(inv_output_scale), dst_sc_tensor, engine);
212+
dst_zp_m = dnnl_memory_from_host_scalar(
213+
static_cast<int32_t>(output_zero_point), dst_zp_tensor, engine);
214+
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, dst_sc_m});
215+
if (dst_need_zp) {
216+
args.insert({DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST, dst_zp_m});
217+
}
218+
192219
size_t scratchpad_size = conv_fwd_pd.scratchpad_desc().get_size();
193220
Tensor scratchpad_tensor = at::empty(
194221
{static_cast<int64_t>(scratchpad_size)},

aten/src/ATen/native/mkldnn/xpu/detail/QMatmul.cpp

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ void quantized_matmul(
118118
binary_alpha,
119119
input_scale,
120120
input_zero_point,
121+
other,
121122
unary_post_op,
122123
unary_post_op_args,
123124
unary_post_op_algorithm,
@@ -210,11 +211,9 @@ void quantized_matmul(
210211
std::unordered_map<int, dnnl::memory> args;
211212

212213
dnnl::post_ops po;
213-
po = attr.extract_post_ops(
214-
dst,
215-
true,
216-
dst.scalar_type() == at::kByte || dst.scalar_type() == at::kChar);
214+
po = attr.extract_post_ops(dst);
217215
bool m1_need_zp = (input_zero_point != 0);
216+
bool dst_need_zp = (output_zero_point != 0);
218217
bool wgh_is_per_channel = weight_scales.numel() > 1;
219218

220219
dnnl::matmul matmul_p;
@@ -242,6 +241,10 @@ void quantized_matmul(
242241
if (m1_need_zp) {
243242
pattr.set_zero_points_mask(DNNL_ARG_SRC, mask_ac);
244243
}
244+
pattr.set_scales_mask(DNNL_ARG_DST, mask_ac);
245+
if (dst_need_zp) {
246+
pattr.set_zero_points_mask(DNNL_ARG_DST, mask_ac);
247+
}
245248

246249
if (with_bias) {
247250
b_md = dnnl::memory::desc(bias_dims, bias_dt, bias_strides);
@@ -309,6 +312,17 @@ void quantized_matmul(
309312
args.insert({DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC, m1_zp_m});
310313
}
311314

315+
dnnl::memory dst_sc_m, dst_zp_m;
316+
Tensor dst_sc_tensor, dst_zp_tensor;
317+
dst_sc_m = dnnl_memory_from_host_scalar(
318+
static_cast<float>(output_scale), dst_sc_tensor, engine);
319+
dst_zp_m = dnnl_memory_from_host_scalar(
320+
static_cast<int32_t>(output_zero_point), dst_zp_tensor, engine);
321+
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, dst_sc_m});
322+
if (dst_need_zp) {
323+
args.insert({DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST, dst_zp_m});
324+
}
325+
312326
auto qmatmul_event = dnnl::sycl_interop::execute(matmul_p, stream, args);
313327

314328
if (!dst.is_same(result))

0 commit comments

Comments
 (0)
0