8000 [Intel GPU] Enable mkdnn._linear_pointwise at XPU backend by ZhiweiYan-96 · Pull Request #140365 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[Intel GPU] Enable mkdnn._linear_pointwise at XPU backend #140365

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 28 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 145 additions & 0 deletions aten/src/ATen/native/mkldnn/xpu/FusionUtils.cpp
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ZhiweiYan-96 , in terms of these literals like hardswish, slilu, etc, they should be defined as a constant variable.

Copy link
Collaborator Author
@ZhiweiYan-96 ZhiweiYan-96 May 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for reminding. I am investigating the elegant way to remove verbose if-else branches. Final solution is not ready yet, though.
I tried, using lambda function+ hash map ( e.g silu -> lambda func(silu)). The argument list is big, the code become more verbose.
Using template to dispatch function (e.g. silu -> handle<silu_func>()), but meet some syntax errors due to ambiguous function call roots from the complex argument list with std::optional.
Please allow me more time to carefully refactor this.

Copy link
Collaborator Author
@ZhiweiYan-96 ZhiweiYan-96 May 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@EikanWang @guangyey @etaf
Could you please help review the design? The overall logic is

The process would be

  1. Choose handle function according to categories. Categories consist of 3 type( no argument/need scalar/need algorithm)
  2. Mapping attr string to attribute insertion function in handle functions

Pros:

  1. no verbose if-else branch anymore, cleaner codes.
  2. User of this code could add append attribute insertion func in unordered_map following current code and does not need to handle if-else logic. extensibility is better.
  3. Classify the attribute to 3 categories can reduce the complexity of argument passing in std::function at unordered_map signature.

Cons:

  1. Categories still need a if-else.

10000
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
#include <ATen/native/mkldnn/xpu/FusionUtils.h>

using namespace at::native::onednn;

namespace at::native::xpu {

onednn::Attr& handle_argument_less(std::string_view unary, onednn::Attr& attr) {
static const std::unordered_map<
std::string_view,
std::function<onednn::Attr&(onednn::Attr&)>>
unary_map = {
{"relu",
[](onednn::Attr& attr) -> onednn::Attr& {
return attr.append_post_eltwise(
1.0f, 0.0f, 0.0f, attr.kind_with_relu);
}},
{"sigmoid",
[](onednn::Attr& attr) -> onednn::Attr& {
return attr.append_post_eltwise(
1.0f, 0.0f, 0.0f, attr.kind_with_sigmoid);
}},
{"tanh",
[](onednn::Attr& attr) -> onednn::Attr& {
return attr.append_post_eltwise(
1.0f, 0.0f, 0.0f, attr.kind_with_tanh);
}},
{"hardswish",
[](onednn::Attr& attr) -> onednn::Attr& {
return attr.append_post_eltwise(
1.0f, 1.0f / 6.0f, 1.0f / 2.0f, attr.kind_with_hardswish);
}},
{"swish",
[](onednn::Attr& attr) -> onednn::Attr& {
return attr.append_post_eltwise(
1.0f, 1.0f, 0.0f, attr.kind_with_swish);
}},
{"hardsigmoid",
[](onednn::Attr& attr) -> onednn::Attr& {
return attr.append_post_eltwise(
1.0f, 1.0f / 6.0f, 1.0f / 2.0f, attr.kind_with_hardsigmoid);
}},
{"none", [](onednn::Attr& attr) -> onednn::Attr& { return attr; }}};

if (unary_map.find(unary) != unary_map.end()) {
return unary_map.at(unary)(attr);
}
TORCH_CHECK(
false,
"Unary attr ",
unary,
" is not supported for conv/linear post unary fusion");
}

onednn::Attr& handle_need_sclars(
std::string_view unary,
onednn::Attr& attr,
torch::List<std::optional<at::Scalar>> scalars) {
static const std::unordered_map<
std::string_view,
std::function<onednn::Attr&(
onednn::Attr&, torch::List<std::optional<at::Scalar>>)>>
unary_map = {
{"leaky_relu",
[](onednn::Attr& attr,
torch::List<std::optional<at::Scalar>> scalars) -> onednn::Attr& {
auto alpha =
scalars[0].get().toOptional<at::Scalar>().value().to<float>();
return attr.append_post_eltwise(
1.0f, alpha, 0.f, attr.kind_with_relu);
}},
{"hardtanh",
[](onednn::Attr& attr,
torch::List<std::optional<at::Scalar>> scalars) -> onednn::Attr& {
auto alpha =
scalars[0].get().toOptional<at::Scalar>().value().to<float>();
auto beta =
scalars[1].get().toOptional<at::Scalar>().value().to<float>();
return attr.append_post_eltwise(
1.0f, alpha, beta, attr.kind_with_clip);
}}};

if (unary_map.find(unary) != unary_map.end()) {
return unary_map.at(unary)(attr, scalars);
}
TORCH_CHECK(
false,
"Unary attr ",
unary,
" is not supported for conv/linear post unary fusion");
}

onednn::Attr& handle_need_algorithm(
std::string_view unary,
onednn::Attr& attr,
std::optional<std::string_view> algorithm) {
TORCH_CHECK(
unary == "gelu",
"GELU is the only unary operation that requires an algorithm currently");
if (!algorithm.has_value()) {
TORCH_CHECK(
false,
"GELU algorithm is not specified, please specify it as 'none' or 'tanh'");
}
enum dnnl::algorithm gelu_type;
if (algorithm.value() == "none") {
gelu_type = attr.kind_with_gelu_erf;
} else {
gelu_type = attr.kind_with_gelu_tanh;
}
return attr.append_post_eltwise(1.0f, 0.0f, 0.0f, gelu_type);
}

onednn::Attr& construct_unary_attr(
onednn::Attr& attr,
std::string_view unary,
torch::List<std::optional<at::Scalar>> scalars,
std::optional<std::string_view> algorithm) {
// Define sets for unary operations based on their argument requirements.
// Category `argument_less`: stateless operations
// Category `need_scalars`: require alpha/beta
// Category `need_algorithm`: require algorithm specification, only gelu now.
// If further unary operations required, they can be added to these sets or
// add new sets according to their new categories.
static const std::set<std::string_view> argument_less = {
"relu", "sigmoid", "tanh", "hardswish", "swish", "hardsigmoid"};
static const std::set<std::string_view> need_scalars = {
"leaky_relu", "hardtanh"};
static const std::set<std::string_view> need_algorithm = {"gelu"};

if (argument_less.find(unary) != argument_less.end()) {
return handle_argument_less(unary, attr);
} else if (need_scalars.find(unary) != need_scalars.end()) {
return handle_need_sclars(unary, attr, scalars);
} else if (need_algorithm.find(unary) != need_algorithm.end()) {
return handle_need_algorithm(unary, attr, algorithm);
} else {
TORCH_CHECK(
false,
"Unary attr ",
unary,
" is not supported for conv/linear post unary fusion");
}
}

} // namespace at::native::xpu
53 changes: 53 additions & 0 deletions aten/src/ATen/native/mkldnn/xpu/FusionUtils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#pragma once
#include <detail/oneDNN.h>

