16
16
#include < ATen/native/TopKImpl.h>
17
17
#include < c10/core/WrapDimMinimal.h>
18
18
#include < c10/util/irange.h>
19 +
19
20
#ifdef USE_FBGEMM
20
21
#include < fbgemm/Utils.h>
21
22
#endif
22
23
24
+ #if USE_X86_SIMD_SORT && (defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2))
25
+ #define XSS_COMPILE_TIME_SUPPORTED
26
+ #define XSS_USE_OPENMP
27
+ #include < src/x86simdsort-static-incl.h>
28
+ #endif
29
+
23
30
namespace at ::native {
24
31
25
32
namespace {
@@ -117,6 +124,7 @@ static void parallel_sort1d_kernel(
117
124
std::vector<int64_t > tmp_vals (elements);
118
125
const scalar_t * sorted_keys = nullptr ;
119
126
const int64_t * sorted_vals = nullptr ;
127
+
120
128
std::tie (sorted_keys, sorted_vals) = fbgemm::radix_sort_parallel (
121
129
keys,
122
130
vals,
@@ -165,6 +173,107 @@ static inline void sort_kernel_impl(const value_accessor_t& value_accessor,
165
173
}
166
174
}
167
175
176
+ #if defined(XSS_COMPILE_TIME_SUPPORTED)
177
+
178
+ #define AT_DISPATCH_CASE_XSS_TYPES (...) \
179
+ AT_DISPATCH_CASE (at::ScalarType::Long, __VA_ARGS__) \
180
+ AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
181
+ AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \
182
+ AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__)
183
+
184
+ #define AT_DISPATCH_XSS_TYPES (TYPE, NAME, ...) \
185
+ AT_DISPATCH_SWITCH (TYPE, NAME, AT_DISPATCH_CASE_XSS_TYPES(__VA_ARGS__))
186
+
187
+ static bool can_use_xss_sort(const TensorBase& values, const TensorBase& indices, int64_t dim, const bool stable) {
188
+ // xss_sort is not a stable sort
189
+ if (stable) return false ;
190
+
191
+ auto type = values.scalar_type ();
192
+ if (! (type == ScalarType::Long || type == ScalarType::Int || type == ScalarType::Double || type == ScalarType::Float)) return false ;
193
+
194
+ return true ;
195
+ }
196
+
197
+ static void xss_sort_kernel (
198
+ const TensorBase& values,
199
+ const TensorBase& indices,
200
+ int64_t dim,
201
+ bool descending) {
202
+ auto iter = TensorIteratorConfig ()
203
+ .check_all_same_dtype (false )
204
+ .resize_outputs (false )
205
+ .declare_static_shape (values.sizes (), /* squash_dims=*/ dim)
206
+ .add_output (values)
207
+ .add_output (indices)
208
+ .build ();
209
+
210
+ using index_t = int64_t ;
211
+
212
+ AT_DISPATCH_XSS_TYPES (values.scalar_type (), " xss_sort_kernel" , [&] {
213
+
214
+ auto values_dim_stride = values.stride (dim);
215
+ auto indices_dim_stride = indices.stride (dim);
216
+ auto dim_size = values.size (dim);
217
+
218
+ auto loop = [&](char ** data, const int64_t * strides, int64_t n) {
219
+ auto * values_data_bytes = data[0 ];
220
+ auto * indices_data_bytes = data[1 ];
221
+
222
+ if (values_data_bytes==nullptr || indices_data_bytes==nullptr ){
223
+ return ;
224
+ }
225
+
226
+ if (values_dim_stride == 1 && indices_dim_stride == 1 ){
227
+ for (const auto i C10_UNUSED : c10::irange (n)) {
228
+ x86simdsortStatic::keyvalue_qsort<scalar_t , index_t >(
229
+ reinterpret_cast <scalar_t *>(values_data_bytes),
230
+ reinterpret_cast <index_t *>(indices_data_bytes),
231
+ dim_size,
232
+ true ,
233
+ descending);
234
+
235
+ values_data_bytes += strides[0 ];
236
+ indices_data_bytes += strides[1 ];
237
+ }
238
+ }else {
239
+ std::vector<scalar_t > tmp_values (dim_size);
240
+ std::vector<index_t > tmp_indices (dim_size);
241
+
242
+ for (const auto i : c10::irange (n)) {
243
+ TensorAccessor<scalar_t , 1 > mode_values_acc (
244
+ reinterpret_cast <scalar_t *>(data[0 ] + i * strides[0 ]),
245
+ &dim_size, &values_dim_stride);
246
+ TensorAccessor<index_t , 1 > mode_indices_acc (
247
+ reinterpret_cast <index_t *>(data[1 ] + i * strides[1 ]),
248
+ &dim_size, &indices_dim_stride);
249
+
250
+ for (const auto j : c10::irange (dim_size)) {
251
+ tmp_values[j] = mode_values_acc[j];
252
+ tmp_indices[j] = j;
253
+ }
254
+
255
+ x86simdsortStatic::keyvalue_qsort<scalar_t , index_t >(
256
+ tmp_values.data (),
257
+ tmp_indices.data (),
258
+ dim_size,
259
+ true ,
260
+ descending);
261
+
262
+ for (const auto j : c10::irange (dim_size)) {
263
+ mode_values_acc[j] = tmp_values[j];
264
+ mode_indices_acc[j] = tmp_indices[j];
265
+ }
266
+ }
267
+ }
268
+ };
269
+
270
+ int64_t grain_size = internal::GRAIN_SIZE / std::max (int64_t {1 }, dim_size);
271
+ iter.for_each (loop, /* grain_size=*/ grain_size);
272
+
273
+ });
274
+ }
275
+ #endif
276
+
168
277
static void sort_kernel (
169
278
const TensorBase& self,
170
279
const TensorBase& values,
@@ -179,6 +288,14 @@ static void sort_kernel(
179
288
// https://github.com/pytorch/pytorch/issues/91420
180
289
return ;
181
290
}
291
+
292
+ #if defined(XSS_COMPILE_TIME_SUPPORTED)
293
+ if (can_use_xss_sort (values, indices, dim, stable)){
294
+ xss_sort_kernel (values, indices, dim, descending);
295
+ return ;
296
+ }
297
+ #endif
298
+
182
299
#ifdef USE_FBGEMM
183
300
if (can_use_radix_sort (values, descending)) {
184
301
parallel_sort1d_kernel (values, indices);
@@ -230,6 +347,7 @@ static void topk_kernel(
230
347
int64_t dim,
231
348
bool largest,
232
349
bool sorted) {
350
+
233
351
auto sizes = self.sizes ();
234
352
auto iter = TensorIteratorConfig ()
235
353
.check_all_same_dtype (false )
@@ -264,7 +382,7 @@ static void topk_kernel(
264
382
265
383
} // anonymous namespace
266
384
267
- REGISTER_DISPATCH (sort_stub, &sort_kernel)
268
- REGISTER_DISPATCH (topk_stub, &topk_kernel)
385
+ ALSO_REGISTER_AVX512_DISPATCH (sort_stub, &sort_kernel)
386
+ ALSO_REGISTER_AVX512_DISPATCH (topk_stub, &topk_kernel)
269
387
270
388
} // at::native
0 commit comments