|
21 | 21 | )
|
22 | 22 | from torch.testing._internal.common_utils import (
|
23 | 23 | instantiate_parametrized_tests,
|
| 24 | + IS_FBCODE, |
24 | 25 | IS_LINUX,
|
25 | 26 | parametrize,
|
26 | 27 | skipIfNoXPU,
|
@@ -147,6 +148,7 @@ def _test_common(
|
147 | 148 | dtype=None,
|
148 | 149 | is_dynamic=False,
|
149 | 150 | quantizer=None,
|
| 151 | + compile_options={}, # noqa: B006 |
150 | 152 | ):
|
151 | 153 | counters.clear()
|
152 | 154 | torch._dynamo.reset()
|
@@ -177,7 +179,7 @@ def _test_common(
|
177 | 179 | with torch.no_grad(), maybe_autocast:
|
178 | 180 | clone_inputs = self._clone_inputs(inputs)
|
179 | 181 | expected = mod(*inputs)
|
180 |
| - actual = torch.compile(mod)(*clone_inputs) |
| 182 | + actual = torch.compile(mod, **compile_options)(*clone_inputs) |
181 | 183 | torch.testing.assert_close(actual, expected, atol=atol, rtol=rtol)
|
182 | 184 | matcher_check_fn()
|
183 | 185 |
|
@@ -3286,6 +3288,93 @@ def test_linear_dynamic_fp16(self):
|
3286 | 3288 | def test_linear_relu_dynamic_fp16(self):
|
3287 | 3289 | self._test_linear_dynamic_fp16_helper(use_relu=True)
|
3288 | 3290 |
|
| 3291 | + @skipIfNoDynamoSupport |
| 3292 | + @skipIfNoONEDNN |
| 3293 | + # TODO: investigate options of torch.compile in fbcode |
| 3294 | + @unittest.skipIf(IS_FBCODE, "Failing in fbcode") |
| 3295 | + @parametrize("has_bias", [True, False]) |
| 3296 | + @parametrize("dtype", [torch.float, torch.bfloat16]) |
| 3297 | + @parametrize("per_channel_quant", [True, False]) |
| 3298 | + @parametrize("dynamic", [True, False]) |
| 3299 | + def test_smooth_quant_with_int_mm( |
| 3300 | + self, has_bias, dtype, per_channel_quant, dynamic |
| 3301 | + ): |
| 3302 | + r""" |
| 3303 | + This testcase check if we can match the SmoothQuant int8 linear pattern from Torchao. |
| 3304 | + The pattern is: |
| 3305 | + (no bias) reshape -> _int_mm -> convert_element_type -> (expand -> mul) -> mul -> reshape |
| 3306 | + or |
| 3307 | + (with bias) pattern_no_bias -> add -> reshape -> reshape |
| 3308 | + """ |
| 3309 | + if dtype == torch.bfloat16 and not torch.ops.mkldnn._is_mkldnn_bf16_supported(): |
| 3310 | + return |
| 3311 | + M = 16 |
| 3312 | + in_feature = 32 |
| 3313 | + out_feature = 64 |
| 3314 | + q_min, q_max = -32, 31 |
| 3315 | + |
| 3316 | + class Mod(torch.nn.Module): |
| 3317 | + def __init__( |
| 3318 | + self, dtype: torch.dtype, has_bias: bool, per_channel_quant: bool |
| 3319 | + ): |
| 3320 | + super().__init__() |
| 3321 | + self.dtype = dtype |
| 3322 | + self.has_bias = has_bias |
| 3323 | + self.b = torch.randint( |
| 3324 | + q_min, q_max, [in_feature, out_feature], dtype=torch.int8 |
| 3325 | + ) |
| 3326 | + self.per_channel_quant = per_channel_quant |
| 3327 | + a_scale_per_tensor = torch.rand([1], dtype=dtype) * 0.01 + 0.01 |
| 3328 | + a_scale_per_channel = torch.rand([M, 1], dtype=dtype) * 0.01 + 0.01 |
| 3329 | + self.a_scale = ( |
| 3330 | + a_scale_per_cha
F438
nnel |
| 3331 | + if self.per_channel_quant |
| 3332 | + else a_scale_per_tensor |
| 3333 | + ) |
| 3334 | + self.b_scale = torch.rand([out_feature]) * 0.01 + 0.01 |
| 3335 | + self.b_scale = self.b_scale.to(dtype) |
| 3336 | + self.bias = torch.rand([out_feature], dtype=dtype) if has_bias else None |
| 3337 | + |
| 3338 | + def forward(self, a): |
| 3339 | + out_shape = a.shape[:-1] + (self.b.size(-1),) |
| 3340 | + a_reshaped = a.reshape(-1, a.size(-1)) |
| 3341 | + c = torch._int_mm(a_reshaped, self.b) |
| 3342 | + c = c.to(self.dtype) |
| 3343 | + c_shape = c.shape |
| 3344 | + a_scale = self.a_scale.expand(c.shape) |
| 3345 | + c = c * a_scale |
| 3346 | + c = c * self.b_scale |
| 3347 | + if self.has_bias: |
| 3348 | + c = c.reshape([1, *list(c_shape)]) |
| 3349 | + c = c + self.bias |
| 3350 | + c = c.reshape(c_shape) |
| 3351 | + c = c.reshape(out_shape) |
| 3352 | + return c |
| 3353 | + |
| 3354 | + mod = Mod(dtype, has_bias, per_channel_quant).eval() |
| 3355 | + a = torch.randint(q_min, q_max, [1, M, in_feature], dtype=torch.int8) |
| 3356 | + |
| 3357 | + def matcher_check_fn(): |
| 3358 | + self.assertEqual( |
| 3359 | + counters["inductor"]["qlinear_weight_prepack_matcher_count"], 1 |
| 3360 | + ) |
| 3361 | + if dynamic: |
| 3362 | + nodes_count = 10 if has_bias else 7 |
| 3363 | + else: |
| 3364 | + nodes_count = 7 if has_bias else 6 |
| 3365 | + self.assertEqual( |
| 3366 | + counters["inductor"]["qlinear_weight_prepack_matcher_nodes"], |
| 3367 | + nodes_count, |
| 3368 | + ) |
| 3369 | + |
| 3370 | + self._test_common( |
| 3371 | + mod, |
| 3372 | + (a,), |
| 3373 | + matcher_check_fn=matcher_check_fn, |
| 3374 | + check_autocast=dtype, |
| 3375 | + compile_options={"dynamic": dynamic}, |
| 3376 | + ) |
| 3377 | + |
3289 | 3378 |
|
3290 | 3379 | @dynamo_config.patch({"dynamic_shapes": True, "assume_static_by_default": False})
|
3291 | 3380 | class TestDynamicPatternMatcher(TestPatternMatcherBase):
|
|
0 commit comments