8000 [Inductor][CPU] Fuse SmoothQuant int8 linear pattern · pytorch/pytorch@82808a6 · GitHub
[go: up one dir, main page]

Skip to content

Commit 82808a6

Browse files
committed
[Inductor][CPU] Fuse SmoothQuant int8 linear pattern
ghstack-source-id: 7b2a046 Pull Request resolved: #142036
1 parent b576a8c commit 82808a6

File tree

3 files changed

+335
-10
lines changed

3 files changed

+335
-10
lines changed

aten/src/ATen/native/quantized/cpu/qlinear.cpp

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -931,8 +931,8 @@ static at::Tensor linear_int8_with_onednn_weight(
931931
std::string_view& unary_post_op_algorithm) {
932932
using ideep::tensor;
933933
const int64_t dim = input.dim();
934-
TORCH_CHECK(input.scalar_type() == c10::ScalarType::Byte,
935-
"qlinear with mkldnn tensor: data type of input should be uint8 (unsigned char).");
934+
TORCH_CHECK(input.scalar_type() == c10::ScalarType::Byte || input.scalar_type() == c10::ScalarType::Char,
935+
"qlinear with mkldnn tensor: data type of input should be uint8 or int8 (unsigned char or char).");
936936
TORCH_CHECK(onednn_weight.scalar_type() == c10::ScalarType::Char,
937937
"qlinear with mkldnn tensor: data type of weight should be int8 (char).");
938938
TORCH_CHECK(
@@ -1021,7 +1021,8 @@ static at::Tensor linear_int8_with_onednn_weight(
10211021
empty_tensor;
10221022

10231023
// Create onednn primitive
1024-
auto src_desc = tensor::desc(src_dims, ideep::data_type::u8, ideep::format_tag::any);
1024+
auto src_dtype = input.scalar_type() == c10::kByte ? ideep::data_type::u8 : ideep::data_type::s8;
1025+
auto src_desc = tensor::desc(src_dims, src_dtype, ideep::format_tag::any);
10251026
auto weights_desc = packed_weight.get_desc();
10261027
auto dst_dtype = dst.get_data_type();
10271028
auto dst_desc = tensor::desc(dst_dims, dst_dtype, ideep::format_tag::any);
@@ -1118,12 +1119,14 @@ namespace at::native {
11181119
torch::List<std::optional<at::Scalar>> post_op_args,
11191120
std::string_view post_op_algorithm) {
11201121
#if AT_MKLDNN_ENABLED()
1121-
TORCH_CHECK(act_scale.numel() == 1 && act_zero_point.numel() == 1,
1122-
"onednn int8 linear: act scale/zp size should be 1");
1122+
// act_zero_point.numel() == 0 for symmetric quantization
1123+
TORCH_CHECK(act_scale.numel() == 1 && act_zero_point.numel() <= 1,
1124+
"onednn int8 linear: act scale/zp size should be 1/<=1");
11231125
static std::optional<at::Tensor> other = std::nullopt;
11241126
static const std::string_view binary_post_op = "none";
1127+
int64_t act_zp = act_zero_point.numel() == 1 ? act_zero_point.item().toLong() : 0;
11251128
return linear_int8_with_onednn_weight(
1126-
act, act_scale.item().toDouble(), act_zero_point.item().toLong(),
1129+
act, act_scale.item().toDouble(), act_zp,
11271130
onednn_weight, weight_scales, weight_zero_points,
11281131
bias, output_scale, output_zero_point, output_dtype,
11291132
other, /*other scale*/1.0, /*other zp*/0,
@@ -1154,10 +1157,12 @@ namespace at::native {
11541157
torch::List<std::optional<at::Scalar>> unary_post_op_args,
11551158
std::string_view unary_post_op_algorithm) {
11561159
#if AT_MKLDNN_ENABLED()
1157-
TORCH_CHECK(act_scale.numel() == 1 && act_zero_point.numel() == 1,
1158-
"onednn int8 linear: act scale/zp size should be 1");
1160+
// act_zero_point.numel() == 0 for symmetric quantization
1161+
TORCH_CHECK(act_scale.numel() == 1 && act_zero_point.numel() <= 1,
1162+
"onednn int8 linear: act scale/zp size should be 1/<=1");
1163+
int64_t act_zp = act_zero_point.numel() == 1 ? act_zero_point.item().toLong() : 0;
11591164
return linear_int8_with_onednn_weight(
1160-
act, act_scale.item().toDouble(), act_zero_point.item().toLong(),
1165+
act, act_scale.item().toDouble(), act_zp,
11611166
onednn_weight, weight_scales, weight_zero_points,
11621167
bias, output_scale, output_zero_point, output_dtype,
11631168
other, other_scale, other_zero_point,

test/inductor/test_mkldnn_pattern_matcher.py

Lines changed: 90 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
)
2222
from torch.testing._internal.common_utils import (
2323
instantiate_parametrized_tests,
24+
IS_FBCODE,
2425
IS_LINUX,
2526
parametrize,
2627
skipIfNoXPU,
@@ -147,6 +148,7 @@ def _test_common(
147148
dtype=None,
148149
is_dynamic=False,
149150
quantizer=None,
151+
compile_options={}, # noqa: B006
150152
):
151153
counters.clear()
152154
torch._dynamo.reset()
@@ -177,7 +179,7 @@ def _test_common(
177179
with torch.no_grad(), maybe_autocast:
178180
clone_inputs = self._clone_inputs(inputs)
179181
expected = mod(*inputs)
180-
actual = torch.compile(mod)(*clone_inputs)
182+
actual = torch.compile(mod, **compile_options)(*clone_inputs)
181183
torch.testing.assert_close(actual, expected, atol=atol, rtol=rtol)
182184
matcher_check_fn()
183185

@@ -3286,6 +3288,93 @@ def test_linear_dynamic_fp16(self):
32863288
def test_linear_relu_dynamic_fp16(self):
32873289
self._test_linear_dynamic_fp16_helper(use_relu=True)
32883290

3291+
@skipIfNoDynamoSupport
3292+
@skipIfNoONEDNN
3293+
# TODO: investigate options of torch.compile in fbcode
3294+
@unittest.skipIf(IS_FBCODE, "Failing in fbcode")
3295+
@parametrize("has_bias", [True, False])
3296+
@parametrize("dtype", [torch.float, torch.bfloat16])
3297+
@parametrize("per_channel_quant", [True, False])
3298+
@parametrize("dynamic", [True, False])
3299+
def test_smooth_quant_with_int_mm(
3300+
self, has_bias, dtype, per_channel_quant, dynamic
3301+
):
3302+
r"""
3303+
This testcase check if we can match the SmoothQuant int8 linear pattern from Torchao.
3304+
The pattern is:
3305+
(no bias) reshape -> _int_mm -> convert_element_type -> (expand -> mul) -> mul -> reshape
3306+
or
3307+
(with bias) pattern_no_bias -> add -> reshape -> reshape
3308+
"""
3309+
if dtype == torch.bfloat16 and not torch.ops.mkldnn._is_mkldnn_bf16_supported():
3310+
return
3311+
M = 16
3312+
in_feature = 32
3313+
out_feature = 64
3314+
q_min, q_max = -32, 31
3315+
3316+
class Mod(torch.nn.Module):
3317+
def __init__(
3318+
self, dtype: torch.dtype, has_bias: bool, per_channel_quant: bool
3319+
):
3320+
super().__init__()
3321+
self.dtype = dtype
3322+
self.has_bias = has_bias
3323+
self.b = torch.randint(
3324+
q_min, q_max, [in_feature, out_feature], dtype=torch.int8
3325+
)
3326+
self.per_channel_quant = per_channel_quant
3327+
a_scale_per_tensor = torch.rand([1], dtype=dtype) * 0.01 + 0.01
3328+
a_scale_per_channel = torch.rand([M, 1], dtype=dtype) * 0.01 + 0.01
3329+
self.a_scale = (
3330+
a_scale_per_cha F438 nnel
3331+
if self.per_channel_quant
3332+
else a_scale_per_tensor
3333+
)
3334+
self.b_scale = torch.rand([out_feature]) * 0.01 + 0.01
3335+
self.b_scale = self.b_scale.to(dtype)
3336+
self.bias = torch.rand([out_feature], dtype=dtype) if has_bias else None
3337+
3338+
def forward(self, a):
3339+
out_shape = a.shape[:-1] + (self.b.size(-1),)
3340+
a_reshaped = a.reshape(-1, a.size(-1))
3341+
c = torch._int_mm(a_reshaped, self.b)
3342+
c = c.to(self.dtype)
3343+
c_shape = c.shape
3344+
a_scale = self.a_scale.expand(c.shape)
3345+
c = c * a_scale
3346+
c = c * self.b_scale
3347+
if self.has_bias:
3348+
c = c.reshape([1, *list(c_shape)])
3349+
c = c + self.bias
3350+
c = c.reshape(c_shape)
3351+
c = c.reshape(out_shape)
3352+
return c
3353+
3354+
mod = Mod(dtype, has_bias, per_channel_quant).eval()
3355+
a = torch.randint(q_min, q_max, [1, M, in_feature], dtype=torch.int8)
3356+
3357+
def matcher_check_fn():
3358+
self.assertEqual(
3359+
counters["inductor"]["qlinear_weight_prepack_matcher_count"], 1
3360+
)
3361+
if dynamic:
3362+
nodes_count = 10 if has_bias else 7
3363+
else:
3364+
nodes_count = 7 if has_bias else 6
3365+
self.assertEqual(
3366+
counters["inductor"]["qlinear_weight_prepack_matcher_nodes"],
3367+
nodes_count,
3368+
)
3369+
3370+
self._test_common(
3371+
mod,
3372+
(a,),
3373+
matcher_check_fn=matcher_check_fn,
3374+
check_autocast=dtype,
3375+
compile_options={"dynamic": dynamic},
3376+
)
3377+
32893378

32903379
@dynamo_config.patch({"dynamic_shapes": True, "assume_static_by_default": False})
32913380
class TestDynamicPatternMatcher(TestPatternMatcherBase):

0 commit comments

Comments
 (0)
0