8000 [Inductor-CPU] int8 WoQ concat linear by sanchitintel · Pull Request #153004 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[Inductor-CPU] int8 WoQ concat linear #153004

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 19 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions test/inductor/test_cpu_select_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1434,6 +1434,90 @@ def forward(self, x, scale):
vec_amx = VecAMX()
self._check_amx_counter(vec_amx)

@inductor_config.patch({"freezing": True, "cpp.enable_concat_linear": True})
@patches
@torch.no_grad
@dtypes(torch.bfloat16)
@parametrize(
"batch_size",
(
1,
32,
),
)
@parametrize(
"mid_dim",
(
1,
8,
),
)
@parametrize("in_features", (128,))
@parametrize("out_features", (64,))
def test_int8_woq_mm_concat(
self, dtype, batch_size, mid_dim, in_features, out_features
):
def _convert_weight_to_int8pack(w):
scale, zp = _calculate_dynamic_per_channel_qparams(
w.to(torch.float), torch.int8
)
scale = torch.from_numpy(scale)
zp = torch.from_numpy(zp)
w_int8 = torch.ao.quantization.fx._decomposed.quantize_per_channel(
input=w,
scales=scale,
zero_points=zp,
axis=0,
quant_min=-128,
quant_max=127,
dtype=torch.int8,
)
return w_int8, scale.to(torch.bfloat16)

class M(torch.nn.Module):
def __init__(self, w1, w2, w3):
super().__init__()
self.w1 = torch.nn.Parameter(w1, requires_grad=False)
self.w2 = torch.nn.Parameter(w2, requires_grad=False)
self.w3 = torch.nn.Parameter(w3, requires_grad=False)

def forward(self, x, scale1, scale2, scale3):
# Ref: _linear_fp_act_int8_weight_impl in torchao/dtypes/uintx/plain_layout.py
y1 = (
torch.mm(x.reshape(-1, x.shape[-1]), self.w1.t().to(x.dtype))
* scale1
)
y2 = (
torch.mm(x.reshape(-1, x.shape[-1]), self.w2.t().to(x.dtype))
* scale2
)
y3 = (
torch.mm(x.reshape(-1, x.shape[-1]), self.w3.t().to(x.dtype))
* scale3
)
return (
y1.reshape(*x.shape[:-1], y1.shape[-1]),
y2.reshape(*x.shape[:-1], y2.shape[-1]),
y3.reshape(*x.shape[:-1], y3.shape[-1]),
)

counters.clear()
# Currently, the corresponding torch.fx pattern only supports 3D x
# Add 2D X case once the corresponding pattern-matcher pattern is added
x = torch.rand((batch_size, mid_dim, in_features), dtype=dtype)
w1 = torch.rand((out_features, in_features), dtype=dtype)
w2 = torch.rand((out_features, in_features), dtype=dtype)
w3 = torch.rand((out_features, in_features), dtype=dtype)
w1_int8pack, w1_scales = _convert_weight_to_int8pack(w1)
w2_int8pack, w2_scales = _convert_weight_to_int8pack(w2)
w3_int8pack, w3_scales = _convert_weight_to_int8pack(w3)
mod = M(w1_int8pack, w2_int8pack, w3_int8pack).eval()
self.common(mod, (x, w1_scales, w2_scales, w3_scales))
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
if batch_size * mid_dim >= 16:
vec_amx = VecAMX()
self._check_amx_counter(vec_amx)

@unittest.skipIf(
not torch._C._cpu._is_amx_tile_supported(), "AMX ISA support is required"
)
Expand Down
63 changes: 62 additions & 1 deletion torch/_inductor/fx_passes/freezing_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,53 @@ def register_binary_folding_pattern(pattern, extra_check=_return_true):

@functools.cache
def addmm_patterns_init():
"""
addmm related patterns.
To avoid duplication, also includes int8 WoQ GEMM pattern without bias.
"""
device = next(
(gpu for gpu in GPU_TYPES if getattr(torch, gpu).is_available()), "cpu"
)
val = functools.partial(torch.empty, (10, 10), device=device, requires_grad=False)
scale = functools.partial(torch.empty, (10,), device=device, requires_grad=False)

def check_int8_woq_concat_linear_weights(match):
is_cpu = match.kwargs["inp"].meta["val"].is_cpu
if not is_cpu or not config.cpp.enable_concat_linear:
# Currently, this pattern is only supported on CPU
return False

weight_inputs = ["w1", "w2"]
if "w3" in match.kwargs:
weight_inputs.append("w3")

