|
| 1 | +#pragma once |
| 2 | + |
| 3 | +#include <ATen/cpu/vec/intrinsics.h> |
| 4 | +#include <c10/util/Exception.h> |
| 5 | + |
| 6 | +namespace at::vec { |
| 7 | +// See Note [CPU_CAPABILITY namespace] |
| 8 | +inline namespace CPU_CAPABILITY { |
| 9 | + |
| 10 | +// Transpose a [4, 64] block to [64, 4] (with contiguous output, ld=4) |
| 11 | +template <typename scalar_t, typename = std::enable_if_t<sizeof(scalar_t) == 1>> |
| 12 | +static inline void transpose_pad_4x64_block( |
| 13 | + const scalar_t* src, |
| 14 | + scalar_t* dst, |
| 15 | + int64_t ld_src, |
| 16 | + int krem = 4, |
| 17 | + int nrem = 64) { |
| 18 | +#if defined(CPU_CAPABILITY_AVX512) |
| 19 | + __m512i r[4]; |
| 20 | + // Load with mask if partial |
| 21 | + if (nrem < 64) { |
| 22 | + __mmask64 mask = (1ULL << nrem) - 1; |
| 23 | + for (int i = 0; i < krem; ++i) { |
| 24 | + r[i] = _mm512_maskz_loadu_epi8(mask, src + i * ld_src); |
| 25 | + } |
| 26 | + for (int i = krem; i < 4; ++i) { |
| 27 | + r[i] = _mm512_setzero_si512(); |
| 28 | + } |
| 29 | + } else { |
| 30 | + for (int i = 0; i < krem; ++i) { |
| 31 | + r[i] = _mm512_loadu_si512( |
| 32 | + reinterpret_cast<const __m512i*>(src + i * ld_src)); |
| 33 | + } |
| 34 | + for (int i = krem; i < 4; ++i) { |
| 35 | + r[i] = _mm512_setzero_si512(); |
| 36 | + } |
| 37 | + } |
| 38 | + |
| 39 | + // Transpose 4x64 bytes using unpack and shuffle |
| 40 | + __m512i t0 = _mm512_unpacklo_epi8(r[0], r[1]); |
| 41 | + __m512i t1 = _mm512_unpackhi_epi8(r[0], r[1]); |
| 42 | + __m512i t2 = _mm512_unpacklo_epi8(r[2], r[3]); |
| 43 | + __m512i t3 = _mm512_unpackhi_epi8(r[2], r[3]); |
| 44 | + |
| 45 | + __m512i u0 = _mm512_unpacklo_epi16(t0, t2); |
| 46 | + __m512i u1 = _mm512_unpackhi_epi16(t0, t2); |
| 47 | + __m512i u2 = _mm512_unpacklo_epi16(t1, t3); |
| 48 | + __m512i u3 = _mm512_unpackhi_epi16(t1, t3); |
| 49 | + |
| 50 | + __m512i v0 = _mm512_shuffle_i32x4(u0, u1, 0x88); |
| 51 | + __m512i v1 = _mm512_shuffle_i32x4(u0, u1, 0xdd); |
| 52 | + __m512i v2 = _mm512_shuffle_i32x4(u2, u3, 0x88); |
| 53 | + __m512i v3 = _mm512_shuffle_i32x4(u2, u3, 0xdd); |
| 54 | + |
| 55 | + __m512i r0 = _mm512_shuffle_i32x4(v0, v2, 0x88); |
| 56 | + __m512i r1 = _mm512_shuffle_i32x4(v1, v3, 0x88); |
| 57 | + __m512i r2 = _mm512_shuffle_i32x4(v0, v2, 0xdd); |
| 58 | + __m512i r3 = _mm512_shuffle_i32x4(v1, v3, 0xdd); |
| 59 | + |
| 60 | + // Store output |
| 61 | + if (nrem < 16) { |
| 62 | + __mmask64 mask = (1ULL << (nrem * 4)) - 1; |
| 63 | + _mm512_mask_storeu_epi8(dst, mask, r0); |
| 64 | + } else if (nrem == 16) { |
| 65 | + _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), r0); |
| 66 | + } else if (nrem < 32) { |
| 67 | + int n_bytes1 = 64; |
| 68 | + int n_bytes2 = (nrem * 4) - n_bytes1; |
| 69 | + __mmask64 mask = (1ULL << n_bytes2) - 1; |
| 70 | + _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), r0); |
| 71 | + _mm512_mask_storeu_epi8(reinterpret_cast<__m512i*>(dst + 64), mask, r1); |
| 72 | + } else if (nrem == 32) { |
| 73 | + _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), r0); |
| 74 | + _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 64), r1); |
| 75 | + } else if (nrem < 48) { |
| 76 | + int n_bytes1 = 64 * 2; |
| 77 | + int n_bytes2 = (nrem * 4) - n_bytes1; |
| 78 | + __mmask64 mask = (1ULL << n_bytes2) - 1; |
| 79 | + _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), r0); |
| 80 | + _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 64), r1); |
| 81 | + _mm512_mask_storeu_epi8(reinterpret_cast<__m512i*>(dst + 64 * 2), mask, r2); |
| 82 | + } else if (nrem == 48) { |
| 83 | + _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), r0); |
| 84 | + _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 64), r1); |
| 85 | + _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 64 * 2), r2); |
| 86 | + } else if (nrem < 64) { |
| 87 | + int n_bytes1 = 64 * 3; |
| 88 | + int n_bytes2 = (nrem * 4) - n_bytes1; |
| 89 | + __mmask64 mask = (1ULL << n_bytes2) - 1; |
| 90 | + _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), r0); |
| 91 | + _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 64), r1); |
| 92 | + _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 64 * 2), r2); |
| 93 | + _mm512_mask_storeu_epi8(reinterpret_cast<__m512i*>(dst + 64 * 3), mask, r3); |
| 94 | + } else { |
| 95 | + // normal case, nrem == 64 |
| 96 | + _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), r0); |
| 97 | + _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 64), r1); |
| 98 | + _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 64 * 2), r2); |
| 99 | + _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 64 * 3), r3); |
| 100 | + } |
| 101 | +#else |
| 102 | + TORCH_CHECK( |
| 103 | + false, |
| 104 | + "transpose_pad_4x64_block is only supported when AVX-512 is supported") |
| 105 | +#endif |
| 106 | +} |
| 107 | + |
| 108 | +// Reorder [K, N] → [K/4, N, 4] (VNNI4-style layout for bit8) |
| 109 | +template <typename scalar_t, typename = std::enable_if_t<sizeof(scalar_t) == 1>> |
| 110 | +static inline void pack_vnni4( |
| 111 | + const scalar_t* src, |
| 112 | + scalar_t* dst, |
| 113 | + int64_t ld_src, |
| 114 | + int64_t K, |
| 115 | + int64_t N) { |
| 116 | +#if defined(CPU_CAPABILITY_AVX512) |
| 117 | + int64_t bk = 0; |
| 118 | + int64_t _K = K / 4 * 4; |
| 119 | + int64_t _N = N / 64 * 64; |
| 120 | + for (; bk < _K; bk += 4) { |
| 121 | + int64_t bn = 0; |
| 122 | + for (; bn < _N; bn += 64) { |
| 123 | + transpose_pad_4x64_block( |
| 124 | + src + bk * ld_src + bn, dst + bk * N + bn * 4, ld_src); |
| 125 | + } |
| 126 | + int64_t nrem = N - bn; |
| 127 | + if (nrem > 0) { |
| 128 | + transpose_pad_4x64_block( |
| 129 | + src + bk * ld_src + bn, dst + bk * N + bn * 4, ld_src, 4, nrem); |
| 130 | + } |
| 131 | + } |
| 132 | + |
| 133 | + // Handle leftover K rows (< 4) |
| 134 | + if (K % 4 != 0) { |
| 135 | + int krem = K - bk; |
| 136 | + int64_t bn = 0; |
| 137 | + for (; bn < _N; bn += 64) { |
| 138 | + transpose_pad_4x64_block( |
| 139 | + src + bk * ld_src + bn, dst + bk * N + bn * 4, ld_src, krem); |
| 140 | + } |
| 141 | + int64_t nrem = N - bn; |
| 142 | + if (nrem > 0) { |
| 143 | + transpose_pad_4x64_block( |
| 144 | + src + bk * ld_src + bn, dst + bk * N + bn * 4, ld_src, krem, nrem); |
| 145 | + } |
| 146 | + } |
| 147 | +#else |
| 148 | + TORCH_CHECK(false, "pack_vnni4 is only supported when AVX-512 is supported") |
| 149 | +#endif |
| 150 | +} |
| 151 | + |
| 152 | +} // namespace CPU_CAPABILITY |
| 153 | +} // namespace at::vec |
0 commit comments