@@ -2521,6 +2521,157 @@ void q_batch_norm_kernel(
2521
2521
});
2522
2522
}
2523
2523
2524
+ template <typename T>
2525
+ void q_batch_norm_cpu_kernel_impl (
2526
+ int64_t N,
2527
+ int64_t C,
2528
+ int64_t HxW,
2529
+ int64_t in_zero_point,
2530
+ int64_t out_zero_point,
2531
+ const uint8_t * in_ptr,
2532
+ const float * alpha_ptr,
2533
+ const float * beta_ptr,
2534
+ T* out_ptr) {
2535
+
2536
+ int q_min = 0 ;
2537
+ int q_max = 255 ;
2538
+ const int64_t outer_size = N * HxW;
2539
+
2540
+ #if defined(CPU_CAPABILITY_AVX512)
2541
+ constexpr int kVLen = 16 ;
2542
+ static constexpr int num_vecs = sizeof (float ) / sizeof (uint8_t );
2543
+ auto in_zp_vec = _mm512_set1_ps ((float )in_zero_point);
2544
+ auto fake_scale = _mm512_set1_ps (1 .0f );
2545
+ auto scale_neg_zp_premul = _mm512_xor_ps (_mm512_set1_ps (-0 .f ), in_zp_vec);
2546
+ auto out_zero_point_v = _mm512_set1_epi32 ((int )out_zero_point);
2547
+ constexpr auto lanes = static_cast <int64_t >(num_vecs * kVLen );
2548
+ __m512i v_q_max = _mm512_set1_epi32 (q_max);
2549
+ __m512i v_q_min = _mm512_set1_epi32 (q_min);
2550
+
2551
+ auto load_convert_u8_to_f32_512bit = [&](const uint8_t * src, __m512* dst) {
2552
+ // Step 1: Load 512 bits
2553
+ __m512i raw = _mm512_loadu_si512 (src);
2554
+
2555
+ // Step 2: Extract two 256-bit chunks
2556
+ __m256i v0 = _mm512_extracti64x4_epi64 (raw, 0 ); // bytes 0–31
2557
+ __m256i v1 = _mm512_extracti64x4_epi64 (raw, 1 ); // bytes 32–63
2558
+
2559
+ // Step 3: Process each 256-bit chunk
2560
+ // --- Expand uint8_t -> uint16_t ---
2561
+ __m256i u16lo0 = _mm256_cvtepu8_epi16 (_mm256_extracti128_si256 (v0, 0 ));
2562
+ __m256i u16hi0 = _mm256_cvtepu8_epi16 (_mm256_extracti128_si256 (v0, 1 ));
2563
+ __m256i u16lo1 = _mm256_cvtepu8_epi16 (_mm256_extracti128_si256 (v1, 0 ));
2564
+ __m256i u16hi1 = _mm256_cvtepu8_epi16 (_mm256_extracti128_si256 (v1, 1 ));
2565
+ // --- Expand to uint32_t and convert to float ---
2566
+ dst[0 ] = _mm512_cvtepi32_ps (_mm512_cvtepu16_epi32 (u16lo0));
2567
+ dst[1 ] = _mm512_cvtepi32_ps (_mm512_cvtepu16_epi32 (u16hi0));
2568
+ dst[2 ] = _mm512_cvtepi32_ps (_mm512_cvtepu16_epi32 (u16lo1));
2569
+ dst[3 ] = _mm512_cvtepi32_ps (_mm512_cvtepu16_epi32 (u16hi1));
2570
+ };
2571
+
2572
+ auto load_convert_u8_to_f32_128bit = [&](const uint8_t * src) {
2573
+ // --- Load and expand uint8_t -> uint16_t ---
2574
+ __m256i v_u16 = _mm256_cvtepu8_epi16 (_mm_loadu_si128 ((__m128i*)src));
2575
+ // --- Expand to uint32_t and convert to float ---
2576
+ return _mm512_cvtepi32_ps (_mm512_cvtepu16_epi32 (v_u16));
2577
+ };
2578
+
2579
+ auto store_output = [&](__m512 out, T* out_addr) {
2580
+ if constexpr (std::is_same<T, float >::value) {
2581
+ _mm512_storeu_ps (out_addr, out);
2582
+ } else if constexpr (std::is_same<T, at::BFloat16>::value) {
2583
+ __m256i out_bf16 = cvtfp32_bf16 (out);
2584
+ _mm256_storeu_si256 ((__m256i*)out_addr, out_bf16);
2585
+ } else if constexpr (std::is_same<T, at::Half>::value) {
2586
+ __m256i out_f16 = cvtfp32_fp16 (out);
2587
+ _mm256_storeu_si256 ((__m256i*)out_addr, out_f16);
2588
+ } else { // T == uint8, requantization needed
2589
+ __m512i out_i32 = _mm512_cvtps_epi32 (out);
2590
+ out_i32 = _mm512_add_epi32 (out_i32, out_zero_point_v);
2591
+ out_i32 = _mm512_min_epi32 (out_i32, v_q_max);
2592
+ out_i32 = _mm512_max_epi32 (out_i32, v_q_min);
2593
+ __m128i out_i8 = _mm512_cvtepi32_epi8 (out_i32);
2594
+ _mm_storeu_si128 ((__m128i*)out_addr, out_i8);
2595
+ }
2596
+ };
2597
+ #endif
2598
+
2599
+ at::parallel_for (0 , outer_size, 0 , [&](int64_t begin, int64_t end) {
2600
+ for (const auto i : c10::irange (begin, end)) {
2601
+ auto * X_ptr = in_ptr + i * C;
2602
+ auto * Y_ptr = out_ptr + i * C;
2603
+ int64_t ch = 0 ;
2604
+
2605
+ #if defined(CPU_CAPABILITY_AVX512)
2606
+ __m512 vals_dq[num_vecs];
2607
+ for (; ch + lanes <= C; ch += lanes) {
2608
+ // load 64 values of input then dequantize them
2609
+ load_convert_u8_to_f32_512bit (X_ptr + ch, vals_dq);
2610
+ for (const auto idx : c10::irange (num_vecs)) {
2611
+ vals_dq[idx] = _mm512_fmadd_ps (fake_scale, vals_dq[idx], scale_neg_zp_premul);
2612
+ auto alpha_v = _mm512_loadu_ps (alpha_ptr + ch + idx * kVLen );
2613
+ auto beta_v = _mm512_loadu_ps (beta_ptr + ch + idx * kVLen );
2614
+ vals_dq[idx] = _mm512_fmadd_ps (alpha_v, vals_dq[idx], beta_v);
2615
+ store_output (vals_dq[idx], Y_ptr + ch + idx * kVLen );
2616
+ }
2617
+ }
2618
+
2619
+ // for channel between 16 and 64
2620
+ int64_t elem_size = C - ch;
2621
+ if (elem_size >= kVLen ) {
2622
+ int64_t vec_num = elem_size / kVLen ;
2623
+ for (const auto idx : c10::irange (vec_num)) {
2624
+ __m512 val_dq = load_convert_u8_to_f32_128bit (X_ptr + ch + idx * kVLen );
2625
+ val_dq = _mm512_fmadd_ps (fake_scale, val_dq, scale_neg_zp_premul);
2626
+ auto alpha_v = _mm512_loadu_ps (alpha_ptr + ch + idx * kVLen );
2627
+ auto beta_v = _mm512_loadu_ps (beta_ptr + ch + idx * kVLen );
2628
+ val_dq = _mm512_fmadd_ps (alpha_v, val_dq, beta_v);
2629
+ store_output (val_dq, Y_ptr + ch + idx * kVLen );
2630
+ }
2631
+ ch += vec_num * kVLen ;
2632
+ }
2633
+ #endif
2634
+ // for channels less than 16
2635
+ for (; ch < C; ++ch) {
2636
+ float y_val_f = alpha_ptr[ch] * (X_ptr[ch] - in_zero_point) +
2637
+ beta_ptr[ch];
2638
+ if constexpr (std::is_same<T, float >::value) {
2639
+ Y_ptr[ch] = y_val_f;
2640
+ } else if constexpr (std::is_same<T, at::BFloat16>::value) {
2641
+ Y_ptr[ch] = (at::BFloat16)y_val_f;
2642
+ } else if constexpr (std::is_same<T, at::Half>::value) {
2643
+ Y_ptr[ch] = (at::Half)y_val_f;
2644
+ } else { // T == uint8, requantization needed
2645
+ long quantized_down = out_zero_point + lrintf (y_val_f);
2646
+ Y_ptr[ch] = std::min<long >(
2647
+ std::max<long >(quantized_down, q_min), q_max);
2648
+ }
2649
+ }
2650
+ }
2651
+ });
2652
+ }
2653
+
2654
+ void q_batch_norm_cpu_kernel (
2655
+ int64_t N,
2656
+ int64_t C,
2657
+ int64_t HxW,
2658
+ int64_t in_zero_point,
2659
+ int64_t out_zero_point,
2660
+ const Tensor& input,
2661
+ const Tensor& a,
2662
+ const Tensor& b,
2663
+ Tensor& output) {
2664
+ auto in_ptr = input.const_data_ptr <uint8_t >();
2665
+ float * alpha_ptr = a.data_ptr <float >();
2666
+ float * beta_ptr = b.data_ptr <float >();
2667
+ AT_DISPATCH_FLOATING_TYPES_AND3 (
2668
+ at::ScalarType::BFloat16, at::ScalarType::Half, at::ScalarType::Byte, output.scalar_type (), " int8_batch_norm2d_cpu" , [&] {
2669
+ auto out_ptr = output.data_ptr <scalar_t >();
2670
+ q_batch_norm_cpu_kernel_impl<scalar_t >(
2671
+ N, C, HxW, in_zero_point, out_zero_point, in_ptr, alpha_ptr, beta_ptr, out_ptr);
2672
+ });
2673
+ }
2674
+
2524
2675
void _fake_quantize_tensor_helper (
2525
2676
Tensor& output,
2526
2677
Tensor& mask,
@@ -4587,5 +4738,6 @@ REGISTER_DISPATCH(qstd_inner_dim_stub, &qstd_inner_dim_kernel)
4587
4738
ALSO_REGISTER_AVX512_DISPATCH (qmul_tensor_cpu_stub, &qmul_tensor_cpu_kernel)
4588
4739
ALSO_REGISTER_AVX512_DISPATCH (qadd_tensor_cpu_stub, &qadd_tensor_cpu_kernel<false >)
4589
4740
ALSO_REGISTER_AVX512_DISPATCH (qadd_relu_tensor_cpu_stub, &qadd_tensor_cpu_kernel<true >)
4741
+ ALSO_REGISTER_AVX512_DISPATCH (qbatch_norm_cpu_stub, &q_batch_norm_cpu_kernel)
4590
4742
} // namespace at::native
4591
4743
// NOLINTEND(*-c-arrays)
0 commit comments