8000 [Inductor][CPP] Enable vectorized fp8 quant dequant · pytorch/pytorch@8f42d59 · GitHub
[go: up one dir, main page]

Skip to content

Commit 8f42d59

Browse files
[Inductor][CPP] Enable vectorized fp8 quant dequant
ghstack-source-id: f961223 Pull Request resolved: #152418
1 parent cbd8419 commit 8f42d59

File tree

3 files changed

+33
-1
lines changed

3 files changed

+33
-1
lines changed

aten/src/ATen/cpu/vec/vec512/vec512_convert.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,26 @@ struct VecConvert<
291291
}
292292
};
293293

294+
template <>
295+
struct VecConvert<Float8_e4m3fn, 1, float, 1> {
296+
static inline VectorizedN<Float8_e4m3fn, 1> apply(const VectorizedN<float, 1>& src_n) {
297+
at::vec::Vectorized<float> src = src_n[0];
298+
__m128i res128 = cvtfp32_fp8e4m3(src);
299+
return at::vec::Vectorized<Float8_e4m3fn>(_mm512_castsi128_si512(res128));
300+
}
301+
};
302+
303+
template <>
304+
struct VecConvert<float, 1, Float8_e4m3fn, 1> {
305+
static inline VectorizedN<float, 1> apply(const VectorizedN<Float8_e4m3fn, 1>& src_n) {
306+
// cvt first 16x8 bits from Float8_e4m3fn to float
307+
at::vec::Vectorized<Float8_e4m3fn> src = src_n[0];
308+
__m512 result;
309+
cvtfp8e4m3_fp32(_mm512_castsi512_si128(src), result);
310+
return at::vec::Vectorized<float>(result);
311+
}
312+
};
313+
294314
#endif
295315

296316
} // namespace CPU_CAPABILITY

test/inductor/test_cpu_repro.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1397,9 +1397,15 @@ def fn(
13971397
use_quant_list = [False, True]
13981398
use_tensor_overload_list = [False, True]
13991399

1400-
assert dtype in [torch.uint8, torch.int8]
1400+
assert dtype in [torch.uint8, torch.int8, torch.float8_e4m3fn]
14011401
quant_min = 0 if dtype == torch.uint8 else -128
14021402
quant_max = 255 if dtype == torch.uint8 else 127
1403+
if dtype == torch.float8_e4m3fn:
1404+
quant_min = int(torch.finfo(dtype).min)
1405+
quant_max = int(torch.finfo(dtype).max)
1406+
use_tensor_overload_list = [
1407+
False,
1408+
]
14031409

14041410
for (
14051411
use_dequant,
@@ -1455,6 +1461,10 @@ def test_dequant_quant_lowering_int8(self):
14551461
torch.int8, dequant_out_dtype=torch.bfloat16
14561462
)
14571463

1464+
@requires_vectorization
1465+
def test_dequant_quant_lowering_fp8_e4m3(self):
1466+
self._test_dequant_quant_lowering_helper(torch.float8_e4m3fn)
1467+
14581468
def _test_dequant_maxpool2d_lowering_helper(self, dtype):
14591469
def fn(x, scale, zero_point, quant_min, quant_max, dtype):
14601470
x = torch.ops.quantized_decomposed.dequantize_per_tensor(

torch/_inductor/codegen/cpp.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ def get_export_declaration():
154154
torch.int8,
155155
torch.int32,
156156
torch.int64,
157+
torch.float8_e4m3fn,
157158
]
158159

159160
MASKED_VECTORIZABLE_DTYPES: list[torch.dtype] = [
@@ -1599,6 +1600,7 @@ def to_dtype(x, dtype, src_dtype=None, use_compute_dtypes=True):
15991600
torch.int8,
16001601
torch.int32,
16011602
torch.int64,
1603+
torch.float8_e4m3fn,
16021604
], f"{__name__} does not support {dtype}"
16031605
assert isinstance(x, CppCSEVariable)
16041606
src_dtype = x.dtype

0 commit comments

Comments
 (0)
0