-
Notifications
You must be signed in to change notification settings - Fork 24.3k
[Intel GPU] Enable mkdnn._linear_pointwise at XPU backend #140365
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
Show all changes
28 commits
Select commit
Hold shift + click to select a range
249d701
Update
ZhiweiYan-96 f1dcb1a
Update
ZhiweiYan-96 691449b
Update
ZhiweiYan-96 5a0e425
Update
ZhiweiYan-96 2caa517
Update
ZhiweiYan-96 ef27ccb
Update
ZhiweiYan-96 be0755f
Update
ZhiweiYan-96 f87eb96
Update
ZhiweiYan-96 8b87482
Update
ZhiweiYan-96 5dd4806
Update
ZhiweiYan-96 6a0e085
Update
ZhiweiYan-96 c9f9ccd
Update
ZhiweiYan-96 2149ea7
Update
ZhiweiYan-96 fe81e43
Update
ZhiweiYan-96 82ce0a4
Update
ZhiweiYan-96 6399e84
Update
ZhiweiYan-96 9060933
Update
ZhiweiYan-96 3e66b0c
Update
ZhiweiYan-96 f112b57
Update
ZhiweiYan-96 60d2cb7
Update
ZhiweiYan-96 ba9b916
Update
ZhiweiYan-96 342fbb6
Update
ZhiweiYan-96 7400f84
Update
ZhiweiYan-96 c9a0e43
Update
ZhiweiYan-96 94c476f
Update
ZhiweiYan-96 d74e58b
Update
ZhiweiYan-96 4efb54f
Update
ZhiweiYan-96 c3718b7
Update
ZhiweiYan-96 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
#include <ATen/native/mkldnn/xpu/FusionUtils.h> | ||
|
||
using namespace at::native::onednn; | ||
|
||
namespace at::native::xpu { | ||
|
||
onednn::Attr& handle_argument_less(std::string_view unary, onednn::Attr& attr) { | ||
static const std::unordered_map< | ||
std::string_view, | ||
std::function<onednn::Attr&(onednn::Attr&)>> | ||
unary_map = { | ||
{"relu", | ||
[](onednn::Attr& attr) -> onednn::Attr& { | ||
return attr.append_post_eltwise( | ||
1.0f, 0.0f, 0.0f, attr.kind_with_relu); | ||
}}, | ||
{"sigmoid", | ||
[](onednn::Attr& attr) -> onednn::Attr& { | ||
return attr.append_post_eltwise( | ||
1.0f, 0.0f, 0.0f, attr.kind_with_sigmoid); | ||
}}, | ||
{"tanh", | ||
[](onednn::Attr& attr) -> onednn::Attr& { | ||
return attr.append_post_eltwise( | ||
1.0f, 0.0f, 0.0f, attr.kind_with_tanh); | ||
}}, | ||
{"hardswish", | ||
[](onednn::Attr& attr) -> onednn::Attr& { | ||
return attr.append_post_eltwise( | ||
1.0f, 1.0f / 6.0f, 1.0f / 2.0f, attr.kind_with_hardswish); | ||
}}, | ||
{"swish", | ||
[](onednn::Attr& attr) -> onednn::Attr& { | ||
return attr.append_post_eltwise( | ||
1.0f, 1.0f, 0.0f, attr.kind_with_swish); | ||
}}, | ||
{"hardsigmoid", | ||
[](onednn::Attr& attr) -> onednn::Attr& { | ||
return attr.append_post_eltwise( | ||
1.0f, 1.0f / 6.0f, 1.0f / 2.0f, attr.kind_with_hardsigmoid); | ||
}}, | ||
{"none", [](onednn::Attr& attr) -> onednn::Attr& { return attr; }}}; | ||
|
||
if (unary_map.find(unary) != unary_map.end()) { | ||
return unary_map.at(unary)(attr); | ||
} | ||
TORCH_CHECK( | ||
false, | ||
"Unary attr ", | ||
unary, | ||
" is not supported for conv/linear post unary fusion"); | ||
} | ||
|
||
onednn::Attr& handle_need_sclars( | ||
std::string_view unary, | ||
onednn::Attr& attr, | ||
torch::List<std::optional<at::Scalar>> scalars) { | ||
static const std::unordered_map< | ||
std::string_view, | ||
std::function<onednn::Attr&( | ||
onednn::Attr&, torch::List<std::optional<at::Scalar>>)>> | ||
unary_map = { | ||
{"leaky_relu", | ||
[](onednn::Attr& attr, | ||
torch::List<std::optional<at::Scalar>> scalars) -> onednn::Attr& { | ||
auto alpha = | ||
scalars[0].get().toOptional<at::Scalar>().value().to<float>(); | ||
return attr.append_post_eltwise( | ||
1.0f, alpha, 0.f, attr.kind_with_relu); | ||
}}, | ||
{"hardtanh", | ||
[](onednn::Attr& attr, | ||
torch::List<std::optional<at::Scalar>> scalars) -> onednn::Attr& { | ||
auto alpha = | ||
scalars[0].get().toOptional<at::Scalar>().value().to<float>(); | ||
auto beta = | ||
scalars[1].get().toOptional<at::Scalar>().value().to<float>(); | ||
return attr.append_post_eltwise( | ||
1.0f, alpha, beta, attr.kind_with_clip); | ||
}}}; | ||
|
||
if (unary_map.find(unary) != unary_map.end()) { | ||
return unary_map.at(unary)(attr, scalars); | ||
} | ||
TORCH_CHECK( | ||
false, | ||
"Unary attr ", | ||
unary, | ||
" is not supported for conv/linear post unary fusion"); | ||
} | ||
|
||
onednn::Attr& handle_need_algorithm( | ||
std::string_view unary, | ||
onednn::Attr& attr, | ||
std::optional<std::string_view> algorithm) { | ||
TORCH_CHECK( | ||
unary == "gelu", | ||
"GELU is the only unary operation that requires an algorithm currently"); | ||
if (!algorithm.has_value()) { | ||
TORCH_CHECK( | ||
false, | ||
"GELU algorithm is not specified, please specify it as 'none' or 'tanh'"); | ||
} | ||
enum dnnl::algorithm gelu_type; | ||
if (algorithm.value() == "none") { | ||
gelu_type = attr.kind_with_gelu_erf; | ||
} else { | ||
gelu_type = attr.kind_with_gelu_tanh; | ||
} | ||
return attr.append_post_eltwise(1.0f, 0.0f, 0.0f, gelu_type); | ||
} | ||
|
||
onednn::Attr& construct_unary_attr( | ||
onednn::Attr& attr, | ||
std::string_view unary, | ||
torch::List<std::optional<at::Scalar>> scalars, | ||
std::optional<std::string_view> algorithm) { | ||
// Define sets for unary operations based on their argument requirements. | ||
// Category `argument_less`: stateless operations | ||
// Category `need_scalars`: require alpha/beta | ||
// Category `need_algorithm`: require algorithm specification, only gelu now. | ||
// If further unary operations required, they can be added to these sets or | ||
// add new sets according to their new categories. | ||
static const std::set<std::string_view> argument_less = { | ||
"relu", "sigmoid", "tanh", "hardswish", "swish", "hardsigmoid"}; | ||
static const std::set<std::string_view> need_scalars = { | ||
"leaky_relu", "hardtanh"}; | ||
static const std::set<std::string_view> need_algorithm = {"gelu"}; | ||
|
||
if (argument_less.find(unary) != argument_less.end()) { | ||
return handle_argument_less(unary, attr); | ||
} else if (need_scalars.find(unary) != need_scalars.end()) { | ||
return handle_need_sclars(unary, attr, scalars); | ||
} else if (need_algorithm.find(unary) != need_algorithm.end()) { | ||
return handle_need_algorithm(unary, attr, algorithm); | ||
} else { | ||
TORCH_CHECK( | ||
false, | ||
"Unary attr ", | ||
unary, | ||
" is not supported for conv/linear post unary fusion"); | ||
} | ||
} | ||
|
||
} // namespace at::native::xpu |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
#pragma once | ||
#include <detail/oneDNN.h> | ||
|
||
// | ||
// This header file provides utility functions for constructing and managing | ||
// oneDNN attributes used in fusion operations on XPU devices. These utilities | ||
// include functions for creating unary and binary post-operations attributes, | ||
// as well as mapping string representations of operations to oneDNN attributes. | ||
// | ||
|
||
namespace at::native::xpu { | ||
at::native::onednn::Attr& unary_attr_with_arg( | ||
onednn::Attr& attr, | ||
std::string_view unary, | ||
torch::List<std::optional<at::Scalar>> scalars, | ||
std::optional<std::string_view> algorithm); | ||
|
||
at::native::onednn::Attr& string_to_unary_attr( | ||
onednn::Attr& attr, | ||
std::string_view unary); | ||
|
||
at::native::onednn::Attr& construct_unary_attr( | ||
onednn::Attr& attr, | ||
std::string_view unary, | ||
torch::List<std::optional<at::Scalar>> scalars, | ||
std::optional<std::string_view> algorithm); | ||
|
||
template <bool is_matmul = false> | ||
onednn::Attr& construct_binary_attr( | ||
onednn::Attr& attr, | ||
std::string_view binary, | ||
const Tensor& other) { | ||
if (binary == "mul") { | ||
attr.append_post_binary<is_matmul>(attr.kind_with_binary_mul, other); | ||
} else if (binary == "sub") { | ||
attr.append_post_binary<is_matmul>(attr.kind_with_binary_sub, other); | ||
} else if (binary == "div") { | ||
attr.append_post_binary<is_matmul>(attr.kind_with_binary_div, other); | ||
} else if (binary == "add") { | ||
attr.append_post_binary<is_matmul>(attr.kind_with_binary_add, other); | ||
} else if (binary == "sum") { | ||
attr.append_post_sum(1.f, 1.f, 0); | ||
} else { | ||
TORCH_CHECK( | ||
binary == "none", | ||
"Binary attr ", | ||
binary, | ||
"is not supported for conv/linear post binary fusion"); | ||
} | ||
guangyey marked this conversation as resolved.
Show resolved
Hide resolved
guangyey marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return attr; | ||
} | ||
|
||
} // namespace at::native::xpu |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that rev
2364
eals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
#include <ATen/DeviceGuard.h> | ||
#include <torch/library.h> | ||
|
||
#include <FusionUtils.h> | ||
|
||
namespace at::native::xpu { | ||
|
||
std::tuple<std::vector<int64_t>, std::vector<int64_t>, std::vector<int64_t>> | ||
collapse_in_out_dim(at::Tensor input, int64_t dim, at::Tensor weight) { | ||
// dim collapse, e.g. [B, M, K] -> [BM, K] | ||
std::vector<int64_t> input_reshaped_size = (dim == 2) | ||
? std::vector<int64_t>(input.size(0), input.size(1)) | ||
: std::vector<int64_t>{ | ||
input.numel() / (input.size(input.dim() - 1)), | ||
input.size(input.dim() - 1)}; | ||
// [B, M, K] -> [B, M] | ||
std::vector<int64_t> output_size( | ||
input.sizes().begin(), input.sizes().end() - 1); | ||
// [B, M, N] | ||
output_size.push_back(weight.size(0)); | ||
|
||
// [BM, N] | ||
std::vector<int64_t> output_reshaped_size{ | ||
input_reshaped_size[0], weight.size(0)}; | ||
return {input_reshaped_size, output_size, output_reshaped_size}; | ||
} | ||
|
||
Tensor linear_pointwise( | ||
const Tensor& input_t, // [M, K] or [B, M, K] | ||
const Tensor& weight_t, // [N, K] | ||
const std::optional<Tensor>& bias_opt, | ||
std::string_view attr, | ||
torch::List<std::optional<at::Scalar>> scalars, | ||
std::optional<std::string_view> algorithm) { | ||
onednn::Attr att; | ||
const OptionalDeviceGuard device_guard(device_of(input_t)); | ||
att = construct_unary_attr(att, attr, scalars, algorithm); | ||
auto input = input_t.contiguous(); | ||
|
||
guangyey marked this conversation as resolved.
Show resolved
Hide resolved
|
||
const int64_t dim = input.dim(); | ||
|
||
auto [input_reshaped_size, output_size, output_reshaped_size] = | ||
collapse_in_out_dim(input, dim, weight_t); | ||
Tensor output = at::empty(output_size, input.options()); | ||
Tensor input_reshaped = input; | ||
if (dim != 2) { | ||
output = output.reshape(output_reshaped_size); | ||
input_reshaped = input_reshaped.reshape(input_reshaped_size); | ||
} | ||
|
||
auto bias = bias_opt.has_value() ? bias_opt.value() : at::Tensor(); | ||
at::native::onednn::matmul( | ||
output, input_reshaped, weight_t, bias, /*m2_trans*/ false, att); | ||
|
||
if (dim != 2) { | ||
output = output.reshape(output_size); | ||
} | ||
|
||
return output; | ||
} | ||
|
||
Tensor linear_pointwise_binary( | ||
const Tensor& input_t, | ||
const Tensor& other_t, | ||
const Tensor& weight_t, | ||
const std::optional<Tensor>& bias_opt, | ||
std::string_view binary_attr) { | ||
const OptionalDeviceGuard device_guard(device_of(input_t)); | ||
onednn::Attr attr; | ||
attr = construct_binary_attr<true>(attr, binary_attr, other_t); | ||
auto input = input_t.contiguous(); | ||
|
||
guangyey marked this conversation as resolved.
Show resolved
Hide resolved
|
||
const int64_t dim = input.dim(); | ||
|
||
// dim collapse | ||
auto [input_reshaped_size, output_size, output_reshaped_size] = | ||
collapse_in_out_dim(input, dim, weight_t); | ||
Tensor output = at::empty(output_size, input.options()); | ||
Tensor input_reshaped = input; | ||
|
||
if (dim != 2) { | ||
// input [m, k], weight [n, k], output [m, n] | ||
output = output.reshape(output_reshaped_size); | ||
input_reshaped = input_reshaped.reshape(input_reshaped_size); | ||
} else { | ||
TORCH_CHECK( | ||
output.dim() == other_t.dim(), | ||
"linear_binary_run expects the dimension of output and other tensor to be the same"); | ||
} | ||
|
||
auto bias = bias_opt.has_value() ? bias_opt.value() : at::Tensor(); | ||
at::native::onednn::matmul( | ||
output, input_reshaped, weight_t, bias, /*m2_trans*/ false, attr); | ||
|
||
if (dim != 2) { | ||
output = output.reshape(output_size); | ||
} | ||
return output; | ||
} | ||
|
||
TORCH_LIBRARY_IMPL(mkldnn, XPU, m) { | ||
m.impl( | ||
TORCH_SELECTIVE_NAME("mkldnn::_linear_pointwise"), | ||
TORCH_FN(linear_pointwise)); | ||
m.impl( | ||
TORCH_SELECTIVE_NAME("mkldnn::_linear_pointwise.binary"), | ||
TORCH_FN(linear_pointwise_binary)); | ||
} | ||
|
||
} // namespace at::native::xpu |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ZhiweiYan-96 , in terms of these literals like hardswish, slilu, etc, they should be defined as a constant variable.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for reminding. I am investigating the elegant way to remove verboseif-else
branches. Final solution is not ready yet, though.I tried, using lambda function+ hash map ( e.g silu -> lambda func(silu)). The argument list is big, the code become more verbose.
Using template to dispatch function (e.g. silu -> handle<silu_func>()), but meet some syntax errors due to ambiguous function call roots from the complex argument list with
std::optional
.Please allow me more time to carefully refactor this.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@EikanWang @guangyey @etaf
Could you please help review the design? The overall logic is
The process would be
Pros:
if-else
branch anymore, cleaner codes.unordered_map
following current code and does not need to handle if-else logic. extensibility is better.std::function
atunordered_map
signature.Cons:
if-else
.