8000 [Quant][Inductor] Bug fix: mutation nodes not handled correctly for QLinearPointwiseBinaryPT2E by Xia-Weiwen · Pull Request #127592 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[Quant][Inductor] Bug fix: mutation nodes not handled correctly for QLinearPointwiseBinaryPT2E #127592

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 12 commits into from
Closed
13 changes: 13 additions & 0 deletions test/inductor/test_mkldnn_pattern_matcher.py
8000
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ def _test_code_common(
inputs,
include_ops,
exclude_ops,
num_include_ops=None,
atol=1e-5,
rtol=1.3e-6,
check_quantization=False,
Expand All @@ -245,6 +246,10 @@ def _test_code_common(
)
for op in include_ops:
self.assertIn(op, source_code)
if num_include_ops is not None:
assert len(include_ops) == len(num_include_ops)
for i in range(len(include_ops)):
self.assertEqual(source_code.count(include_ops[i]), num_include_ops[i])
for op in exclude_ops:
self.assertNotIn(op, source_code)
if check_dynamic is not None:
Expand Down Expand Up @@ -1775,6 +1780,14 @@ def matcher_check_fn():
matcher_check_fn=matcher_check_fn,
is_qat=is_qat,
)
self._test_code_common(
mod,
(v,),
["torch.ops.onednn.qlinear_pointwise.default", "torch.ops.onednn.qlinear_pointwise.binary"],
[],
num_include_ops=[2, 2],
check_quantization=True,
)

@skipIfNoDynamoSupport
@skipIfNoONEDNN
Expand Down
10 changes: 10 additions & 0 deletions torch/_inductor/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -7196,6 +7196,7 @@ def __init__(
constant_args=(),
has_bias=True,
x_scale_zp_are_tensors=False,
has_mutation=False,
):
"""
if bias is not None
Expand All @@ -7208,7 +7209,9 @@ def __init__(
fp32_output, binary_attr, aplha, unary_attr, unary_scalars, unary_algorithm]
"""
self.has_bias = has_bias
self.idx_for_inplace_sum = -1
self.x_scale_zp_are_tensors = x_scale_zp_are_tensors
self.has_mutation = has_mutation
super().__init__(
layout,
inputs,
Expand Down Expand Up @@ -7325,6 +7328,12 @@ def codegen(self, wrapper):
if isinstance(self.layout, Layout):
self.codegen_size_asserts(wrapper)

def get_mutation_names(self):
return [self.inputs[self.idx_for_inplace_sum].get_name()] if self.has_mutation else []

def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]:
return set()

@classmethod
def create(
cls,
Expand Down Expand Up @@ -7394,6 +7403,7 @@ def create(
constant_args=constant_args,
has_bias=(bias is not None),
x_scale_zp_are_tensors=x_scale_zp_are_tensors,
has_mutation=True,
)
mark_node_as_mutating(packed, other)
# Return other since it has been inplace changed.
Expand Down
0