//
// This header file provides utility functions for constructing and managing
// oneDNN attributes used in fusion operations on XPU devices. These utilities
// include functions for creating unary and binary post-operations attributes,
// as well as mapping string representations of operations to oneDNN attributes.
//

namespace at::native::xpu {
at::native::onednn::Attr& unary_attr_with_arg(
onednn::Attr& attr,
std::string_view unary,
torch::List<std::optional<at::Scalar>> scalars,
std::optional<std::string_view> algorithm);

at::native::onednn::Attr& string_to_unary_attr(
onednn::Attr& attr,
std::string_view unary);

at::native::onednn::Attr& construct_unary_attr(
onednn::Attr& attr,
std::string_view unary,
torch::List<std::optional<at::Scalar>> scalars,
std::optional<std::string_view> algorithm);

template <bool is_matmul = false>
onednn::Attr& construct_binary_attr(
onednn::Attr& attr,
std::string_view binary,
const Tensor& other) {
if (binary == "mul") {
attr.append_post_binary<is_matmul>(attr.kind_with_binary_mul, other);
} else if (binary == "sub") {
attr.append_post_binary<is_matmul>(attr.kind_with_binary_sub, other);
} else if (binary == "div") {
attr.append_post_binary<is_matmul>(attr.kind_with_binary_div, other);
} else if (binary == "add") {
attr.append_post_binary<is_matmul>(attr.kind_with_binary_add, other);
} else if (binary == "sum") {
attr.append_post_sum(1.f, 1.f, 0);
} else {
TORCH_CHECK(
binary == "none",
"Binary attr ",
binary,
"is not supported for conv/linear post binary fusion");
}
return attr;
}

} // namespace at::native::xpu
110 changes: 110 additions & 0 deletions aten/src/ATen/native/mkldnn/xpu/Linear.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
#include <ATen/DeviceGuard.h>
#include <torch/library.h>

