-
-
Notifications
You must be signed in to change notification settings - Fork 56.2k
GSoC Add ONNX Support for GatherElements #24092
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
86bd690
4fee29e
421b3ce
ab3b5a3
b99e1ab
4ff95a2
873bfd4
611b32f
501ebaf
72ab576
3bf4944
d949dbd
e2fec12
8d3be8e
2dd7b21
6c3a8a1
ac3ea42
e8e1637
83031aa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
// This file is part of OpenCV project. | ||
// It is subject to the license terms in the LICENSE file found in the top-level directory | ||
// of this distribution and at http://opencv.org/license.html. | ||
|
||
#include "../precomp.hpp" | ||
#include <opencv2/dnn/shape_utils.hpp> | ||
|
||
namespace cv { namespace dnn { | ||
|
||
static inline int calculateOffset(int outer_dim, const MatShape &shape_indices, int axis_skip, const MatStep &step_data) { | ||
int offset = 0; | ||
for (int axis = static_cast<int>(shape_indices.size()) - 2; axis >= 0; axis--) { | ||
int dim = shape_indices[axis]; | ||
if (axis != axis_skip) { | ||
offset += (outer_dim % dim) * step_data[axis]; | ||
} | ||
outer_dim /= dim; | ||
} | ||
return offset; | ||
} | ||
|
||
class GatherElementsLayerImpl CV_FINAL : public GatherElementsLayer | ||
{ | ||
public: | ||
GatherElementsLayerImpl(const LayerParams& params) | ||
{ | ||
setParamsFrom(params); | ||
axis = params.get<int>("axis", 0); | ||
} | ||
|
||
virtual bool supportBackend(int backendId) CV_OVERRIDE | ||
{ | ||
return backendId == DNN_BACKEND_OPENCV; | ||
} | ||
|
||
virtual bool getMemoryShapes(const std::vector<MatShape> &inputs, | ||
const int requiredOutputs, | ||
std::vector<MatShape> &outputs, | ||
std::vector<MatShape> &internals) const CV_OVERRIDE | ||
{ | ||
CV_CheckEQ(inputs.size(), 2ull, "GatherElements: requires two inputs"); | ||
|
||
const auto &data = inputs[0]; | ||
const auto &indices = inputs[1]; | ||
CV_CheckEQ(data.size(), indices.size(), "GatherElements: data and indices should have the same dimension"); | ||
|
||
int normalized_axis = normalize_axis(axis, static_cast<int>(data.size())); | ||
CV_CheckGE(normalized_axis, 0, "GatherElements: axis out of range"); | ||
CV_CheckLT(normalized_axis, static_cast<int>(data.size()), "GatherElements: axis out of range"); | ||
for (size_t i = 0; i < data.size(); i++) { | ||
if (i != normalized_axis) { | ||
CV_CheckEQ(data[i], indices[i], "GatherElements: shape mismatched"); | ||
} | ||
} | ||
|
||
outputs.assign(1, inputs[1]); // shape of output is same as indices | ||
return false; | ||
} | ||
|
||
virtual void finalize(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr) CV_OVERRIDE { | ||
std::vector<Mat> inputs; | ||
inputs_arr.getMatVector(inputs); | ||
|
||
const auto &data = inputs[0]; | ||
axis = normalize_axis(axis, data.dims); | ||
} | ||
|
||
void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE | ||
{ | ||
CV_TRACE_FUNCTION(); | ||
CV_TRACE_ARG_VALUE(name, "name", name.c_str()); | ||
|
||
std::vector<Mat> inputs, outputs; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Implementation should handle FP16 datatype too
https://pullrequest.opencv.org/buildbot/builders/4_x-lin64/builds/100283 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok, I will do this. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See #24427 |
||
inputs_arr.getMatVector(inputs); | ||
outputs_arr.getMatVector(outputs); | ||
|
||
const Mat& data = inputs[0]; | ||
const Mat& indices = inputs[1]; | ||
Mat& out = outputs[0]; | ||
|
||
typeDispatch(outputs[0].type(), data, indices, out); | ||
} | ||
|
||
template <typename T> | ||
void forward_impl(const Mat& data_, const Mat& indices_, Mat& out_) | ||
{ | ||
const auto *ptr_data = data_.ptr<const T>(); | ||
const auto *ptr_indices = indices_.ptr<const T>(); | ||
auto *ptr_out = out_.ptr<T>(); | ||
|
||
const auto shape_data = shape(data_); | ||
const auto &step_data = data_.step; | ||
const auto shape_indices = shape(indices_); | ||
|
||
int inner_most_dim = shape_indices.back(); | ||
int axis_dim = shape_data[axis]; | ||
size_t axis_step = static_cast<size_t>(step_data[axis] / sizeof(T)); | ||
|
||
bool innermost_axis = axis == static_cast<int>(shape_data.size() - 1); | ||
|
||
auto fn = [&](const Range &r) { | ||
for (int i = r.start; i < r.end; i++) { | ||
auto *data = ptr_data + static_cast<size_t>(calculateOffset(i, shape_indices, axis, step_data) / sizeof(T)); | ||
auto *indices = ptr_indices + i * inner_most_dim; | ||
auto *out = ptr_out + i * inner_most_dim; | ||
|
||
if (innermost_axis) { | ||
for (int j = 0; j < inner_most_dim; j++) { | ||
int index = static_cast<int>((indices[j] + axis_dim)) % axis_dim; // TODO: Check out-of-range index | ||
out[j] = data[index]; | ||
} | ||
} else { | ||
for (int j = 0; j < inner_most_dim; j++) { | ||
int index = static_cast<int>(indices[j] + axis_dim) % axis_dim; // TODO: Check out-of-range index | ||
out[j] = data[index * axis_step + j]; | ||
} | ||
} | ||
} | ||
}; | ||
|
||
int outer_dims = total(shape_indices, 0, shape_indices.size() - 1); | ||
double nstripes = static_cast<size_t>(outer_dims * inner_most_dim * (1 / 1024.0)); | ||
parallel_for_(Range(0, outer_dims), fn, nstripes); | ||
} | ||
|
||
template<typename... Args> | ||
inline void typeDispatch(const int type, Args&&... args) | ||
{ | ||
switch (type) | ||
{ | ||
case CV_8U: | ||
forward_impl<uint8_t>(std::forward<Args>(args)...); | ||
break; | ||
case CV_32S: | ||
forward_impl<int32_t>(std::forward<Args>(args)...); | ||
break; | ||
case CV_32F: | ||
forward_impl<float>(std::forward<Args>(args)...); | ||
break; | ||
default: | ||
CV_Error(cv::Error::BadDepth, "DNN/GatherElements: Unsupported type."); | ||
}; | ||
} | ||
|
||
private: | ||
int axis; | ||
}; | ||
|
||
Ptr<GatherElementsLayer> GatherElementsLayer::create(const LayerParams& params) | ||
{ | ||
return makePtr<GatherElementsLayerImpl>(params); | ||
} | ||
|
||
}} // namespace cv::dnn |
Uh oh!
There was an error while loading. Please reload this page.