10000 [Intel GPU] qlinear.pointwise with mixed dtype support · pytorch/pytorch@349ab48 · GitHub
[go: up one dir, main page]

Skip to content

Commit 349ab48

Browse files
committed
[Intel GPU] qlinear.pointwise with mixed dtype support
ghstack-source-id: 9d217b7 Pull Request resolved: #136753
1 parent 22fffb3 commit 349ab48

File tree

2 files changed

+315
-30
lines changed

2 files changed

+315
-30
lines changed

aten/src/ATen/native/mkldnn/xpu/qlinear.cpp

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,18 @@ using namespace at::native::onednn;
77

88
namespace at::native::xpu {
99

10+
static inline c10::ScalarType qlinear_decide_out_dtype(
11+
const at::Tensor& act,
12+
const std::optional<c10::ScalarType> output_dtype) {
13+
bool fp32_output = output_dtype.has_value() && (output_dtype == c10::kFloat);
14+
bool bfloat16_output =
15+
output_dtype.has_value() && (output_dtype == c10::kBFloat16);
16+
auto dst_dtype = fp32_output
17+
? c10::kFloat
18+
: (bfloat16_output ? c10::kBFloat16 : act.scalar_type());
19+
return dst_dtype;
20+
}
21+
1022
Tensor q_linear_pointwise(
1123
Tensor act,
1224
double act_scale,
@@ -37,9 +49,9 @@ Tensor q_linear_pointwise(
3749

3850
std::vector<int64_t> src_dims = {M, K};
3951
std::vector<int64_t> dst_dims = {M, N};
40-
auto out_dtype =
41-
output_dtype.has_value() ? output_dtype.value() : act.scalar_type();
42-
Tensor qout = at::empty(dst_dims, act.options().dtype(out_dtype));
52+
53+
auto dst_dtype = qlinear_decide_out_dtype(act, output_dtype);
54+
Tensor qout = at::empty(dst_dims, act.options().dtype(dst_dtype));
4355

4456
quantized_matmul(
4557
act.contiguous(),
@@ -96,9 +108,9 @@ Tensor q_linear_pointwise_tensor(
96108

97109
std::vector<int64_t> src_dims = {M, K};
98110
std::vector<int64_t> dst_dims = {M, N};
99-
auto out_dtype =
100-
output_dtype.has_value() ? output_dtype.value() : act.scalar_type();
101-
Tensor qout = at::empty(dst_dims, act.options().dtype(out_dtype));
111+
112+
auto dst_dtype = qlinear_decide_out_dtype(act, output_dtype);
113+
Tensor qout = at::empty(dst_dims, act.options().dtype(dst_dtype));
102114

103115
quantized_matmul(
104116
act.contiguous(),
@@ -159,9 +171,8 @@ Tensor q_linear_pointwise_binary(
159171

160172
std::vector<int64_t> src_dims = {M, K};
161173
std::vector<int64_t> dst_dims = {M, N};
162-
auto out_dtype =
163-
output_dtype.has_value() ? output_dtype.value() : act.scalar_type();
164-
Tensor qout = at::empty(dst_dims, act.options().dtype(out_dtype));
174+
auto dst_dtype = qlinear_decide_out_dtype(act, output_dtype);
175+
Tensor qout = at::empty(dst_dims, act.options().dtype(dst_dtype));
165176

166177
quantized_matmul(
167178
act.contiguous(),
@@ -222,9 +233,8 @@ Tensor q_linear_pointwise_binary_tensor(
222233

223234
std::vector<int64_t> src_dims = {M, K};
224235
std::vector<int64_t> dst_dims = {M, N};
225-
auto out_dtype =
226-
output_dtype.has_value() ? output_dtype.value() : act.scalar_type();
227-
Tensor qout = at::empty(dst_dims, act.options().dtype(out_dtype));
236+
auto dst_dtype = qlinear_decide_out_dtype(act, output_dtype);
237+
Tensor qout = at::empty(dst_dims, act.options().dtype(dst_dtype));
228238

229239
quantized_matmul(
230240
act.contiguous(),

0 commit comments

Comments
 (0)
0