#include <FusionUtils.h>

namespace at::native::xpu {

std::tuple<std::vector<int64_t>, std::vector<int64_t>, std::vector<int64_t>>
collapse_in_out_dim(at::Tensor input, int64_t dim, at::Tensor weight) {
// dim collapse, e.g. [B, M, K] -> [BM, K]
std::vector<int64_t> input_reshaped_size = (dim == 2)
? std::vector<int64_t>(input.size(0), input.size(1))
: std::vector<int64_t>{
input.numel() / (input.size(input.dim() - 1)),
input.size(input.dim() - 1)};
// [B, M, K] -> [B, M]
std::vector<int64_t> output_size(
input.sizes().begin(), input.sizes().end() - 1);
// [B, M, N]
output_size.push_back(weight.size(0));

// [BM, N]
std::vector<int64_t> output_reshaped_size{
input_reshaped_size[0], weight.size(0)};
return {input_reshaped_size, output_size, output_reshaped_size};
}

Tensor linear_pointwise(
const Tensor& input_t, // [M, K] or [B, M, K]
const Tensor& weight_t, // [N, K]
const std::optional<Tensor>& bias_opt,
std::string_view attr,
torch::List<std::optional<at::Scalar>> scalars,
std::optional<std::string_view> algorithm) {
onednn::Attr att;
const OptionalDeviceGuard device_guard(device_of(input_t));
att = construct_unary_attr(att, attr, scalars, algorithm);
auto input = input_t.contiguous();

const int64_t dim = input.dim();

auto [input_reshaped_size, output_size, output_reshaped_size] =
collapse_in_out_dim(input, dim, weight_t);
Tensor output = at::empty(output_size, input.options());
Tensor input_reshaped = input;
if (dim != 2) {
output = output.reshape(output_reshaped_size);
input_reshaped = input_reshaped.reshape(input_reshaped_size);
}

auto bias = bias_opt.has_value() ? bias_opt.value() : at::Tensor();
at::native::onednn::matmul(
output, input_reshaped, weight_t, bias, /*m2_trans*/ false, att);

if (dim != 2) {
output = output.reshape(output_size);
}

return output;
}

Tensor linear_pointwise_binary(
const Tensor& input_t,
const Tensor& other_t,
const Tensor& weight_t,
const std::optional<Tensor>& bias_opt,
std::string_view binary_attr) {
const OptionalDeviceGuard device_guard(device_of(input_t));
onednn::Attr attr;
attr = construct_binary_attr<true>(attr, binary_attr, other_t);
auto input = input_t.contiguous();

const int64_t dim = input.dim();

// dim collapse
auto [input_reshaped_size, output_size, output_reshaped_size] =
collapse_in_out_dim(input, dim, weight_t);
Tensor output = at::empty(output_size, input.options());
Tensor input_reshaped = input;

if (dim != 2) {
// input [m, k], weight [n, k], output [m, n]
output = output.reshape(output_reshaped_size);
input_reshaped = input_reshaped.reshape(input_reshaped_size);
} else {
TORCH_CHECK(
output.dim() == other_t.dim(),
"linear_binary_run expects the dimension of output and other tensor to be the same");
}

auto bias = bias_opt.has_value() ? bias_opt.value() : at::Tensor();
at::native::onednn::matmul(
output, input_reshaped, weight_t, bias, /*m2_trans*/ false, attr);

if (dim != 2) {
output = output.reshape(output_size);
}
return output;
}

TORCH_LIBRARY_IMPL(mkldnn, XPU, m) {
m.impl(
TORCH_SELECTIVE_NAME("mkldnn::_linear_pointwise"),
TORCH_FN(linear_pointwise));
m.impl(
TORCH_SELECTIVE_NAME("mkldnn::_linear_pointwise.binary"),
TORCH_FN(linear_pointwise_binary));
}

} // namespace at::native::xpu
4 changes: 4 additions & 0 deletions aten/src/ATen/native/mkldnn/xpu/detail/Matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ sycl::event matmul(
bool m2_trans,
Attr attr,
const std::vector<sycl::event>& deps) {
// m2_trans means mat2 is transposed from the nn.Linear perspective.
// m2_trans==true means mat2 is [k, n] layout.
// m2_trans==false means mat2 is [n, k] layout, aka, the default layout in
// nn.Linear.
int64_t dims = result.dim();
TORCH_CHECK(
dims == 2 || dims == 3,
Expand Down
Loading
Loading
0