10000 [inductor][cpp][gemm] improve large bs perf with better cache blocking by jgong5 · Pull Request #132729 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[inductor][cpp][gemm] improve large bs perf with better cache blocking #132729

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 32 additions & 1 deletion test/inductor/test_cpu_select_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,17 @@

try:
try:
from . import test_torchinductor
from . import test_cpu_repro, test_torchinductor
except ImportError:
import test_cpu_repro
import test_torchinductor
except unittest.SkipTest:
if __name__ == "__main__":
sys.exit(0)
raise

check_model = test_torchinductor.check_model
set_num_threads = test_cpu_repro.set_num_threads

aten = torch.ops.aten

Expand Down Expand Up @@ -744,6 +746,35 @@ def forward(self, x):
self.common(mod, (v,), atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)

@inductor_config.patch({"freezing": True})
@inductor_config.patch({"cpp.gemm_cache_blocking": "2,2,2"})
@patches
@torch.no_grad
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
@set_num_threads(1)
@parametrize("batch_size", (1024,))
@parametrize("in_features", (1024,))
@parametrize("out_features", (1024,))
@parametrize("bias", (True, False))
@dtypes(torch.float, torch.bfloat16, torch.half)
def test_linear_cache_blocking(
self, batch_size, in_features, out_features, bias, dtype
):
class M(torch.nn.Module):
def __init__(self, bias):
super().__init__()
self.linear = torch.nn.Linear(in_features, out_features, bias)

def forward(self, x):
return self.linear(x)

counters.clear()
v = torch.randn(batch_size, in_features).to(dtype=dtype)
mod = M(bias=bias).to(dtype=dtype).eval()
with verify(dtype) as (atol, rtol):
self.common(mod, (v,), atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)


@dynamo_config.patch({"dynamic_shapes": True, "assume_static_by_default": False})
class _DynamicShapesTestBase(BaseTestSelectAlgorithm):
Expand Down
114 changes: 78 additions & 36 deletions torch/_inductor/codegen/cpp_gemm_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
const auto Kt_blocks = K0_blocks;
{%- endif %}
const int64_t Mc_blocks = Mt_blocks;
const int64_t Nc_blocks = 1;
const int64_t Kc_blocks = Kt_blocks;
const int64_t num_Mc_blocks = (M0_blocks + Mc_blocks - 1) / Mc_blocks;
const int64_t num_Nc_blocks = N0_blocks;
Expand All @@ -76,9 +77,10 @@
constexpr int64_t Nt_blocks = {{template.thread_blocking().block_n}};
constexpr int64_t Kt_blocks = {{template.thread_blocking().block_k}};
constexpr int64_t Mc_blocks = {{template.cache_blocking().block_m}};
constexpr int64_t Nc_blocks = {{template.cache_blocking().block_n}};
constexpr int64_t Kc_blocks = {{template.cache_blocking().block_k}};
constexpr int64_t num_Mc_blocks = (M0_blocks + Mc_blocks - 1) / Mc_blocks;
constexpr int64_t num_Nc_blocks = N0_blocks;
constexpr int64_t num_Nc_blocks = (N0_blocks + Nc_blocks - 1) / Nc_blocks;
constexpr int64_t num_k_slices = (K0_blocks + Kt_blocks - 1) / Kt_blocks;
{%- endif %}

