8000 Support transpose and pack for bit8 (#156065) · pytorch/pytorch@d26ca5d · GitHub
[go: up one dir, main page]

Skip to content

Commit d26ca5d

Browse files
Valentine233pytorchmergebot
authored andcommitted
Support transpose and pack for bit8 (#156065)
To be used by CPU INT8 SDPA in torchao. pytorch/ao#2380 Pull Request resolved: #156065 Approved by: https://github.com/mingfeima, https://github.com/ezyang
1 parent 2022588 commit d26ca5d

File tree

4 files changed

+230
-2
lines changed

4 files changed

+230
-2
lines changed

aten/src/ATen/cpu/vec/vec_quant.h

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
#pragma once
2+
3+
#include <ATen/cpu/vec/intrinsics.h>
4+
#include <c10/util/Exception.h>
5+
6+
namespace at::vec {
7+
// See Note [CPU_CAPABILITY namespace]
8+
inline namespace CPU_CAPABILITY {
9+
10+
// Transpose a [4, 64] block to [64, 4] (with contiguous output, ld=4)
11+
template <typename scalar_t, typename = std::enable_if_t<sizeof(scalar_t) == 1>>
12+
static inline void transpose_pad_4x64_block(
13+
const scalar_t* src,
14+
scalar_t* dst,
15+
int64_t ld_src,
16+
int krem = 4,
17+
int nrem = 64) {
18+
#if defined(CPU_CAPABILITY_AVX512)
19+
__m512i r[4];
20+
// Load with mask if partial
21+
if (nrem < 64) {
22+
__mmask64 mask = (1ULL << nrem) - 1;
23+
for (int i = 0; i < krem; ++i) {
24+
r[i] = _mm512_maskz_loadu_epi8(mask, src + i * ld_src);
25+
}
26+
for (int i = krem; i < 4; ++i) {
27+
r[i] = _mm512_setzero_si512();
28+
}
29+
} else {
30+
for (int i = 0; i < krem; ++i) {
31+
r[i] = _mm512_loadu_si512(
32+
reinterpret_cast<const __m512i*>(src + i * ld_src));
33+
}
34+
for (int i = krem; i < 4; ++i) {
35+
r[i] = _mm512_setzero_si512();
36+
}
37+
}
38+
39+
// Transpose 4x64 bytes using unpack and shuffle
40+
__m512i t0 = _mm512_unpacklo_epi8(r[0], r[1]);
41+
__m512i t1 = _mm512_unpackhi_epi8(r[0], r[1]);
42+
__m512i t2 = _mm512_unpacklo_epi8(r[2], r[3]);
43+
__m512i t3 = _mm512_unpackhi_epi8(r[2], r[3]);
44+
45+
__m512i u0 = _mm512_unpacklo_epi16(t0, t2);
46+
__m512i u1 = _mm512_unpackhi_epi16(t0, t2);
47+
__m512i u2 = _mm512_unpacklo_epi16(t1, t3);
48+
__m512i u3 = _mm512_unpackhi_epi16(t1, t3);
49+
50+
__m512i v0 = _mm512_shuffle_i32x4(u0, u1, 0x88);
51+
__m512i v1 = _mm512_shuffle_i32x4(u0, u1, 0xdd);
52+
__m512i v2 = _mm512_shuffle_i32x4(u2, u3, 0x88);
53+
__m512i v3 = _mm512_shuffle_i32x4(u2, u3, 0xdd);
54+
55+
__m512i r0 = _mm512_shuffle_i32x4(v0, v2, 0x88);
56+
__m512i r1 = _mm512_shuffle_i32x4(v1, v3, 0x88);
57+
__m512i r2 = _mm512_shuffle_i32x4(v0, v2, 0xdd);
58+
__m512i r3 = _mm512_shuffle_i32x4(v1, v3, 0xdd);
59+
60+
// Store output
61+
if (nrem < 16) {
62+
__mmask64 mask = (1ULL << (nrem * 4)) - 1;
63+
_mm512_mask_storeu_epi8(dst, mask, r0);
64+
} else if (nrem == 16) {
65+
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), r0);
66+
} else if (nrem < 32) {
67+
int n_bytes1 = 64;
68+
int n_bytes2 = (nrem * 4) - n_bytes1;
69+
__mmask64 mask = (1ULL << n_bytes2) - 1;
70+
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), r0);
71+
_mm512_mask_storeu_epi8(reinterpret_cast<__m512i*>(dst + 64), mask, r1);
72+
} else if (nrem == 32) {
73+
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), r0);
74+
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 64), r1);
75+
} else if (nrem < 48) {
76+
int n_bytes1 = 64 * 2;
77+
int n_bytes2 = (nrem * 4) - n_bytes1;
78+
__mmask64 mask = (1ULL << n_bytes2) - 1;
79+
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), r0);
80+
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 64), r1);
81+
_mm512_mask_storeu_epi8(reinterpret_cast<__m512i*>(dst + 64 * 2), mask, r2);
82+
} else if (nrem == 48) {
83+
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), r0);
84+
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 64), r1);
85+
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 64 * 2), r2);
86+
} else if (nrem < 64) {
87+
int n_bytes1 = 64 * 3;
88+
int n_bytes2 = (nrem * 4) - n_bytes1;
89+
__mmask64 mask = (1ULL << n_bytes2) - 1;
90+
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), r0);
91+
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 64), r1);
92+
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 64 * 2), r2);
93+
_mm512_mask_storeu_epi8(reinterpret_cast<__m512i*>(dst + 64 * 3), mask, r3);
94+
} else {
95+
// normal case, nrem == 64
96+
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), r0);
97+
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 64), r1);
98+
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 64 * 2), r2);
99+
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 64 * 3), r3);
100+
}
101+
#else
102+
TORCH_CHECK(
103+
false,
104+
"transpose_pad_4x64_block is only supported when AVX-512 is supported")
105+
#endif
106+
}
107+
108+
// Reorder [K, N] → [K/4, N, 4] (VNNI4-style layout for bit8)
109+
template <typename scalar_t, typename = std::enable_if_t<sizeof(scalar_t) == 1>>
110+
static inline void pack_vnni4(
111+
const scalar_t* src,
112+
scalar_t* dst,
113+
int64_t ld_src,
114+
int64_t K,
115+
int64_t N) {
116+
#if defined(CPU_CAPABILITY_AVX512)
117+
int64_t bk = 0;
118+
int64_t _K = K / 4 * 4;
119+
int64_t _N = N / 64 * 64;
120+
for (; bk < _K; bk += 4) {
121+
int64_t bn = 0;
122+
for (; bn < _N; bn += 64) {
123+
transpose_pad_4x64_block(
124+
src + bk * ld_src + bn, dst + bk * N + bn * 4, ld_src);
125+
}
126+
int64_t nrem = N - bn;
127+
if (nrem > 0) {
128+
transpose_pad_4x64_block(
129+
src + bk * ld_src + bn, dst + bk * N + bn * 4, ld_src, 4, nrem);
130+
}
131+
}
132+
133+
// Handle leftover K rows (< 4)
134+
if (K % 4 != 0) {
135+
int krem = K - bk;
136+
int64_t bn = 0;
137+
for (; bn < _N; bn += 64) {
138+
transpose_pad_4x64_block(
139+
src + bk * ld_src + bn, dst + bk * N + bn * 4, ld_src, krem);
140+
}
141+
int64_t nrem = N - bn;
142+
if (nrem > 0) {
143+
transpose_pad_4x64_block(
144+
src + bk * ld_src + bn, dst + bk * N + bn * 4, ld_src, krem, nrem);
145+
}
146+
}
147+
#else
148+
TORCH_CHECK(false, "pack_vnni4 is only supported when AVX-512 is supported")
149+
#endif
150+
}
151+
152+
} // namespace CPU_CAPABILITY
153+
} // namespace at::vec

