8000 [Inductor][CPP] Enable vectorized fp8 quant dequant · pytorch/pytorch@cb72615 · GitHub
[go: up one dir, main page]

Skip to content

Commit cb72615

Browse files
[Inductor][CPP] Enable vectorized fp8 quant dequant
ghstack-source-id: a82e56e Pull Request resolved: #152418
1 parent 9c7d78a commit cb72615

File tree

4 files changed

+35
-1
lines changed

4 files changed

+35
-1
lines changed

aten/src/ATen/cpu/vec/vec512/vec512_convert.h

+20
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,26 @@ struct VecConvert<
291291
}
292292
};
293293

294+
template <>
295+
struct VecConvert<Float8_e4m3fn, 1, float, 1> {
296+
static inline VectorizedN<Float8_e4m3fn, 1> apply(const VectorizedN<float, 1>& src_n) {
297+
at::vec::Vectorized<float> src = src_n[0];
298+
__m128i res128 = cvtfp32_fp8e4m3(src);
299+
return at::vec::Vectorized<Float8_e4m3fn>(_mm512_castsi128_si512(res128));
300+
}
301+
};
302+
303+
template <>
304+
struct VecConvert<float, 1, Float8_e4m3fn, 1> {
305+
static inline VectorizedN<float, 1> apply(const VectorizedN<Float8_e4m3fn, 1>& src_n) {
306+
// cvt first 16x8 bits from Float8_e4m3fn to float
307+
at::vec::Vectorized<Float8_e4m3fn> src = src_n[0];
308+
__m512 result;
309+
cvtfp8e4m3_fp32(_mm512_castsi512_si128(src), result);
310+
return at::vec::Vectorized<float>(result);
311+
}
312+
};
313+
294314
#endif
295315

296316
} // namespace CPU_CAPABILITY

test/inductor/test_cpu_repro.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -1418,9 +1418,15 @@ def fn(
14181418
use_quant_list = [False, True]
14191419
use_tensor_overload_list = [False, True]
14201420

1421-
assert dtype in [torch.uint8, torch.int8]
1421+
assert dtype in [torch.uint8, torch.int8, torch.float8_e4m3fn]
14221422
quant_min = 0 if dtype == torch.uint8 else -128
14231423
quant_max = 255 if dtype == torch.uint8 else 127
1424+
if dtype == torch.float8_e4m3fn:
1425+
quant_min = int(torch.finfo(dtype).min)
1426+
quant_max = int(torch.finfo(dtype).max)
1427+
use_tensor_overload_list = [
1428+
False,
1429+
]
14241430

14251431
for (
14261432
use_dequant,
@@ -1476,6 +1482,10 @@ def test_dequant_quant_lowering_int8(self):
14761482
torch.int8, dequant_out_dtype=torch.bfloat16
14771483
)
14781484

1485+
@requires_vectorization
1486+
def test_dequant_quant_lowering_fp8_e4m3(self):
1487+
self._test_dequant_quant_lowering_helper(torch.float8_e4m3fn)
1488+
14791489
def _test_dequant_maxpool2d_lowering_helper(self, dtype):
14801490
def fn(x, scale, zero_point, quant_min, quant_max, dtype):
14811491
x = torch.ops.quantized_decomposed.dequantize_per_tensor(

torch/_inductor/codegen/cpp.py

+2
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ def get_export_declaration():
154154
torch.int8,
155155
torch.int32,
156156
torch.int64,
157+
torch.float8_e4m3fn,
157158
]
158159

159160
MASKED_VECTORIZABLE_DTYPES: list[torch.dtype] = [
@@ -1607,6 +1608,7 @@ def to_dtype(x, dtype, src_dtype=None, use_compute_dtypes=True):
16071608
torch.int8,
16081609
torch.int32,
16091610
torch.int64,
1611+
torch.float8_e4m3fn,
16101612
], f"{__name__} does not support {dtype}"
16111613
assert isinstance(x, CppCSEVariable)
16121614
src_dtype = x.dtype

torch/csrc/inductor/cpp_wrapper/common.h

+2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
// Include some often-used cpp_wrapper headers, for precompiling.
1313
#include <c10/util/BFloat16.h>
14+
#include <c10/util/Float8_e4m3fn.h>
1415
#include <torch/csrc/Device.h>
1516
#include <torch/csrc/DynamicTypes.h>
1617
#include <torch/csrc/utils/pythoncapi_compat.h>
@@ -70,6 +71,7 @@ using namespace torch::aot_inductor;
7071
#include <torch/csrc/inductor/aoti_runtime/scalar_to_tensor.h>
7172
using half = at::Half;
7273
using bfloat16 = at::BFloat16;
74+
using float8_e4m3fn = at::Float8_e4m3fn;
7375

7476
// Round up to the nearest multiple of 64
7577
[[maybe_unused]] inline int64_t align(int64_t nbytes) {

0 commit comments

Comments
 (0)
0