8000 [Intel GPU][pt2e]: Collapse 3D input to 2D for matmul in qlinear_pointwise_binary fusion by ZhiweiYan-96 · Pull Request #148423 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[Intel GPU][pt2e]: Collapse 3D input to 2D for matmul in qlinear_pointwise_binary fusion #148423

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions aten/src/ATen/native/mkldnn/xpu/detail/QMatmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,8 @@ void quantized_matmul(
construct_attr_by_post_op(
binary_post_op,
binary_alpha,
input_scale,
input_zero_point,
other_scale,
other_zero_point,
other,
unary_post_op,
unary_post_op_args,
Expand Down
74 changes: 28 additions & 46 deletions aten/src/ATen/native/mkldnn/xpu/qlinear.cpp
8000
Original file line number Diff line number Diff line change
Expand Up @@ -164,18 +164,22 @@ Tensor q_linear_pointwise_binary(
Tensor b_raw = bias.has_value() ? bias.value() : at::Tensor();

const int64_t dim = act.dim();
TORCH_CHECK(dim == 2 || dim == 3, "qliner_pointwise_binary XPU: input dim should be 2 or 3, but got", dim);
int64_t K = act.size(dim - 1);
int64_t M = act.numel() / K;
// [M, K] x [K, N]
int64_t N = weight.size(1);

Tensor input = dim == 3 ? act.reshape({-1, K}) : act;
std::vector<int64_t> src_dims = {M, K};
std::vector<int64_t> dst_dims = {M, N};
auto dst_dtype = qlinear_decide_out_dtype(act, output_dtype);
Tensor qout = at::empty(dst_dims, act.options().dtype(dst_dtype));

bool has_accum_postop_sum = (binary_post_op == "sum");
if(dim == 3){
other = other.has_value() ? other.value().reshape({-1, N}) : other;
}
Tensor qout = has_accum_postop_sum ? other.value() : at::empty(dst_dims, act.options().dtype(dst_dtype));
quantized_matmul(
act.contiguous(),
input.contiguous(),
act_scale,
act_zero_point,
weight.contiguous(),
Expand All @@ -196,7 +200,7 @@ Tensor q_linear_pointwise_binary(
unary_post_op_algorithm,
/*m2_trans*/ true);

return qout;
return dim == 3 ? qout.reshape({act.size(0), -1, N}) : qout;
}

Tensor q_linear_pointwise_binary_tensor(
Expand All @@ -218,47 +222,25 @@ Tensor q_linear_pointwise_binary_tensor(
c10::string_view unary_post_op,
torch::List<std::optional<at::Scalar>> unary_post_op_args,
c10::string_view unary_post_op_algorithm) {
TORCH_CHECK(
act.device() == weight.device() &&
act.device() == weight_scales.device() &&
act.device() == weight_zero_points.device(),
"qlinear xpu: input tensors(act, weight, weight scale, weight zero-points) should be on the same device");
Tensor b_raw = bias.has_value() ? bias.value() : at::Tensor();

const int64_t dim = act.dim();
int64_t K = act.size(dim - 1);
int64_t M = act.numel() / K;
// [M, K] x [K, N]
int64_t N = weight.size(1);

std::vector<int64_t> src_dims = {M, K};
std::vector<int64_t> dst_dims = {M, N};
auto dst_dtype = qlinear_decide_out_dtype(act, output_dtype);
Tensor qout = at::empty(dst_dims, act.options().dtype(dst_dtype));

quantized_matmul(
act.contiguous(),
act_scale.item().toDouble(),
act_zero_point.item().toLong(),
weight.contiguous(),
weight_scales,
weight_zero_points,
b_raw,
qout,
output_scale,
output_zero_point,
output_dtype,
/*other*/ other,
/*other scale*/ other_scale,
/*other zp*/ other_zero_point,
/*binary post op*/ binary_post_op,
/*binary alpha*/ binary_alpha,
unary_post_op,
unary_post_op_args,
unary_post_op_algorithm,
/*m2_trans*/ true);

return qout;
return q_linear_pointwise_binary(
act,
act_scale.item().toDouble(),
act_zero_point.item().toLong(),
weight,
weight_scales,
weight_zero_points,
other,
bias,
output_scale,
output_zero_point,
output_dtype,
other_scale,
other_zero_point,
binary_post_op,
binary_alpha,
unary_post_op,
unary_post_op_args,
unary_post_op_algorithm);
}

at::Tensor q_linear_prepack_onednn(
Expand Down
92 changes: 92 additions & 0 deletions test/xpu/test_xpu_inductor_quantizer.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ZhiweiYan-96 , why do we need to add a dedicated test file? I suppose it should reuse other test files. Right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Appreciation for suggestion, this file is not necessary.. I have removed the file and add 3D cases in test_mkldnn_pattern_matcher.py

Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Owner(s): ["module: intel"]

# This files serves as supplementary tests for the cases in `test/inductor/test_mkldnn_pattern_matcher`
# This files tests the issue cases that shown only in XPU mode.
import contextlib
import torch
from torch._inductor.test_case import TestCase
from torch.testing._internal.common_quantization import _generate_qdq_quantized_model
from torch._inductor import config
from torch._dynamo.utils import counters
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
run_tests,
)

@config.patch({"freezing": True})
@config.patch({"force_disable_caches": True})
class TestXPUInductorQuantizer(TestCase):
def _clone_inputs(self, inputs):
def clone(x):
if not isinstance(x, torch.Tensor):
return x
return x.clone()

return tuple(clone(x) for x in inputs)

def _test_common(
self,
mod,
inputs,
matcher_check_fn,
atol=1e-5,
rtol=1.3e-6,
check_autocast=torch.float32,
is_qat=False,
dtype=None,
is_dynamic=False,
quantizer=None,
compile_options={}, # noqa: B006
):
counters.clear()
torch._dynamo.reset()
device_type = "xpu"
if check_autocast == torch.bfloat16:
maybe_autocast = torch.amp.autocast(
device_type=device_type, dtype=torch.bfloat16
)
atol, rtol = 1e-2, 1e-2
elif check_autocast == torch.float16:
maybe_autocast = torch.amp.autocast(
device_type=device_type, dtype=torch.float16
)
atol, rtol = 1e-2, 1e-2
else:
assert check_autocast == torch.float32
maybe_autocast = contextlib.nullcontext()
convert_model = _generate_qdq_quantized_model(
mod, inputs, is_qat, is_dynamic, quantizer
)
with torch.no_grad(), maybe_autocast:
compiled_model = torch.compile(convert_model)
ref = compiled_model(*self._clone_inputs(inputs))
res = mod(*self._clone_inputs(inputs))
relative_err = torch.mean(torch.abs(res - ref) / ref.abs().clamp(1e-6))
self.assertTrue(relative_err < 0.1)
matcher_check_fn()

def test_qlinear_pointwise_binary_3d(self):
class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
self.linear = torch.nn.Linear(10, 10)
self.relu = torch.nn.ReLU()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.relu = torch.nn.ReLU()

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for suggestions, since the UT is removed, we may resolve this issue.


def forward(self, x):
orig = x
out = self.linear(x)
return out + orig

def matcher_check_fn():
self.assertEqual(counters["inductor"]["qlinear_weight_prepack_matcher_count"], 1)
self.assertEqual(counters["inductor"]["qlinear_binary_matcher_count"], 1)

mod = Model().xpu()
inputs = (torch.rand(2, 3, 10, device="xpu"),)
self._test_common(mod, inputs, matcher_check_fn)


instantiate_parametrized_tests(TestXPUInductorQuantizer)

if __name__ == "__main__":
run_tests()
Loading
0