aten/src/ATen/native/cpu/utils.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,12 @@ inline void transpose<uint16_t>(int64_t M, int64_t N, const uint16_t* src, int64
165165
TORCH_CHECK(fbgemm::fbgemmSupportedCPU(), "Your CPU does not support FBGEMM.");
166166
fbgemm::transpose_simd<uint16_t>(M, N, src, ld_src, dst, ld_dst);
167167
}
168+
169+
template <>
170+
inline void transpose<uint8_t>(int64_t M, int64_t N, const uint8_t* src, int64_t ld_src, uint8_t* dst, int64_t ld_dst) {
171+
TORCH_CHECK(fbgemm::fbgemmSupportedCPU(), "Your CPU does not support FBGEMM.");
172+
fbgemm::transpose_simd<uint8_t>(M, N, src, ld_src, dst, ld_dst);
173+
}
168174
#endif
169175

170176
template <typename index_t, typename F>

aten/src/ATen/test/vec_test_all_types.cpp

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ namespace {
6161
template <typename T>
6262
class QuantizationTests : public ::testing::Test {};
6363
template <typename T>
64+
class Quantization8BitTests : public ::testing::Test {};
65+
template <typename T>
6466
class Quantization8BitWithTailTests : public ::testing::Test {};
6567
template <typename T>
6668
class FunctionalTests : public ::testing::Test {};
@@ -79,6 +81,7 @@ namespace {
7981
using FloatTestedTypes = ::testing::Types<vfloat, vdouble, vcomplex, vcomplexDbl>;
8082
using ALLTestedTypes = ::testing::Types<vfloat, vdouble, vcomplex, vlong, vint, vshort, vqint8, vquint8, vqint>;
8183
using QuantTestedTypes = ::testing::Types<vqint8, vquint8, vqint>;
84+
using Quantization8BitTestedTypes = ::testing::Types<vqint8, vquint8>;
8285
#if (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && !defined(_MSC_VER)
8386
using Quantization8BitWithTailTestedTypes =
8487
::testing::Types<vqint8, vquint8>;
@@ -116,6 +119,7 @@ namespace {
116119
TYPED_TEST_SUITE(BitwiseFloatsAdditional, RealFloatReducedFloatTestedTypes);
117120
TYPED_TEST_SUITE(BitwiseFloatsAdditional2, FloatTestedTypes);
118121
TYPED_TEST_SUITE(QuantizationTests, QuantTestedTypes);
122+
TYPED_TEST_SUITE(Quantization8BitTests, Quantization8BitTestedTypes);
119123
TYPED_TEST_SUITE(InfiniteTests, RealFloatTestedTypes);
120124
#if (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && !defined(_MSC_VER)
121125
TYPED_TEST_SUITE(
@@ -1496,6 +1500,68 @@ namespace {
14961500
},
14971501
test_case);
14981502
}
1503+
#ifndef _WIN32
1504+
TYPED_TEST(Quantization8BitTests, Transpose) {
1505+
using VT = ValueType<TypeParam>;
1506+
constexpr auto M = 4;
1507+
constexpr auto N = 64;
1508+
constexpr auto L = M * N;
1509+
constexpr auto ld_src = N;
1510+
constexpr auto ld_dst = M;
1511+
CACHE_ALIGN VT x[L];
1512+
CACHE_ALIGN VT y[L];
1513+
CACHE_ALIGN VT ref[L];
1514+
auto seed = TestSeed();
1515+
ValueGen<VT> generator(VT(-100), VT(100), seed);
1516+
for (const auto i : c10::irange(L)) {
1517+
x[i] = generator.get();
1518+
}
1519+
at::native::utils::transpose<uint8_t>(
1520+
M, N,
1521+
reinterpret_cast<uint8_t*>(x), ld_src,
1522+
reinterpret_cast<uint8_t*>(y), ld_dst);
1523+
for (int64_t j = 0; j < N; j++) {
1524+
for (int64_t i = 0; i < M; i++) {
1525+
ref[j * ld_dst + i] = c10::load(&(x[i * ld_src + j]));
1526+
}
1527+
}
1528+
for (const auto i : c10::irange(L)) {
1529+
ASSERT_EQ(y[i], ref[i])
1530+
<< "Failure Details:\nTest Seed to reproduce: " << seed;
1531+
}
1532+
}
1533+
#endif
1534+
#if defined(CPU_CAPABILITY_AVX512)
1535+
TYPED_TEST(Quantization8BitTests, PackVNNI4) {
1536+
using VT = ValueType<TypeParam>;
1537+
constexpr auto K = 8;
1538+
constexpr auto N = 128;
1539+
constexpr auto L = K * N;
1540+
constexpr auto ld_src = N;
1541+
CACHE_ALIGN VT x[L];
1542+
CACHE_ALIGN VT y[L];
1543+
CACHE_ALIGN VT ref[L];
1544+
auto seed = TestSeed();
1545+
ValueGen<VT> generator(VT(-100), VT(100), seed);
1546+
for (const auto i : c10::irange(L)) {
1547+
x[i] = generator.get();
1548+
}
1549+
at::vec::pack_vnni4(x, y, ld_src, K, N);
1550+
int64_t _K = K / 4;
1551+
for (int64_t k = 0; k < _K; k++) {
1552+
for(int64_t n = 0; n < N; n++) {
1553+
for(int64_t l = 0; l < 4; l++) {
1554+
ref[k * N * 4 + n * 4 + l] =
1555+
c10::load(&(x[k * ld_src * 4 + l * ld_src + n]));
1556+
}
1557+
}
1558+
}
1559+
for (const auto i : c10::irange(L)) {
1560+
ASSERT_EQ(y[i], ref[i])
1561+
<< "Failure Details:\nTest Seed to reproduce: " << seed;
1562+
}
1563+
}
1564+
#endif
14991565
TYPED_TEST(FunctionalTests, Map) {
15001566
using vec = TypeParam;
15011567
using VT = ValueType<TypeParam>;

aten/src/ATen/test/vec_test_all_types.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#pragma once
2-
#include <ATen/cpu/vec/vec.h>
32
#include <ATen/cpu/vec/functional.h>
3+
#include <ATen/cpu/vec/vec.h>
4+
#include <ATen/cpu/vec/vec_quant.h>
45
#include <c10/util/bit_cast.h>
56
#include <c10/util/irange.h>
67
#include <gtest/gtest.h>
@@ -21,7 +22,9 @@
2122
#else
2223
#define CACHE_LINE 32
2324
#endif
24-
25+
#ifndef _WIN32
26+
#include <ATen/native/cpu/utils.h>
27+
#endif
2528
#if defined(__GNUC__)
2629
#define CACHE_ALIGN __attribute__((aligned(CACHE_LINE)))
2730
#define not_inline __attribute__((noinline))

0 commit comments

Comments
 (0)
0