Expand Down Expand Up @@ -124,28 +126,33 @@
const int64_t m_size = m_end - m_start;
{%- if use_local_acc %}
{%- set acc_buf_name = "local_acc_buf" %}
{{ kernel.define_buffer(acc_buf_name, ["m_end - m_start", "N0"], acc_buf_dtype) }}
{{ kernel.define_buffer(acc_buf_name, ["m_end - m_start", "Nc_blocks*N0"], acc_buf_dtype) }}
{%- endif %}
for (int64_t nc = n_block_start; nc < n_block_end; ++nc) {
for (int64_t nc = n_block_start; nc < n_block_end; nc += Nc_blocks) {
const int64_t n_start = nc * N0;
const int64_t n_end = std::min((nc + 1) * N0, N);
const int64_t n_end = std::min(std::min(nc + Nc_blocks, n_block_end) * N0, N);
const int64_t n_size = n_end - n_start;
// NB: assume we pad N, nc_block_end won't exceed padded N here.
const int64_t nc_block_end = std::min(nc + Nc_blocks, n_block_end);
{%- if use_local_acc %}
{%- set acc = kernel.local_buffers[acc_buf_name] %}
{{ kernel.reinit_buffer_if_null(acc_buf_name) }}
{%- else %}
{%- set acc = kernel.slice_nd(GemmOut, [("m_start", "m_end"), ("n_start", "n_start + N0")]) %}
{%- set acc = kernel.slice_nd(GemmOut, [("m_start", "m_end"), ("n_start", "n_end")]) %}
{%- endif %}
for (int64_t kc = k_block_start; kc < k_block_end; kc += Kc_blocks) {
int64_t k_start = kc * K0;
int64_t k_end = std::min(std::min(kc + Kc_blocks, k_block_end) * K0, K);
{%- set tile_X = kernel.slice_nd(X, [("m_start", "m_end"), ("k_start", "k_end")]) %}
{%- set tile_W_3d = kernel.slice_nd(W, [("nc", "nc + 1"), ("k_start", "k_end"), ()]) %}
{%- set tile_W = kernel.view(tile_W_3d, ["k_end - k_start", micro_gemm.register_blocking.block_n]) %}
if (kc == k_block_start) {
{{ micro_gemm.codegen_call(kernel, tile_X, tile_W, acc, accum=False)|indent(24, false) }}
} else {
{{ micro_gemm.codegen_call(kernel, tile_X, tile_W, acc, accum=True)|indent(24, false) }}
for (int64_t nci = nc; nci < nc_block_end; nci++) {
{%- set acc_slice = kernel.slice_nd(acc, [(), ("(nci - nc)*N0", "(nci - nc + 1)*N0")]) %}
{%- set tile_W_3d = kernel.slice_nd(W, [("nci", "nci + 1"), ("k_start", "k_end"), ()]) %}
{%- set tile_W = kernel.view(tile_W_3d, ["k_end - k_start", micro_gemm.register_blocking.block_n]) %}
if (kc == k_block_start) {
{{ micro_gemm.codegen_call(kernel, tile_X, tile_W, acc_slice, accum=False)|indent(28, false) }}
} else {
{{ micro_gemm.codegen_call(kernel, tile_X, tile_W, acc_slice, accum=True)|indent(28, false) }}
}
}
}
{%- if maybe_k_slicing %}
Expand All @@ -155,13 +162,8 @@
} else
{%- endif %}
{
{%- if N == PADDED_N %}
{%- set tile_Y = kernel.slice_nd(Y_2d, [("m_start", "m_end"), ("n_start", "n_start + N0")]) %}
{%- set tile_acc = acc %}
{%- else %}
{%- set tile_Y = kernel.slice_nd(Y_2d, [("m_start", "m_end"), ("n_start", "n_end")]) %}
{%- set tile_acc = kernel.slice_nd(acc 10000 , [(), ("0", "n_end - n_start")]) %}
{%- endif %}
{{ kernel.store_output(
tile_Y, tile_acc, GemmOut, epilogue_nodes, offsets=("m_start", "n_start"), reindexers=reindexers
)|indent(20, false)
Expand All @@ -182,9 +184,9 @@
const int64_t m_end = std::min(m_start_unsliced + m_slice_size * (k_slice_id + 1), m_end_unsliced);
const int64_t m_size = m_end - m_start;
const int64_t m_offset = m_start - m_start_unsliced;
for (int64_t nc = n_block_start; nc < n_block_end; ++nc) {
for (int64_t nc = n_block_start; nc < n_block_end; nc += Nc_blocks) {
const int64_t n_start = nc * N0;
const int64_t n_end = std::min((nc + 1) * N0, N);
const int64_t n_end = std::min(std::min(nc + Nc_blocks, n_block_end) * N0, N);
const int64_t n_size = n_end - n_start;
const int64_t mxn_cache_block_id = mc * num_Nc_blocks + nc;
auto {{acc_buf_name}} = local_buf_ptrs[mxn_cache_block_id * num_k_slices].get();
Expand Down Expand Up @@ -350,11 +352,22 @@ def get_cache_blocking(register_blocking, thread_blocking):
N0 = register_blocking.block_n
K0 = register_blocking.block_k

Mc_blocks = thread_blocking.block_m
# Nc_blocks is always 1
Nc_blocks = 1
Kc_blocks = thread_blocking.block_k
Mt_blocks = thread_blocking.block_m
Nt_blocks = thread_blocking.block_n
Kt_blocks = thread_blocking.block_k

if config.cpp.gemm_cache_blocking is not None:
blockings = [int(i) for i in config.cpp.gemm_cache_blocking.split(",")]
assert len(blockings) == 3
Mc_blocks, Nc_blocks, Kc_blocks = blockings
return (
min(Mc_blocks, Mt_blocks),
min(Nc_blocks, Nt_blocks),
min(Kc_blocks, Kt_blocks),
)

# The ratios below are empirically determined to decide
# the effective sizes of L1 and L2.
# TODO: tune the factor here
L1_limit_factor = 1
L2_limit_factor = 0.5
Expand All @@ -365,33 +378,62 @@ def get_cache_blocking(register_blocking, thread_blocking):
assert (
L1_cache_size > 0
), f"Expect L1_cache_size > 0 but got {L1_cache_size}"
L1 = L1_cache_size * L1_limit_factor

L2_cache_size = (
torch._C._cpu._L2_cache_size()
) # per core cache size in Bytes
assert (
L2_cache_size > 0
), f"Expect L2_cache_size > 0 but got {L2_cache_size}"
B_size_limit = L1_cache_size * L1_limit_factor
A_size_limit = L2_cache_size * L2_limit_factor
L2 = L2_cache_size * L2_limit_factor

def get_num_byte(dtype):
return torch.tensor([], dtype=dtype).element_size()

num_byte_A = get_num_byte(self.input_nodes[0].get_dtype())
num_byte_B = get_num_byte(self.input_nodes[1].get_dtype())

size_cache_B = K0 * Kc_blocks * N0 * Nc_blocks * num_byte_B

if size_cache_B > B_size_limit:
Kc_blocks = math.floor(
B_size_limit / (K0 * N0 * Nc_blocks * num_byte_B)
)

size_cache_A = M0 * Mc_blocks * K0 * Kc_blocks * num_byte_A
if size_cache_A > A_size_limit:
Mc_blocks = math.floor(
A_size_limit / (M0 * Kc_blocks * K0 * num_byte_A)
)
# NOTE [CPP GEMM Cache Blocking Algorithm]
# Our overall strategy is to
# 1) Make cache blocks of B L1-reside and reused by multiple rows of A, i.e. Mc.
# Here, B is Kc x Nr where Nr is a single register block. We use L1 size to
# decide Kc. We want to make Mc large enough to better reuse B.
# 2) Make cache blocks of A L2-reside, which would limit Mc. We want to reuse A
# along N, where we have two sub-strategies (see notes below) to decide Mc and Nc.

# Step 1: Decide Kc assuming B block is L1-reside.
size_cache_B = K0 * Kt_blocks * N0 * num_byte_B
Kc_blocks = Kt_blocks
if size_cache_B > L1:
Kc_blocks = math.floor(L1 / (K0 * N0 * num_byte_B))

# Step 2: Decide Mc assuming A block is L2-reside.
min_Mc_ratio = 2 # TODO(jgong5): something to tune?
min_Mc_blocks = math.ceil(min_Mc_ratio * M0 / N0)
assert min_Mc_blocks >= 1
Kt_bytes = Kt_blocks * K0 * num_byte_A
if min_Mc_blocks * M0 * Kt_bytes < L2:
# Strategy 1: A (Mc x Kt) resides in L2 and reused by all Nt
# when Nc_blocks is kept 1. Mc should be large enough (>= min_Mc_blocks)
# to reuse B (Kc x Nr) in L1. This makes C (Mc x Nr) small enough to reside
# in L1.
Mc_blocks = min(Mt_blocks, math.floor(L2 / (M0 * Kt_bytes)))
Nc_blocks = 1
else:
# Strategy 2: Kt is too large to hold A (Mc x Kt) in L2, we reuse
# A (Mc x Kc) in L2 by B (Kc x Nc). C (Mc x Nc) resides in L2.
Mc_blocks = Mt_blocks
Nc_blocks = min(math.ceil(Mc_blocks * M0 / N0), Nt_blocks)
Nc_bytes = Nc_blocks * N0 * 4 # assume C or acc is float32/int32
Kc_bytes = Kc_blocks * K0 * num_byte_A
if Mc_blocks * M0 * (Kc_bytes + Nc_bytes) > L2:
# The following is the solution for 4*Mc*Nc + Mc*Kc_bytes = L2,
# assuming Mc == Nc for good data reuse.
M_max = (math.sqrt(Kc_bytes * Kc_bytes + 16 * L2) - Kc_bytes) / 8
if M_max < Mc_blocks * M0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should it be an assert theoretically? But considering we use some approximated calculation. I guess it should be fine.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not an assertion. When it is false, we use the default one which is Mt_blocks.

Mc_blocks = math.floor(M_max / M0)
Nc_blocks = min(math.ceil(Mc_blocks * M0 / N0), Nt_blocks)

return Mc_blocks, Nc_blocks, Kc_blocks

Expand Down
2 changes: 1 addition & 1 deletion torch/_inductor/codegen/cpp_template_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def slice_nd(self, node, ranges: List[Tuple[Any, Any]]) -> ir.ReinterpretView:
Slice the given node with a list of ranges (start and end) corresponding to its dims.
The dim is not sliced if the corresponding range is empty.
"""
assert len(ranges) == len(node.get_size())
assert len(ranges) == len(node.get_size()), f"{ranges=}, {node=}"
sliced = wrap_with_tensorbox(node)
for dim, _range in enumerate(ranges):
if len(_range) == 0:
Expand Down
6 changes: 6 additions & 0 deletions torch/_inductor/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,6 +744,12 @@ class cpp:
# When set to 0, the number of slices is unlimited.
gemm_max_k_slices = int(os.environ.get("TORCHINDUCTOR_CPP_GEMM_MAX_K_SLICES", "1"))

# For perf tuning and debugging purpose, configure the pre-defined cache blocking for
# MxNxK dims respectively. The blockings are separated by comma and the unit is
# the number of register blocks.
# For example, "4,1,10" means 4 register blocks on M, 1 on N and 10 on K respectively.
gemm_cache_blocking = os.environ.get("TORCHINDUCTOR_CPP_GEMM_CACHE_BLOCKING", None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a test that enables this config via @config.patch(gemm_cache_blocking="...")

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Added.



# config specific to codegen/triton.py
class triton:
Expand Down
Loading
0