8000 [Intel GPU][pt2e]: Collapse 3D input to 2D for matmul in qlinear_pointwise_binary fusion by ZhiweiYan-96 · Pull Request #148423 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[Intel GPU][pt2e]: Collapse 3D input to 2D for matmul in qlinear_pointwise_binary fusion #148423

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions aten/src/ATen/native/mkldnn/xpu/detail/QMatmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,8 @@ void quantized_matmul(
construct_attr_by_post_op(
binary_post_op,
binary_alpha,
input_scale,
input_zero_point,
other_scale,
other_zero_point,
other,
unary_post_op,
unary_post_op_args,
Expand Down
61 changes: 24 additions & 37 deletions aten/src/ATen/native/mkldnn/xpu/qlinear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,18 +164,27 @@ Tensor q_linear_pointwise_binary(
Tensor b_raw = bias.has_value() ? bias.value() : at::Tensor();

const int64_t dim = act.dim();
TORCH_CHECK(
dim == 2 || dim == 3,
"qliner_pointwise_binary XPU: input dim should be 2 or 3, but got",
dim);
int64_t K = act.size(dim - 1);
int64_t M = act.numel() / K;
// [M, K] x [K, N]
int64_t N = weight.size(1);

Tensor input = dim == 3 ? act.reshape({-1, K}) : act;
std::vector<int64_t> src_dims = {M, K};
std::vector<int64_t> dst_dims = {M, N};
auto dst_dtype = qlinear_decide_out_dtype(act, output_dtype);
Tensor qout = at::empty(dst_dims, act.options().dtype(dst_dtype));

bool has_accum_postop_sum = (binary_post_op == "sum");
if (dim == 3) {
other = other.has_value() ? other.value().reshape({-1, N}) : other;
}
Tensor qout = has_accum_postop_sum
? other.value()
: at::empty(dst_dims, act.options().dtype(dst_dtype));
quantized_matmul(
act.contiguous(),
input.contiguous(),
act_scale,
act_zero_point,
weight.contiguous(),
Expand All @@ -196,7 +205,7 @@ Tensor q_linear_pointwise_binary(
unary_post_op_algorithm,
/*m2_trans*/ true);

return qout;
return dim == 3 ? qout.reshape({act.size(0), -1, N}) : qout;
}

Tensor q_linear_pointwise_binary_tensor(
Expand All @@ -218,47 +227,25 @@ Tensor q_linear_pointwise_binary_tensor(
c10::string_view unary_post_op,
torch::List<std::optional<at::Scalar>> unary_post_op_args,
c10::string_view unary_post_op_algorithm) {
TORCH_CHECK(
act.device() == weight.device() &&
act.device() == weight_scales.device() &&
act.device() == weight_zero_points.device(),
"qlinear xpu: input tensors(act, weight, weight scale, weight zero-points) should be on the same device");
Tensor b_raw = bias.has_value() ? bias.value() : at::Tensor();

const int64_t dim = act.dim();
int64_t K = act.size(dim - 1);
int64_t M = act.numel() / K;
// [M, K] x [K, N]
int64_t N = weight.size(1);

std::vector<int64_t> src_dims = {M, K};
std::vector<int64_t> dst_dims = {M, N};
auto dst_dtype = qlinear_decide_out_dtype(act, output_dtype);
Tensor qout = at::empty(dst_dims, act.options().dtype(dst_dtype));

quantized_matmul(
act.contiguous(),
return q_linear_pointwise_binary(
act,
act_scale.item().toDouble(),
act_zero_point.item().toLong(),
weight.contiguous(),
weight,
weight_scales,
weight_zero_points,
b_raw,
qout,
other,
bias,
output_scale,
output_zero_point,
output_dtype,
/*other*/ other,
/*other scale*/ other_scale,
/*other zp*/ other_zero_point,
/*binary post op*/ binary_post_op,
/*binary alpha*/ binary_alpha,
other_scale,
other_zero_point,
binary_post_op,
binary_alpha,
unary_post_op,
unary_post_op_args,
unary_post_op_algorithm,
/*m2_trans*/ true);

return qout;
unary_post_op_algorithm);
}

at::Tensor q_linear_prepack_onednn(
Expand Down
11 changes: 8 additions & 3 deletions test/inductor/test_mkldnn_pattern_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -2654,11 +2654,12 @@ def forward(self, x):
lambda x, y: y.add_(x),
]
fake_quant_x2_list = [False, True] if int8_mixed_bf16 else [False]
cases = itertools.product(add_fn_list, fake_quant_x2_list)
for add_fn, fq_x2 in cases:
shape_list = [(4, 4), [4, 4, 4]]
cases = itertools.product(add_fn_list, fake_quant_x2_list, shape_list)
for add_fn, fq_x2, shape in cases:
mod = M(add_fn, use_relu, fq_x2).eval().to(device=device)
v = torch.randn(
(4, 4), dtype=torch.float32, requires_grad=False, device=device
shape, dtype=torch.float32, requires_grad=False, device=device
).add(1)

def matcher_check_fn():
Expand All @@ -2668,6 +2669,10 @@ def matcher_check_fn():
)
# pattern = [dequant_per_tensor, (convert_dtype), dequant_per_channel, (convert_dtype), permute, addmm]
nodes_per_match = 6 if int8_mixed_bf16 else 4
if len(shape) == 3:
# pattern = [dequant_per_tensor, (convert_dtype), (view), \
# dequant_per_channel, (convert_dtype), (view), permute, addmm]
nodes_per_match += 2
self.assertEqual(
counters["inductor"]["qlinear_weight_prepack_matcher_nodes"],
4 * nodes_per_match,
Expand Down
Loading
0