if not all(
match.kwargs[wgt].target == torch.ops.prims.convert_element_type.default
for wgt in weight_inputs
):
return False

if not all(
next(iter(match.kwargs[wgt]._input_nodes.keys())).meta["val"].dtype
is torch.int8
for wgt in weight_inputs
):
return False

if not all(
match.kwargs[wgt].meta["val"].dtype is torch.bfloat16
for wgt in weight_inputs
):
return False

equal_shape_inputs = [weight_inputs]
for equal_shape_group in equal_shape_inputs:
inps = [match.kwargs[name] for name in equal_shape_group]
if not all(
inp.meta["val"].shape == inps[0].meta["val"].shape for inp in inps
):
return False
return True

def check_concat_weights(match):
is_cpu = match.kwargs["inp"].meta["val"].is_cpu
Expand Down Expand Up @@ -153,9 +196,27 @@ def check_concat_weights(match):
for inp in inps
):
return False

return True

def int8_woq_fusion_pattern(inp, w1, w2, w3, s1, s2, s3):
return ((inp @ w1) * s1, (inp @ w2) * s2, (inp @ w3) * s3)

def int8_woq_fusion_replacement(inp, w1, w2, w3, s1, s2, s3):
cat_w = torch.cat((w1, w2, w3), dim=1)
cat_s = torch.cat((s1, s2, s3), dim=0)
mm = (inp @ cat_w).mul(cat_s)
return mm.chunk(3, dim=1)

register_replacement(
int8_woq_fusion_pattern,
int8_woq_fusion_replacement,
[val(), val(), val(), val(), scale(), scale(), scale()],
fwd_only,
pass_patterns[0],
extra_check=check_int8_woq_concat_linear_weights,
exclusive_arg_names=("w1", "w2", "w3", "s1", "s2", "s3"),
)

def matmul_fuse_pattern(inp, w1, w2, w3):
return (inp @ w1, inp @ w2, inp @ w3)

Expand Down
2 changes: 2 additions & 0 deletions torch/_inductor/fx_passes/mkldnn_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from .freezing_patterns import register_freezing_graph_pattern
from .post_grad import register_lowering_pattern
from .quantization import (
_register_int8_woq_concat_linear_pattern,
_register_quantization_lowerings,
_register_quantization_weight_pack_pass,
_register_woq_lowerings,
Expand Down Expand Up @@ -1420,3 +1421,4 @@ def _mkldnn_weight_pack_init():
_register_weight_pack_pass()
_recover_linear()
_register_quantization_weight_pack_pass()
_register_int8_woq_concat_linear_pattern()
151 changes: 150 additions & 1 deletion torch/_inductor/fx_passes/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,17 @@
from torch.fx.experimental.symbolic_shapes import has_free_symbols
from torch.fx.node import map_arg

from .. import config
from ..lowering import lowerings as L, require_channels_last
from ..pattern_matcher import Arg, CallFunction, filter_nodes, KeywordArg, ListOf, Match
from ..pattern_matcher import (
Arg,
CallFunction,
filter_nodes,
KeywordArg,
ListOf,
Match,
stable_topological_sort,
)
from ..utils import pad_listlike
from .freezing_patterns import register_freezing_graph_pattern
from .post_grad import register_lowering_pattern
Expand Down Expand Up @@ -1068,6 +1077,53 @@ def _register_quantization_reshape():
)


def _is_valid_concat_linear_int8_woq_optimization_pattern():
def fn(match):
if not config.cpp.enable_concat_linear:
return False
assert all(k in match.kwargs for k in ("x", "w1", "w2", "w3", "scales"))
if not all(
hasattr(match.kwargs[key], "meta")
for key in ["x", "w1", "w2", "w3", "scales"]
):
return False
x = match.kwargs["x"].meta["val"]
w1 = match.kwargs["w1"].meta["val"]
w2 = match.kwargs["w2"].meta["val"]
w3 = match.kwargs["w3"].meta["val"]
scales = match.kwargs["scales"].meta["val"]
if len(match.kwargs["scales"].meta["val"].size()) > 1:
return False
num_scales = match.kwargs["scales"].meta["val"].numel()
w1_cols = match.kwargs["w1"].meta["val"].size()[0]
w2_cols = match.kwargs["w2"].meta["val"].size()[0]
w3_cols = match.kwargs["w3"].meta["val"].size()[0]
# Technically, the shapes of the three weights need not be equal.
# But currently, we only enable replacement in this case.
if w1_cols != w2_cols or w2_cols != w3_cols:
return False
if 3 * w1_cols != num_scales:
return False
return (
# For now, we only support woq mm kernels
# with x.type=bfloat16 and w.type=int8
x.dtype == torch.bfloat16
and w1.dtype == torch.int8
and w2.dtype == torch.int8
and w3.dtype == torch.int8
and scales.dtype == torch.bfloat16
# _weight_int8pack_mm kernel only supports cpu now
# TODO: add cuda kernel support instead of calling mul+sum
and x.device.type == "cpu"
and x.device == w1.device
and w1.device == w2.device
and w2.device == w3.device
and x.device == scales.device
)

