|
| 1 | +#include <ATen/Tensor.h> |
| 2 | +#include <ATen/core/Tensor.h> |
| 3 | +#include <c10/core/ScalarType.h> |
| 4 | + |
| 5 | +#include <ATen/native/mkldnn/xpu/detail/Attr.h> |
| 6 | +#include <ATen/native/mkldnn/xpu/detail/oneDNNContext.h> |
| 7 | + |
| 8 | +#include <oneapi/dnnl/dnnl.hpp> |
| 9 | + |
| 10 | +namespace at::native::onednn { |
| 11 | + |
| 12 | +at::Tensor broadcast_bias2D( |
| 13 | + at::Tensor& dst, |
| 14 | + at::Tensor& bias, |
| 15 | + int64_t m, |
| 16 | + int64_t n) { |
| 17 | + switch (bias.dim()) { |
| 18 | + case 1: |
| 19 | + TORCH_CHECK( |
| 20 | + bias.size(0) == n || bias.size(0) == 1, |
| 21 | + "matmul supports [n] or [1] when bias dim is 1, but b.size() is:", |
| 22 | + bias.size(0)); |
| 23 | + break; |
| 24 | + case 2: |
| 25 | + if ((bias.size(0) == m && bias.size(1) == n) || |
| 26 | + (bias.size(0) == m && bias.size(1) == 1) || |
| 27 | + (bias.size(0) == m && bias.size(1) == 1)) |
| 28 | + return bias; // No need to broadcast |
| 29 | + TORCH_CHECK( |
| 30 | + bias.size(0) == 1 && bias.size(1) == 1, |
| 31 | + "matmul supports [m, n] or [1, n] or [m, 1] or [1, 1] when bias dim is 2 ...") |
| 32 | + break; |
| 33 | + case 0: |
| 34 | + TORCH_CHECK( |
| 35 | + bias.numel() == 1, "matmul supports 1 numel when bias dim is [] ..."); |
| 36 | + break; |
| 37 | + default: |
| 38 | + TORCH_CHECK(0, "unsupported bias dim in matmul ..."); |
| 39 | + } |
| 40 | + bias = bias.expand({1, n}).contiguous(); |
| 41 | + return bias; |
| 42 | +} |
| 43 | + |
| 44 | +at::Tensor broadcast_bias3D( |
| 45 | + at::Tensor& dst, |
| 46 | + at::Tensor bias, |
| 47 | + int64_t mb, |
| 48 | + int64_t m, |
| 49 | + int64_t n) { |
| 50 | + switch (bias.dim()) { |
| 51 | + case 1: |
| 52 | + TORCH_CHECK( |
| 53 | + bias.size(0) == n || bias.size(0) == 1, |
| 54 | + "matmul supports [n] or [1] when bias dim is 1, but b.size() is:", |
| 55 | + bias.size(0)); |
| 56 | + break; |
| 57 | + case 3: |
| 58 | + TORCH_CHECK( |
| 59 | + are_expandable({mb, m, n}, bias.sizes()), |
| 60 | + "matmul bias must be expandable to:", |
| 61 | + dst.sizes(), |
| 62 | + " but got:", |
| 63 | + bias.sizes()); |
| 64 | + break; |
| 65 | + case 0: |
| 66 | + TORCH_CHECK( |
| 67 | + bias.numel() == 1, "matmul supports 1 numel when bias dim is [] ..."); |
| 68 | + break; |
| 69 | + default: |
| 70 | + TORCH_CHECK(0, "unsupported bias dim in matmul ..."); |
| 71 | + } |
| 72 | + bias = bias.expand({mb, m, n}).contiguous(); |
| 73 | + return bias; |
| 74 | +} |
| 75 | + |
| 76 | +at::Tensor broadcast_bias( |
| 77 | + at::Tensor& dst, |
| 78 | + at::Tensor bias, |
| 79 | + int64_t mb, |
| 80 | + int64_t m, |
| 81 | + int64_t n) { |
| 82 | + if (dst.dim() == 2) { |
| 83 | + return broadcast_bias2D(dst, bias, m, n); |
| 84 | + } else { |
| 85 | + return broadcast_bias3D(dst, bias, mb, m, n); |
| 86 | + } |
| 87 | +} |
| 88 | + |
| 89 | +void quantized_matmul( |
| 90 | + at::Tensor mat1, // act |
| 91 | + double input_scale, |
| 92 | + int64_t input_zero_point, |
| 93 | + at::Tensor mat2, // weight |
| 94 | + at::Tensor& weight_scales, |
| 95 | + at::Tensor& weight_zero_points, |
| 96 | + at::Tensor& bias, |
| 97 | + at::Tensor result, // output |
| 98 | + double output_scale, |
| 99 | + int64_t output_zero_point, |
| 100 | + std::optional<c10::ScalarType> output_dtype, |
| 101 | + std::optional<at::Tensor> other, // extra input for binary-post-op |
| 102 | + double other_scale, |
| 103 | + int64_t other_zero_point, |
| 104 | + const c10::string_view& binary_post_op, |
| 105 | + double binary_alpha, |
| 106 | + const c10::string_view& unary_post_op, |
| 107 | + torch::List<std::optional<at::Scalar>>& unary_post_op_args, |
| 108 | + c10::string_view unary_post_op_algorithm, |
| 109 | + bool m2_trans) { |
| 110 | + // [Note] Quantized Matrix Multiplication at XPU |
| 111 | + // The following code integrates oneDNN quantized gemm. The quantization |
| 112 | + // config we support: |
| 113 | + // activation: s8&u8; per tensor calibrated; symmetric&asymmetric |
| 114 | + // weight: s8; per_tensor/per_channel calibrated; symmetric |
| 115 | + auto attr = Attr(1.0 / output_scale, output_zero_point); |
| 116 | + construct_attr_by_post_op( |
| 117 | + binary_post_op, |
| 118 | + binary_alpha, |
| 119 | + input_scale, |
| 120 | + input_zero_point, |
| 121 | + unary_post_op, |
| 122 | + unary_post_op_args, |
| 123 | + unary_post_op_algorithm, |
| 124 | + attr); |
| 125 | + |
| 126 | + size_t dims = result.dim(); |
| 127 | + at::Device cur_device = at::Device(at::kXPU, c10::xpu::current_device()); |
| 128 | + auto engine = GpuEngineManager::Instance().get_engine(cur_device); |
| 129 | + auto stream = GpuStreamManager::Instance().get_stream(); |
| 130 | + |
| 131 | + at::Tensor m1 = is_onednn_matmul_strides(mat1) ? mat1 : mat1.contiguous(); |
| 132 | + at::Tensor m2 = is_onednn_matmul_strides(mat2) ? mat2 : mat2.contiguous(); |
| 133 | + at::Tensor dst = |
| 134 | + is_onednn_matmul_strides(result, true) ? result : result.contiguous(); |
| 135 | + |
| 136 | + int64_t m = dst.size(-2); |
| 137 | + int64_t n = dst.size(-1); |
| 138 | <
10000
td data-grid-cell-id="diff-47972c1ed4d0fe0a5cd26a6d22cf4a5087a3d5621d1ad55f149c5c552dfc1d4e-empty-138-2" data-line-anchor="diff-47972c1ed4d0fe0a5cd26a6d22cf4a5087a3d5621d1ad55f149c5c552dfc1d4eR138" data-selected="false" role="gridcell" style="background-color:var(--diffBlob-additionLine-bgColor, var(--diffBlob-addition-bgColor-line));padding-right:24px" tabindex="-1" valign="top" class="focusable-grid-cell diff-text-cell right-side-diff-cell left-side">+ int64_t k = m1.size(-1);
| 139 | + int64_t mb = 1; |
| 140 | + |
| 141 | + if (dims == 3) { |
| 142 | + mb = dst.size(0); |
| 143 | + TORCH_CHECK( |
| 144 | + mb == m1.size(0) && mb == m2.size(0), |
| 145 | + "batch size mismatch, dst mb: ", |
| 146 | + mb, |
| 147 | + "m1 mb", |
| 148 | + m1.size(0), |
| 149 | + " m2 mb: ", |
| 150 | + m2.size(0)); |
| 151 | + } |
| 152 | + |
| 153 | + bool with_bias = false; |
| 154 | + at::Tensor b = bias; |
| 155 | + if (b.defined()) { |
| 156 | + with_bias = true; |
| 157 | + b = broadcast_bias(dst, b, mb, m, n); |
| 158 | + } |
| 159 | + // bias is fused in post-op for quantized path |
| 160 | + b = b.contiguous(); // avoid reorder 2 times |
| 161 | + |
| 162 | + auto m1_usr_dt = get_onednn_dtype(m1); |
| 163 | + auto m2_usr_dt = get_onednn_dtype(m2); |
| 164 | + auto dst_usr_dt = get_onednn_dtype(dst); |
| 165 | + |
| 166 | + auto m1_dt = m1_usr_dt; |
| 167 | + auto m2_dt = m2_usr_dt; |
| 168 | + auto dst_dt = dst_usr_dt; |
| 169 | + dnnl::memory::data_type bias_dt; |
| 170 | + |
| 171 | + dnnl::memory::desc m1_md, m1_usr_md; |
| 172 | + dnnl::memory::desc m2_md, m2_usr_md; |
| 173 | + dnnl::memory::desc dst_md, dst_usr_md; |
| 174 | + dnnl::memory::desc b_md; |
| 175 | + |
| 176 | + dnnl::memory::dims m1_dims, m2_dims, dst_dims, bias_dims; |
| 177 | + dnnl::memory::dims m1_strides, m2_strides, dst_strides, bias_strides; |
| 178 | + if (dims == 2) { |
| 179 | + m1_dims = {m, k}; |
| 180 | + m2_dims = {k, n}; // (n, 1) (1, n) |
| 181 | + dst_dims = {m, n}; |
| 182 | + |
| 183 | + m1_strides = {m1.stride(0), m1.stride(1)}; |
| 184 | + if (m2_trans) { |
| 185 | + m2_strides = {m2.stride(0), m2.stride(1)}; |
| 186 | + } else { |
| 187 | + m2_strides = {m2.stride(1), m2.stride(0)}; |
| 188 | + } |
| 189 | + dst_strides = {dst.stride(0), dst.stride(1)}; |
| 190 | + } else { |
| 191 | + m1_dims = {mb, m, k}; |
| 192 | + m2_dims = {mb, k, n}; |
| 193 | + dst_dims = {mb, m, n}; |
| 194 | + |
| 195 | + m1_strides = {m1.stride(0), m1.stride(1), m1.stride(2)}; |
| 196 | + if (m2_trans) { |
| 197 | + m2_strides = {m2.stride(0), m2.stride(1), m2.stride(2)}; |
| 198 | + } else { |
| 199 | + m2_strides = {m2.stride(0), m2.stride(2), m2.stride(1)}; |
| 200 | + } |
| 201 | + dst_strides = {dst.stride(0), dst.stride(1), dst.stride(2)}; |
| 202 | + } |
| 203 | + |
| 204 | + if (with_bias) { |
| 205 | + bias_dims = get_onednn_dims(b); |
| 206 | + bias_dt = get_onednn_dtype(b); |
| 207 | + bias_strides = get_onednn_strides(b); |
| 208 | + } |
| 209 | + |
| 210 | + std::unordered_map<int, dnnl::memory> args; |
| 211 | + |
| 212 | + dnnl::post_ops po; |
| 213 | + po = attr.extract_post_ops( |
| 214 | + dst, |
| 215 | + true, |
| 216 | + dst.scalar_type() == at::kByte || dst.scalar_type() == at::kChar); |
| 217 | + bool m1_need_zp = (input_zero_point != 0); |
| 218 | + bool wgh_is_per_channel = weight_scales.numel() > 1; |
| 219 | + |
| 220 | + dnnl::matmul matmul_p; |
| 221 | + dnnl::matmul::primitive_desc matmul_pd; |
| 222 | + |
| 223 | + m1_md = dnnl::memory::desc(m1_dims, m1_dt, m1_strides); |
| 224 | + m2_md = dnnl::memory::desc(m2_dims, m2_dt, m2_strides); |
| 225 | + dst_md = dnnl::memory::desc(dst_dims, dst_dt, dst_strides); |
| 226 | + dnnl::primitive_attr pattr; |
| 227 | + pattr.set_post_ops(po); |
| 228 | + pattr.set_scratchpad_mode(dnnl::scratchpad_mode::user); |
| 229 | + |
| 230 | + at::Tensor m2_sc; |
| 231 | + if (!wgh_is_per_channel) { |
| 232 | + pattr.set_scales_mask(DNNL_ARG_WEIGHTS, 0); |
| 233 | + } else { |
| 234 | + pattr.set_scales_mask(DNNL_ARG_WEIGHTS, 1 << 1); |
| 235 | + } |
| 236 | + |
| 237 | + at::Tensor m1_sc; |
| 238 | + dnnl::memory::desc m1_sc_md = dnnl::memory::desc( |
| 239 | + {1}, dnnl::memory::data_type::f32, dnnl::memory::format_tag::x); |
| 240 | + int mask_ac = 0; |
| 241 | + pattr.set_scales_mask(DNNL_ARG_SRC, mask_ac); |
| 242 | + if (m1_need_zp) { |
| 243 | + pattr.set_zero_points_mask(DNNL_ARG_SRC, mask_ac); |
| 244 | + } |
| 245 | + |
| 246 | + if (with_bias) { |
| 247 | + b_md = dnnl::memory::desc(bias_dims, bias_dt, bias_strides); |
| 248 | + matmul_pd = |
| 249 | + dnnl::matmul::primitive_desc(engine, m1_md, m2_md, b_md, dst_md, pattr); |
| 250 | + } else { |
| 251 | + matmul_pd = |
| 252 | + dnnl::matmul::primitive_desc(engine, m1_md, m2_md, dst_md, pattr); |
| 253 | + } |
| 254 | + |
| 255 | + matmul_p = dnnl::matmul(matmul_pd); |
| 256 | + |
| 257 | + m1_usr_md = dnnl::memory::desc(m1_dims, m1_usr_dt, m1_strides); |
| 258 | + m2_usr_md = dnnl::memory::desc(m2_dims, m2_usr_dt, m2_strides); |
| 259 | + dst_usr_md = dnnl::memory::desc(dst_dims, dst_usr_dt, dst_strides); |
| 260 | + |
| 261 | + auto m1_usr_m = make_onednn_memory(m1_usr_md, engine, m1.data_ptr()); |
| 262 | + auto m2_usr_m = make_onednn_memory(m2_usr_md, engine, m2.data_ptr()); |
| 263 | + auto dst_usr_m = make_onednn_memory(dst_usr_md, engine, dst.data_ptr()); |
| 264 | + |
| 265 | + auto expected_m1_md = matmul_pd.src_desc(); |
| 266 | + auto expected_m2_md = matmul_pd.weights_desc(); |
| 267 | + auto expected_dst_md = matmul_pd.dst_desc(); |
| 268 | + |
| 269 | + dnnl::memory m1_m = m1_usr_m, m2_m = m2_usr_m, dst_m = dst_usr_m; |
| 270 | + at::Tensor m1_, m2_, dst_; |
| 271 | + |
| 272 | + int scratchpad_size = matmul_pd.scratchpad_desc().get_size(); |
| 273 | + at::Tensor scratchpad_tensor = |
| 274 | + at::empty({scratchpad_size}, m1.options().dtype(at::kByte), c10::nullopt); |
| 275 | + auto scratchpad_memory = make_onednn_memory( |
| 276 | + matmul_pd.scratchpad_desc(), engine, scratchpad_tensor.data_ptr()); |
| 277 | + args.insert({DNNL_ARG_SCRATCHPAD, scratchpad_memory}); |
| 278 | + |
| 279 | + if (attr.with_binary()) |
| 280 | + attr.construct_post_binary(matmul_pd, args); |
| 281 | + |
| 282 | + args.insert({DNNL_ARG_SRC, m1_m}); |
| 283 | + args.insert({DNNL_ARG_WEIGHTS, m2_m}); |
| 284 | + args.insert({DNNL_ARG_DST, dst_m}); |
| 285 | + if (b.defined()) { |
| 286 | + auto b_m = make_onednn_memory(b_md, engine, b.data_ptr()); |
| 287 | + args.insert({DNNL_ARG_BIAS, b_m}); |
| 288 | + } |
| 289 | + |
| 290 | + // Add scale/zp md |
| 291 | + weight_scales = weight_scales.to(at::kFloat); |
| 292 | + dnnl::memory m2_sc_m, m2_zp_m; |
| 293 | + dnnl::memory::desc m2_sc_md = dnnl::memory::desc( |
| 294 | + get_onednn_dims(weight_scales), |
| 295 | + dnnl::memory::data_type::f32, |
| 296 | + dnnl::memory::format_tag::x); |
| 297 | + m2_sc_m = make_onednn_memory(m2_sc_md, engine, weight_scales.data_ptr()); |
| 298 | + args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, m2_sc_m}); |
| 299 | + |
| 300 | + dnnl::memory m1_sc_m, m1_zp_m; |
| 301 | + Tensor m1_sc_tensor, m1_zp_tensor; |
| 302 | + m1_sc_m = dnnl_memory_from_host_scalar( |
| 303 | + static_cast<float>(input_scale), m1_sc_tensor, engine); |
| 304 | + m1_zp_m = dnnl_memory_from_host_scalar( |
| 305 | + static_cast<int32_t>(input_zero_point), m1_zp_tensor, engine); |
| 306 | + |
| 307 | + args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, m1_sc_m}); |
| 308 | + if (m1_need_zp) { |
| 309 | + args.insert({DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC, m1_zp_m}); |
| 310 | + } |
| 311 | + |
| 312 | + auto qmatmul_event = dnnl::sycl_interop::execute(matmul_p, stream, args); |
| 313 | + |
| 314 | + if (!dst.is_same(result)) |
| 315 | + result.copy_(dst); |
| 316 | +} |
| 317 | + |
| 318 | +} // namespace at::native::onednn |
0 commit comments