-
Notifications
You must be signed in to change notification settings - Fork 24.7k
Description
🚀 Feature
Generalize binary_kernel_reduce_vec
to support the more advanced reductions that can't be vectorized currently, e.g. torch.norm
.
Motivation
As mentioned by @ngimel in #39516 (comment), non-vectorized reductions are much slower than their vectorized counterparts. Currently, the following CPU reductions are not vectorized:
norm
std_var
mean
(this kernel isn't actually used becausesum().div_()
is faster, see [ATen] mean operator is unvectorized on CPU #16617)argmin
/argmax
.
Except for argmin
/argmax
, I think these could all be efficiently vectorized, if not for the fact that binary_kernel_reduce_vec
doesn't support generalised reduction ops like binary_kernel_reduce
does. All of these require separate reduce
, combine
and project
operations, whereas the _vec
is limiting because it requires reduce
and combine
to be the same and project
must to be a no-op.
Pitch
I propose binary_kernel_reduce_vec
should be generalized to support reduction operators similar to binary_kernel_reduce
. This would mean that each individual operation (reduce
, combine
and project
) would need to be overloaded for both scalar and vector types. Additionally, there needs to be an operation to convert from vector accumulator back to a scalar accumulator in order to perform inner reductions. I was thinking accumulator_to_scalar
, but welcome other names.
In addition, I think it would be useful to have a customizable multi_row_reduce
operation, similar to the multi_row_sum
operation I used in in #39516. This would allow cascade-sum to use the generic machinery and also allow torch.norm
and torch.mean
to use the cascade-sum algorithm as well, for improved numerical accuracy.
Compared to binary_kernel_reduce
, there would need to be a lot more functions defined for each reduction operation. However, for the simple cases these are mostly boilerplate which could be generated by a helper function based on a lambda pair op
, vop
like it is done currently. In total, a full reduction operation would need to define each of these functions:
template <typename input, typename output>
struct ReductionOps {
using scalar_t = input;
using vec_t = Vec256<scalar_t>;
using acc_t = T;
using vacc_t = Vec256<acc_t>;
using result_t = output;
using vresult_t = Vec256<result_t>;
constexpr int ilp_factor = 4; // Number of rows reduced by multi_row_reduce
acc_t identity();
vacc_t videntity();
acc_t reduce(acc_t acc, scalar_t data);
vacc_t reduce(vacc_t acc, vec_t data);
acc_t accumulator_to_scalar(vacc_t vec_acc) const;
std::array<acc_t, ilp_factor> multi_row_reduce(
const char * C10_RESTRICT in_data,
const int64_t row_stride,
const int64_t col_stride,
const int64_t size) const;
std::array<vacc_t, ilp_factor> multi_row_reduce(
const char * C10_RESTRICT in_data,
const int64_t row_stride,
const int64_t col_stride,
const int64_t size) const;
acc_t combine(acc_t a, acc_t b);
vacc_t combine(vacc_t, vacc_t b);
result_t project(acc_t a);
vresult_t project(vacc_t a);
};