return fn


def _is_valid_woq_optimization_pattern():
def fn(match):
assert all(k in match.kwargs for k in ("x", "weight", "scales"))
Expand All @@ -1094,6 +1150,73 @@ def fn(match):
return fn


def _register_concat_linear_int8_woq_lowering(
pattern, computation_woq, computation_reshape
):
@register_freezing_graph_pattern(
pattern,
extra_check=_is_valid_concat_linear_int8_woq_optimization_pattern(),
pass_number=4,
)
def woq(match: Match, *args, **kwargs):
x = kwargs["x"]
w1 = kwargs["w1"]
w2 = kwargs["w2"]
w3 = kwargs["w3"]
scales = kwargs["scales"]
counters["inductor"]["woq_matcher_count"] += 1
counters["inductor"]["woq_matcher_nodes"] += len(match.nodes)
out_features = (
w1.meta["val"].size()[0]
+ w2.meta["val"].size()[0]
+ w3.meta["val"].size()[0]
)
origin_x_size = tuple(x.meta["val"].size())
x_shape = [-1, origin_x_size[-1]]
out_shape = list(origin_x_size[:-1] + (out_features,))
mm_node_of_x = None
for candidate in iter(x.users.keys()):
if (
candidate.target == aten.mm.default
and list(candidate._input_nodes)[1].target == aten.cat.default
):
mm_node_of_x = candidate
break
assert mm_node_of_x is not None, "unable to find mm node"
_, cat_wgt_node = mm_node_of_x._input_nodes
scaling_node = next(iter(mm_node_of_x.users.keys()))
user_of_scaling_node = next(iter(scaling_node.users.keys()))
# Some other pass is making some changes that entails
# adding a node before it's used, but it can only be found when
# lint is run. stable_topological_sort() is being run before lint,
# so that error was not being being discovered.
# We call stable_topological_sort here as a workaround.
stable_topological_sort(match.graph)
with match.graph.inserting_before(user_of_scaling_node):
new_cat_node = match.graph.call_function(
aten.cat.default,
args=([w1, w2, w3], 0),
)
x_reshape_node = match.graph.call_function(
computation_reshape, args=(x, x_shape)
)
new_woq_node = match.graph.call_function(
computation_woq,
args=(x_reshape_node, new_cat_node, scales),
)
new_woq_node.meta = copy.copy(x.meta)
output_reshape_node = match.graph.call_function(
computation_reshape, args=(new_woq_node, out_shape)
)
scaling_node.replace_all_uses_with(output_reshape_node)
match.graph.erase_node(scaling_node)
match.graph.erase_node(mm_node_of_x)
match.graph.erase_node(cat_wgt_node)
match.graph.lint()

return woq


def _register_woq_lowering(pattern, computation_woq, computation_reshape):
@register_lowering_pattern(
pattern,
Expand Down Expand Up @@ -1214,6 +1337,32 @@ def _register_woq_mm_int8_pattern4():
_register_woq_lowering(_woq_pattern, aten._weight_int8pack_mm.default, aten.reshape)


def _register_int8_woq_concat_linear_pattern():
def _create_wgt_node(wgt_node_name: str):
return CallFunction(
prims.convert_element_type.default,
CallFunction(
aten.permute.default,
KeywordArg(wgt_node_name),
Arg(),
),
Arg(),
)

cat_wgt = CallFunction(
aten.cat.default, [_create_wgt_node(wgt) for wgt in ["w1", "w2", "w3"]], 1
)

_woq_pattern = CallFunction(
aten.mul.Tensor,
CallFunction(aten.mm.default, KeywordArg("x"), cat_wgt),
KeywordArg("scales"),
)
_register_concat_linear_int8_woq_lowering(
_woq_pattern, aten._weight_int8pack_mm.default, aten.reshape
)


def _register_quantization_lowerings():
_register_quantization_unary_lowering()
_register_quantization_binary_lowering()
Expand Down
Loading
0