8000 Support fp8 output of _scaled_mm for CPU · pytorch/pytorch@75a6bba · GitHub
[go: up one dir, main page]

Skip to content

Commit 75a6bba

Browse files
committed
Support fp8 output of _scaled_mm for CPU
1 parent 1a722f6 commit 75a6bba

File tree

3 files changed

+59
-0
lines changed

3 files changed

+59
-0
lines changed

aten/src/ATen/native/Blas.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,10 @@ _scaled_mm_out_cpu_emulated(const Tensor& mat1, const Tensor& mat2,
246246
mat1.sizes()[0], "x", mat1.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")");
247247

248248
TORCH_INTERNAL_ASSERT((scale_a.numel() == 1 && scale_b.numel() == 1), "Now _scaled_mm only supports per-tensor scaling for CPU backend.");
249+
TORCH_CHECK(
250+
!scale_result ||
251+
(scale_result->numel() == 1 && scale_result->scalar_type() == kFloat),
252+
"scale_result must be a float scalar");
249253
TORCH_CHECK(!bias || bias->numel() == mat2.sizes()[1], "Bias must be size ", mat2.sizes()[1],
250254
" but got ", bias->numel());
251255

@@ -262,12 +266,22 @@ _scaled_mm_out_cpu_emulated(const Tensor& mat1, const Tensor& mat2,
262266

263267
float input_scale = scale_a.item<float>();
264268
float weight_scale = scale_b.item<float>();
269+
float output_scale = float(1.0);
270+
if (scale_result.has_value() &&
271+
(*out_dtype == ScalarType::Float8_e4m3fn ||
272+
*out_dtype == ScalarType::Float8_e5m2)) {
273+
output_scale = scale_result.value().item<float>();
274+
}
265275
auto fp32_mat1 = at::mul(mat1.to(kFloat), input_scale);
266276
auto fp32_mat2 = at::mul(mat2_c.to(kFloat), weight_scale);
267277
auto out_tmp = at::matmul(fp32_mat1, fp32_mat2);
268278
if (bias) {
269279
out_tmp.add_(bias.value());
270280
}
281+
if (*out_dtype == ScalarType::Float8_e4m3fn ||
282+
*out_dtype == ScalarType::Float8_e5m2) {
283+
out_tmp = at::mul(out_tmp, 1 / output_scale);
284+
}
271285
out_tmp = out_tmp.to(out.scalar_type());
272286
out.copy_(out_tmp);
273287
return out;

aten/src/ATen/native/mkldnn/Linear.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -487,6 +487,10 @@ mkldnn_scaled_mm(const Tensor& mat1, const Tensor& mat2,
487487
mat1.sizes()[0], "x", mat1.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")");
488488

489489
TORCH_INTERNAL_ASSERT((scale_a.numel() == 1 && scale_b.numel() == 1), "Now _scaled_mm only supports per-tensor scaling for CPU backend.");
490+
TORCH_CHECK(
491+
!scale_result ||
492+
(scale_result->numel() == 1 && scale_result->scalar_type() == kFloat),
493+
"scale_result must be a float scalar");
490494
TORCH_CHECK(!bias || bias->numel() == mat2.sizes()[1], "Bias must be size ", mat2.sizes()[1],
491495
" but got ", bias->numel());
492496

@@ -504,6 +508,12 @@ mkldnn_scaled_mm(const Tensor& mat1, const Tensor& mat2,
504508

505509
float input_scale = scale_a.item<float>();
506510
float weight_scale = scale_b.item<float>();
511+
float output_scale = float(1.0);
512+
if (scale_result.has_value() &&
513+
(*out_dtype == ScalarType::Float8_e4m3fn ||
514+
*out_dtype == ScalarType::Float8_e5m2)) {
515+
output_scale = scale_result.value().item<float>();
516+
}
507517
auto src = at::native::itensor_view_from_dense(mat1_c);
508518
auto weight_t = at::native::itensor_view_from_dense(mat2_c);
509519
bool with_bias = bias.has_value();
@@ -550,6 +560,9 @@ mkldnn_scaled_mm(const Tensor& mat1, const Tensor& mat2,
550560
if (weight_scale != 1.0f) {
551561
op_attr.set_scales_mask(DNNL_ARG_WEIGHTS, 0);
552562
}
563+
if (output_scale != 1.0f) {
564+
op_attr.set_scales_mask(DNNL_ARG_DST, 0);
565+
}
553566

554567
op_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
555568
auto engine = ideep::engine::cpu_engine();
@@ -578,6 +591,8 @@ mkldnn_scaled_mm(const Tensor& mat1, const Tensor& mat2,
578591
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, src_scales_t});
579592
}
580593
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, wei_scales_t});
594+
ideep::tensor dst_scales_t = ideep::tensor(ideep::scale_t(1, output_scale));
595+
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, dst_scales_t});
581596

582597
primitive.execute(ideep::stream::default_stream(), args);
583598
return out;

test/test_mkldnn.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1629,6 +1629,36 @@ def test_mkldnn_error_on_zero_stride(self, device):
16291629
with self.assertRaises(ValueError):
16301630
torch.mkldnn_max_pool2d(x, kernel_size=3, stride=0)
16311631

1632+
def test_mkldnn_scaled_mm(self, device) -> None:
1633+
# test with input scale, weight scale and output_scale
1634+
M, N, K = 2, 13, 16
1635+
x = torch.randn((M, K), device=device) / K
1636+
y = torch.randn((N, K), device=device).t() / K
1637+
options = itertools.product(
1638+
[torch.float8_e4m3fn, torch.float8_e5m2],
1639+
[torch.float8_e4m3fn, torch.float8_e5m2],
1640+
[torch.float8_e4m3fn, torch.float8_e5m2, torch.bfloat16, torch.float16, torch.float32])
1641+
for x_dtype, y_dtype, out_dtype in options:
1642+
if out_dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
1643+
if x_dtype != out_dtype:
1644+
continue
1645+
x_fp8 = x.to(x_dtype)
1646+
y_fp8 = y.to(y_dtype)
1647+
scale_a = torch.randn(1, device=device)
1648+
scale_b = torch.randn(1, device=device)
1649+
scale_out = torch.randn(1, device=device)
1650+
out_fp32 = torch.mm(x_fp8.to(torch.float) * scale_a, y_fp8.to(torch.float) * scale_b)
1651+
if out_dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
1652+
out_emulated = (out_fp32 / scale_out).to(out_dtype)
1653+
else:
1654+
out_emulated = out_fp32.to(out_dtype)
1655+
1656+
out = torch._scaled_mm(x_fp8, y_fp8, scale_a, scale_b, scale_result=scale_out, out_dtype=out_dtype)
1657+
if out_dtype is not None:
1658+
self.assertEqual(out_dtype, out.dtype)
1659+
self.assertEqual(out_emulated.float(), out.float(), atol=5e-2, rtol=5e-2)
1660+
1661+
16321662

16331663
instantiate_device_type_tests(TestMkldnn, globals(), only_for=('cpu',))
16341664

0 commit comments

Comments
 (0)
0