8000 [Quant][X86] add an op to compute uint8 batch norm 2d (#152811) · pytorch/pytorch@1a722f6 · GitHub
[go: up one dir, main page]

Skip to content

Commit 1a722f6

Browse files
Xia-Weiwenpytorchmergebot
authored andcommitted
[Quant][X86] add an op to compute uint8 batch norm 2d (#152811)
**Summary** This PR adds a new op, `onednn.qbatch_norm2d`, which accepts uint8 inputs on CPU device (instead of QuantizedCPU). The new ops are implemented with AVX512 instructions and it provides similar performance as its counterpart for QuantizedCPU device `quantized.batch_norm2d`. The new op supports output dtypes other than uint8 (fp32, fp16 and bf16 are supported). **Test plan** ``` pytest test/quantization/core/test_quantized_op.py -k test_int8_batch_norm_onednn ``` Pull Request resolved: #152811 Approved by: https://github.com/leslie-fang-intel, https://github.com/jerryzh168, https://github.com/jgong5 ghstack dependencies: #152411
1 parent 7e16cb9 commit 1a722f6

File tree

5 files changed

+289
-0
lines changed

5 files changed

+289
-0
lines changed

aten/src/ATen/native/quantized/cpu/Normalization.cpp

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#else
1212
#include <ATen/ops/_empty_affine_quantized.h>
1313
#include <ATen/ops/empty_like.h>
14+
#include <ATen/ops/empty.h>
1415
#include <ATen/ops/quantized_batch_norm_native.h>
1516
#endif
1617

@@ -20,6 +21,7 @@ namespace at::native {
2021

2122
DEFINE_DISPATCH(qbatch_norm_stub);
2223
DEFINE_DISPATCH(qbatch_norm_relu_stub);
24+
DEFINE_DISPATCH(qbatch_norm_cpu_stub);
2325

2426
namespace {
2527
void compute_fused_params(
@@ -376,6 +378,85 @@ Tensor q_batch_norm_impl(
376378
return qy;
377379
}
378380

381+
Tensor int8_batch_norm2d_cpu_impl(
382+
const Tensor& qx,
383+
double qx_scale,
384+
int64_t qx_zero_point,
385+
const Tensor& weight,
386+
const Tensor& bias,
387+
const Tensor& mean,
388+
const Tensor& var,
389+
double eps,
390+
double output_scale,
391+
int64_t output_zero_point,
392+
c10::ScalarType output_dtype) {
393+
if (qx.numel() == 0) {
394+
auto out = qx.clone();
395+
return out;
396+
}
397+
if (output_dtype != at::kByte) {
398+
TORCH_CHECK(output_scale == 1.0 && output_zero_point == 0,
399+
"Quantized batch_norm_2d output scale and zero point should be 1 and 0 for "
400+
"output_dtype ", output_dtype, ", but got scale = ",
401+
output_scale, " and zero point = ", output_zero_point);
402+
}
403+
int64_t ndim = qx.dim();
404+
TORCH_CHECK(ndim == 4, "Int8 batch_norm2d: Expecting the input tensor of rank 4.");
405+
const int64_t N = qx.size(0);
406+
const int64_t C = qx.size(1);
407+
const int64_t H = qx.size(2);
408+
const int64_t W = qx.size(3);
409+
410+
TORCH_CHECK(weight.numel() == C, "Expect weight size to match C");
411+
TORCH_CHECK(bias.numel() == C, "Expect weight size to match C");
412+
413+
const float* weight_data = weight.template const_data_ptr<float>();
414+
const float* bias_data = bias.template const_data_ptr<float>();
415+
416+
TORCH_CHECK(mean.numel() == C, "Mean size must match channel dimension");
417+
TORCH_CHECK(var.numel() == C, "Variance size must match channel dimension");
418+
419+
Tensor alpha = at::empty_like(mean, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
420+
Tensor beta = at::empty_like(mean, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
421+
float* alpha_data = alpha.mutable_data_ptr<float>();
422+
float* beta_data = beta.data_ptr<float>();
423+
424+
const float* mean_data = mean.template const_data_ptr<float>();
425+
const float* var_data = var.template const_da 10000 ta_ptr<float>();
426+
427+
auto oSizes = qx.sizes();
428+
auto qx_nhwc = qx.contiguous(MemoryFormat::ChannelsLast);
429+
Tensor qy = at::empty(
430+
oSizes,
431+
at::device(kCPU)
432+
.dtype(output_dtype)
433+
.memory_format(MemoryFormat::ChannelsLast));
434+
435+
compute_fused_params(
436+
C,
437+
weight_data,
438+
bias_data,
439+
mean_data,
440+
var_data,
441+
eps,
442+
qx_scale,
443+
output_scale,
444+
alpha_data,
445+
beta_data);
446+
qbatch_norm_cpu_stub(
447+
qx.device().type(),
448+
N,
449+
C,
450+
H * W,
451+
qx_zero_point,
452+
output_zero_point,
453+
qx_nhwc,
454+
alpha,
455+
beta,
456+
qy);
457+
return qy;
458+
}
459+
379460
} // namespace
380461

381462
Tensor quantized_batch_norm(
@@ -396,6 +477,7 @@ Tensor quantized_batch_norm(
396477
output_zero_point);
397478
}
398479

480+
399481
TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) {
400482
m.impl(TORCH_SELECTIVE_NAME("quantized::batch_norm"), TORCH_FN(q_batch_norm_impl<false>));
401483
m.impl(TORCH_SELECTIVE_NAME("quantized::batch_norm_relu"), TORCH_FN(q_batch_norm_impl<true>));
@@ -407,4 +489,8 @@ TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) {
407489
m.impl(TORCH_SELECTIVE_NAME("quantized::batch_norm3d_relu"), TORCH_FN(q_batch_norm3d_impl<true>));
408490
}
409491

492+
TORCH_LIBRARY_IMPL(onednn, CPU, m) {
493+
m.impl(TORCH_SELECTIVE_NAME("onednn::qbatch_norm2d"), TORCH_FN(int8_batch_norm2d_cpu_impl));
494+
}
495+
410496
} // namespace at::native

aten/src/ATen/native/quantized/cpu/QuantizedOps.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,17 @@ using qbinary_eltwise_cpu_fn = void (*)(
227227
double /*output_scale*/,
228228
int64_t /*output_zero_point*/);
229229

230+
using qbatch_norm_cpu_fn = void(*)(
231+
int64_t /*N*/,
232+
int64_t /*C*/,
233+
int64_t /*H * W*/,
234+
int64_t /*in_zero_point*/,
235+
int64_t /*out_zero_point*/,
236+
const Tensor& /*input*/,
237+
const Tensor& /*a*/,
238+
const Tensor& /*b*/,
239+
Tensor& /*output*/);
240+
230241
DECLARE_DISPATCH(qadaptive_avg_pool2d_fn, qadaptive_avg_pool2d_nhwc_stub)
231242
DECLARE_DISPATCH(qadaptive_avg_pool3d_fn, qadaptive_avg_pool3d_ndhwc_stub)
232243
DECLARE_DISPATCH(qadd_scalar_fn, qadd_scalar_relu_stub)
@@ -266,5 +277,6 @@ DECLARE_DISPATCH(qprelu_fn, qprelu_stub)
266277
DECLARE_DISPATCH(qbinary_eltwise_cpu_fn, qmul_tensor_cpu_stub)
267278
DECLARE_DISPATCH(qbinary_eltwise_cpu_fn, qadd_tensor_cpu_stub)
268279
DECLARE_DISPATCH(qbinary_eltwise_cpu_fn, qadd_relu_tensor_cpu_stub)
280+
DECLARE_DISPATCH(qbatch_norm_cpu_fn, qbatch_norm_cpu_stub)
269281

270282
} // namespace at::native

aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2521,6 +2521,157 @@ void q_batch_norm_kernel(
25212521
});
25222522
}
25232523

2524+
template <typename T>
2525+
void q_batch_norm_cpu_kernel_impl(
2526+
int64_t N,
2527+
int64_t C,
2528+
int64_t HxW,
2529+
int64_t in_zero_point,
2530+
int64_t out_zero_point,
2531+
const uint8_t* in_ptr,
2532+
const float* alpha_ptr,
2533+
const float* beta_ptr,
2534+
T* out_ptr) {
2535+
2536+
int q_min = 0;
2537+
int q_max = 255;
2538+
const int64_t outer_size = N * HxW;
2539+
2540+
#if defined(CPU_CAPABILITY_AVX512)
2541+
constexpr int kVLen = 16;
2542+
static constexpr int num_vecs = sizeof(float) / sizeof(uint8_t);
2543+
auto in_zp_vec = _mm512_set1_ps((float)in_zero_point);
2544+
auto fake_scale = _mm512_set1_ps(1.0f);
2545+
auto scale_neg_zp_premul = _mm512_xor_ps(_mm512_set1_ps(-0.f), in_zp_vec);
2546+
auto out_zero_point_v = _mm512_set1_epi32((int)out_zero_point);
2547+
constexpr auto lanes = static_cast<int64_t>(num_vecs * kVLen);
2548+
__m512i v_q_max = _mm512_set1_epi32(q_max);
2549+
__m512i v_q_min = _mm512_set1_epi32(q_min);
2550+
2551+
auto load_convert_u8_to_f32_512bit = [&](const uint8_t* src, __m512* dst) {
2552+
// Step 1: Load 512 bits
2553+
__m512i raw = _mm512_loadu_si512(src);
2554+
2555+
// Step 2: Extract two 256-bit chunks
2556+
__m256i v0 = _mm512_extracti64x4_epi64(raw, 0); // bytes 0–31
2557+
__m256i v1 = _mm512_extracti64x4_epi64(raw, 1); // bytes 32–63
2558+
2559+
// Step 3: Process each 256-bit chunk
2560+
// --- Expand uint8_t -> uint16_t ---
2561+
__m256i u16lo0 = _mm256_cvtepu8_epi16(_mm256_extracti128_si256(v0, 0));
2562+
__m256i u16hi0 = _mm256_cvtepu8_epi16(_mm256_extracti128_si256(v0, 1));
2563+
__m256i u16lo1 = _mm256_cvtepu8_epi16(_mm256_extracti128_si256(v1, 0));
2564+
__m256i u16hi1 = _mm256_cvtepu8_epi16(_mm256_extracti128_si256(v1, 1));
2565+
// --- Expand to uint32_t and convert to float ---
2566+
dst[0] = _mm512_cvtepi32_ps(_mm512_cvtepu16_epi32(u16lo0));
2567+
dst[1] = _mm512_cvtepi32_ps(_mm512_cvtepu16_epi32(u16hi0));
2568+
dst[2] = _mm512_cvtepi32_ps(_mm512_cvtepu16_epi32(u16lo1));
2569+
dst[3] = _mm512_cvtepi32_ps(_mm512_cvtepu16_epi32(u16hi1));
2570+
};
2571+
2572+
auto load_convert_u8_to_f32_128bit = [&](const uint8_t* src) {
2573+
// --- Load and expand uint8_t -> uint16_t ---
2574+
__m256i v_u16 = _mm256_cvtepu8_epi16(_mm_loadu_si128((__m128i*)src));
2575+
// --- Expand to uint32_t and convert to float ---
2576+
return _mm512_cvtepi32_ps(_mm512_cvtepu16_epi32(v_u16));
2577+
};
2578+
2579+
auto store_output = [&](__m512 out, T* out_addr) {
2580+
if constexpr (std::is_same<T, float>::value) {
2581+
_mm512_storeu_ps(out_addr, out);
2582+
} else if constexpr (std::is_same<T, at::BFloat16>::value) {
2583+
__m256i out_bf16 = cvtfp32_bf16(out);
2584+
_mm256_storeu_si256((__m256i*)out_addr, out_bf16);
2585+
} else if constexpr (std::is_same<T, at::Half>::value) {
2586+
__m256i out_f16 = cvtfp32_fp16(out);
2587+
_mm256_storeu_si256((__m256i*)out_addr, out_f16);
2588+
} else { // T == uint8, requantization needed
2589+
__m512i out_i32 = _mm512_cvtps_epi32(out);
2590+
out_i32 = _mm512_add_epi32(out_i32, out_zero_point_v);
2591+
out_i32 = _mm512_min_epi32(out_i32, v_q_max);
2592+
out_i32 = _mm512_max_epi32(out_i32, v_q_min);
2593+
__m128i out_i8 = _mm512_cvtepi32_epi8(out_i32);
2594+
_mm_storeu_si128((__m128i*)out_addr, out_i8);
2595+
}
2596+
};
2597+
#endif
2598+
2599+
at::parallel_for(0, outer_size, 0, [&](int64_t begin, int64_t end) {
2600+
for (const auto i : c10::irange(begin, end)) {
2601+
auto* X_ptr = in_ptr + i * C;
2602+
auto* Y_ptr = out_ptr + i * C;
2603+
int64_t ch = 0;
2604+
2605+
#if defined(CPU_CAPABILITY_AVX512)
2606+
__m512 vals_dq[num_vecs];
2607+
for(; ch + lanes <= C; ch += lanes) {
2608+
// load 64 values of input then dequantize them
2609+
load_convert_u8_to_f32_512bit(X_ptr + ch, vals_dq);
2610+
for (const auto idx : c10::irange(num_vecs)) {
2611+
vals_dq[idx] = _mm512_fmadd_ps(fake_scale, vals_dq[idx], scale_neg_zp_premul);
2612+
auto alpha_v = _mm512_loadu_ps(alpha_ptr + ch + idx * kVLen);
2613+
auto beta_v = _mm512_loadu_ps(beta_ptr + ch + idx * kVLen);
2614+
vals_dq[idx] = _mm512_fmadd_ps(alpha_v, vals_dq[idx], beta_v);
2615+
store_output(vals_dq[idx], Y_ptr + ch + idx * kVLen);
2616+
}
2617+
}
2618+
2619+
// for channel between 16 and 64
2620+
int64_t elem_size = C - ch;
2621+
if (elem_size >= kVLen) {
2622+
int64_t vec_num = elem_size / kVLen;
2623+
for (const auto idx : c10::irange(vec_num)) {
2624+
__m512 val_dq = load_convert_u8_to_f32_128bit(X_ptr + ch + idx * kVLen);
2625+
val_dq = _mm512_fmadd_ps(fake_scale, val_dq, scale_neg_zp_premul);
2626+
auto alpha_v = _mm512_loadu_ps(alpha_ptr + ch + idx * kVLen);
2627+
auto beta_v = _mm512_loadu_ps(beta_ptr + ch + idx * kVLen);
2628+
val_dq = _mm512_fmadd_ps(alpha_v, val_dq, beta_v);
2629+
store_output(val_dq, Y_ptr + ch + idx * kVLen);
2630+
}
2631+
ch += vec_num * kVLen;
2632+
}
2633+
#endif
2634+
// for channels less than 16
2635+
for (; ch < C; ++ch) {
2636+
float y_val_f = alpha_ptr[ch] * (X_ptr[ch] - in_zero_point) +
2637+
beta_ptr[ch];
2638+
if constexpr (std::is_same<T, float>::value) {
2639+
Y_ptr[ch] = y_val_f;
2640+
} else if constexpr (std::is_same<T, at::BFloat16>::value) {
2641+
Y_ptr[ch] = (at::BFloat16)y_val_f;
2642+
} else if constexpr (std::is_same<T, at::Half>::value) {
2643+
Y_ptr[ch] = (at::Half)y_val_f;
2644+
} else { // T == uint8, requantization needed
2645+
long quantized_down = out_zero_point + lrintf(y_val_f);
2646+
Y_ptr[ch] = std::min<long>(
2647+
std::max<long>(quantized_down, q_min), q_max);
2648+
}
2649+
}
2650+
}
2651+
});
2652+
}
2653+
2654+
void q_batch_norm_cpu_kernel(
2655+
int64_t N,
2656+
int64_t C,
2657+
int64_t HxW,
2658+
int64_t in_zero_point,
2659+
int64_t out_zero_point,
2660+
const Tensor& input,
2661+
const Tensor& a,
2662+
const Tensor& b,
2663+
Tensor& output) {
2664+
auto in_ptr = input.const_data_ptr<uint8_t>();
2665+
float* alpha_ptr = a.data_ptr<float>();
2666+
float* beta_ptr = b.data_ptr<float>();
2667+
AT_DISPATCH_FLOATING_TYPES_AND3(
2668+
at::ScalarType::BFloat16, at::ScalarType::Half, at::ScalarType::Byte, output.scalar_type(), "int8_batch_norm2d_cpu", [&] {
2669+
auto out_ptr = output.data_ptr<scalar_t>();
2670+
q_batch_norm_cpu_kernel_impl<scalar_t>(
2671+
N, C, HxW, in_zero_point, out_zero_point, in_ptr, alpha_ptr, beta_ptr, out_ptr);
2672+
});
2673+
}
2674+
25242675
void _fake_quantize_tensor_helper(
25252676
Tensor& output,
25262677
Tensor& mask,
@@ -4587,5 +4738,6 @@ REGISTER_DISPATCH(qstd_inner_dim_stub, &qstd_inner_dim_kernel)
45874738
ALSO_REGISTER_AVX512_DISPATCH(qmul_tensor_cpu_stub, &qmul_tensor_cpu_kernel)
45884739
ALSO_REGISTER_AVX512_DISPATCH(qadd_tensor_cpu_stub, &qadd_tensor_cpu_kernel<false>)
45894740
ALSO_REGISTER_AVX512_DISPATCH(qadd_relu_tensor_cpu_stub, &qadd_tensor_cpu_kernel<true>)
4741+
ALSO_REGISTER_AVX512_DISPATCH(qbatch_norm_cpu_stub, &q_batch_norm_cpu_kernel)
45904742
} // namespace at::native
45914743
// NOLINTEND(*-c-arrays)

aten/src/ATen/native/quantized/library.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,4 +283,6 @@ TORCH_LIBRARY(onednn, m) {
283283
// int8 add
284284
m.def(TORCH_SELECTIVE_SCHEMA("onednn::qadd.tensor(Tensor self, float self_scale, int self_zero_point, Tensor other, float other_scale, int other_zero_point, float output_scale, int output_zero_point, ScalarType output_dtype) -> Tensor"));
285285
m.def(TORCH_SELECTIVE_SCHEMA("onednn::qadd_relu.tensor(Tensor self, float self_scale, int self_zero_point, Tensor other, float other_scale, int other_zero_point, float output_scale, int output_zero_point, ScalarType output_dtype) -> Tensor"));
286+
// int8 batch_norm2d
287+
m.def(TORCH_SELECTIVE_SCHEMA("onednn::qbatch_norm2d(Tensor qx, float qx_scale, int qx_zero_point, Tensor weight, Tensor bias, Tensor mean, Tensor var, float eps, float output_scale, int output_zero_point, ScalarType output_dtype) -> Tensor"));
286288
}

test/quantization/core/test_quantized_op.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3201,6 +3201,43 @@ def test_int8_add_onednn(self, relu_fused):
32013201
c = torch.ops.onednn.qadd.tensor(a_int8, s_a, z_a, b_int8, s_b, z_b, s_c, z_c, output_dtype)
32023202
self.assertEqual(c, c_ref)
32033203

3204+
@skipIfNoONEDNN
3205+
def test_int8_batch_norm_onednn(self):
3206+
# hypothesis too slow for this test, create test cases manually
3207+
channel_len_list = (8, 64, 100, 120, 128)
3208+
output_dtype_list = [torch.uint8, torch.float, torch.bfloat16, torch.half]
3209+
x_scale, x_zero_point = 0.1, 1
3210+
cases = itertools.product(channel_len_list, output_dtype_list)
3211+
for channels, out_dtype in cases:
3212+
shapes = [8, channels, 8, 8]
3213+
y_scale, y_zero_point = (0.2, 2) if out_dtype == torch.uint8 else (1, 0)
3214+
3215+
x = torch.randn(shapes, dtype=torch.float32)
3216+
mean = torch.rand(channels).float()
3217+
var = torch.rand(channels).float()
3218+
weight = torch.rand(channels).float()
3219+
bias = torch.rand(channels).float()
3220+
eps = 0.001
3221+
qx = torch.ops.quantized_decomposed.quantize_per_tensor.default(
3222+
x, x_scale, x_zero_point, 0, 255, torch.uint8
3223+
)
3224+
y = torch.ops.onednn.qbatch_norm2d(
3225+
qx, x_scale, x_zero_point, weight, bias, mean, var, eps, y_scale, y_zero_point, out_dtype
3226+
)
3227+
3228+
dqx = torch.ops.quantized_decomposed.dequantize_per_tensor.default(
3229+
qx, x_scale, x_zero_point, 0, 255, torch.uint8
3230+
)
3231+
y_ref = F.batch_norm(dqx, weight=weight, bias=bias,
3232+
running_mean=mean, running_var=var, training=False,
3233+
momentum=0, eps=eps)
3234+
if out_dtype == torch.uint8:
3235+
y_ref = torch.ops.quantized_decomposed.quantize_per_tensor.default(
3236+
y_ref, y_scale, y_zero_point, 0, 255, torch.uint8
3237+
)
3238+
y_ref = y_ref.to(out_dtype)
3239+
self.assertEqual(y, y_ref, msg=f"{y} vs {y_ref}")
3240+
32043241

32053242
class TestDynamicQuantizedOps(TestCase):
32063243
"""Tests the correctness of the dynamic quantized linear and linear_relu op."""

0 commit comments

Comments
 (0)
0