8000 [Inductor][CPP] Enable vectorized fp8 E5M2 quant dequant · pytorch/pytorch@f81e66f · GitHub 8000
[go: up one dir, main page]

Skip to content

Commit f81e66f

Browse files
[Inductor][CPP] Enable vectorized fp8 E5M2 quant dequant
ghstack-source-id: 2249ca6 Pull Request resolved: #153365
1 parent 2db230c commit f81e66f

File tree

4 files changed

+35
-2
lines changed

4 files changed

+35
-2
lines changed

aten/src/ATen/cpu/vec/vec512/vec512_convert.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,26 @@ struct VecConvert<float, 1, Float8_e4m3fn, 1> {
311311
}
312312
};
313313

314+
template <>
315+
struct VecConvert<Float8_e5m2, 1, float, 1> {
316+
static inline VectorizedN<Float8_e5m2, 1> apply(const VectorizedN<float, 1>& src_n) {
317+
at::vec::Vectorized<float> src = src_n[0];
318+
__m128i res128 = cvtfp32_fp8e5m2(src);
319+
return at::vec::Vectorized<Float8_e5m2>(_mm512_castsi128_si512(res128));
320+
}
321+
};
322+
323+
template <>
324+
struct VecConvert<float, 1, Float8_e5m2, 1> {
325+
static inline VectorizedN<float, 1> apply(const VectorizedN<Float8_e5m2, 1>& src_n) {
326+
// cvt first 16x8 bits from Float8_e5m2 to float
327+
at::vec::Vectorized<Float8_e5m2> src = src_n[0];
328+
__m512 result;
329+
cvtfp8e5m2_fp32(_mm512_castsi512_si128(src), result);
330+
return at::vec::Vectorized<float>(result);
331+
}
332+
};
333+
314334
#endif
315335

316336
} // namespace CPU_CAPABILITY

test/inductor/test_cpu_repro.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1418,10 +1418,15 @@ def fn(
14181418
use_quant_list = [False, True]
14191419
use_tensor_overload_list = [False, True]
14201420

1421-
assert dtype in [torch.uint8, torch.int8, torch.float8_e4m3fn]
1421+
assert dtype in [
1422+
torch.uint8,
1423+
torch.int8,
1424+
torch.float8_e4m3fn,
1425+
torch.float8_e5m2,
1426+
]
14221427
quant_min = 0 if dtype == torch.uint8 else -128
14231428
quant_max = 255 if dtype == torch.uint8 else 127
1424-
if dtype == torch.float8_e4m3fn:
1429+
if dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
14251430
quant_min = int(torch.finfo(dtype).min)
14261431
quant_max = int(torch.finfo(dtype).max)
14271432
use_tensor_overload_list = [
@@ -1486,6 +1491,10 @@ def test_dequant_quant_lowering_int8(self):
14861491
def test_dequant_quant_lowering_fp8_e4m3(self):
14871492
self._test_dequant_quant_lowering_helper(torch.float8_e4m3fn)
14881493

1494+
@requires_vectorization
1495+
def test_dequant_quant_lowering_fp8_e5m2(self):
1496+
self._test_dequant_quant_lowering_helper(torch.float8_e5m2)
1497+
14891498
def _test_dequant_maxpool2d_lowering_helper(self, dtype):
14901499
def fn(x, scale, zero_point, quant_min, quant_max, dtype):
14911500
x = torch.ops.quantized_decomposed.dequantize_per_tensor(

torch/_inductor/codegen/cpp.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ def get_export_declaration():
155155
torch.int32,
156156
torch.int64,
157157
torch.float8_e4m3fn,
158+
torch.float8_e5m2,
158159
]
159160

160161
MASKED_VECTORIZABLE_DTYPES: list[torch.dtype] = [
@@ -1609,6 +1610,7 @@ def to_dtype(x, dtype, src_dtype=None, use_compute_dtypes=True):
16091610
torch.int32,
16101611
torch.int64,
16111612
torch.float8_e4m3fn,
1613+
torch.float8_e5m2,
16121614
], f"{__name__} does not support {dtype}"
16131615
assert isinstance(x, CppCSEVariable)
16141616
src_dtype = x.dtype

torch/csrc/inductor/cpp_wrapper/common.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
// Include some often-used cpp_wrapper headers, for precompiling.
1313
#include <c10/util/BFloat16.h>
1414
#include <c10/util/Float8_e4m3fn.h>
15+
#include <c10/util/Float8_e5m2.h>
1516
#include <torch/csrc/Device.h>
1617
#include <torch/csrc/DynamicTypes.h>
1718
#include <torch/csrc/utils/pythoncapi_compat.h>
@@ -72,6 +73,7 @@ using namespace torch::aot_inductor;
7273
using half = at::Half;
7374
using bfloat16 = at::BFloat16;
7475
using float8_e4m3fn = at::Float8_e4m3fn;
76+
using float8_e5m2 = at::Float8_e5m2;
7577

7678
// Round up to the nearest multiple of 64
7779
[[maybe_unused]] inline int64_t align(int64_t nbytes) {

0 commit comments

Comments
 (0)
0