8000 Adds support for accelerating sorting with x86-simd-sort · pytorch/pytorch@354b43a · GitHub
[go: up one dir, main page]

Skip to content 8000

Commit 354b43a

Browse files
committed
Adds support for accelerating sorting with x86-simd-sort
1 parent 6fc63b4 commit 354b43a

File tree

6 files changed

+146
-2
lines changed

6 files changed

+146
-2
lines changed

.gitmodules

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,3 +131,6 @@
131131
path = third_party/composable_kernel
132132
url = https://github.com/ROCm/composable_kernel.git
133133
branch = develop
134+
[submodule "third_party/x86-simd-sort"]
135+
path = third_party/x86-simd-sort
136+
url = https://github.com/intel/x86-simd-sort.git

CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,7 @@ else()
262262
cmake_dependent_option(USE_CUFILE "Use cuFile" OFF "USE_CUDA AND NOT WIN32" OFF)
263263
endif()
264264
option(USE_FBGEMM "Use FBGEMM (quantized 8-bit server operators)" ON)
265+
option(USE_X86_SIMD_SORT "Use x86-simd-sort to accelerate sorting and topk for AVX2/AVX512" ON)
265266
option(USE_KINETO "Use Kineto profiling library" ON)
266267
option(USE_CUPTI_SO "Use CUPTI as a shared library" ON)
267268
option(USE_FAKELOWP "Use FakeLowp operators" OFF)
@@ -903,6 +904,10 @@ if(USE_FBGEMM)
903904
string(APPEND CMAKE_CXX_FLAGS " -DUSE_FBGEMM")
904905
endif()
905906

907+
if(USE_X86_SIMD_SORT)
908+
string(APPEND CMAKE_CXX_FLAGS " -DUSE_X86_SIMD_SORT")
909+
endif()
910+
906911
if(USE_PYTORCH_QNNPACK)
907912
string(APPEND CMAKE_CXX_FLAGS " -DUSE_PYTORCH_QNNPACK")
908913
endif()

aten/src/ATen/native/cpu/SortingKernel.cpp

Lines changed: 120 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,17 @@
1616
#include <ATen/native/TopKImpl.h>
1717
#include <c10/core/WrapDimMinimal.h>
1818
#include <c10/util/irange.h>
19+
1920
#ifdef USE_FBGEMM
2021
#include <fbgemm/Utils.h>
2122
#endif
2223

24+
#if USE_X86_SIMD_SORT && (defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2))
25+
#define XSS_COMPILE_TIME_SUPPORTED
26+
#define XSS_USE_OPENMP
27+
#include <src/x86simdsort-static-incl.h>
28+
#endif
29+
2330
namespace at::native {
2431

2532
namespace {
@@ -117,6 +124,7 @@ static void parallel_sort1d_kernel(
117124
std::vector<int64_t> tmp_vals(elements);
118125
const scalar_t* sorted_keys = nullptr;
119126
const int64_t* sorted_vals = nullptr;
127+
120128
std::tie(sorted_keys, sorted_vals) = fbgemm::radix_sort_parallel(
121129
keys,
122130
vals,
@@ -165,6 +173,107 @@ static inline void sort_kernel_impl(const value_accessor_t& value_accessor,
165173
}
166174
}
167175

176+
#if defined(XSS_COMPILE_TIME_SUPPORTED)
177+
178+
#define AT_DISPATCH_CASE_XSS_TYPES(...) \
179+
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) \
180+
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
181+
AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \
182+
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__)
183+
184+
#define AT_DISPATCH_XSS_TYPES(TYPE, NAME, ...) \
185+
AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_XSS_TYPES(__VA_ARGS__))
186+
187+
static bool can_use_xss_sort(const TensorBase& values, const TensorBase& indices, int64_t dim, const bool stable) {
188+
// xss_sort is not a stable sort
189+
if (stable) return false;
190+
191+
auto type = values.scalar_type();
192+
if (! (type == ScalarType::Long || type == ScalarType::Int || type == ScalarType::Double || type == ScalarType::Float)) return false;
193+
194+
return true;
195+
}
196+
197+
static void xss_sort_kernel(
198+
const TensorBase& values,
199+
const TensorBase& indices,
200+
int64_t dim,
201+
bool descending) {
202+
auto iter = TensorIteratorConfig()
203+
.check_all_same_dtype(false)
204+
.resize_outputs(false)
205+
.declare_static_shape(values.sizes(), /*squash_dims=*/dim)
206+
.add_output(values)
207+
.add_output(indices)
208+
.build();
209+
210+
using index_t = int64_t;
211+
212+
AT_DISPATCH_XSS_TYPES(values.scalar_type(), "xss_sort_kernel", [&] {
213+
214+
auto values_dim_stride = values.stride(dim);
215+
auto indices_dim_stride = indices.stride(dim);
216+
auto dim_size = values.size(dim);
217+
218+
auto loop = [&](char** data, const int64_t* strides, int64_t n) {
219+
auto* values_data_bytes = data[0];
220+
auto* indices_data_bytes = data[1];
221+
222+
if(values_data_bytes==nullptr || indices_data_bytes==nullptr){
223+
return;
224+
}
225+
226+
if (values_dim_stride == 1 && indices_dim_stride == 1){
227+
for (const auto i C10_UNUSED : c10::irange(n)) {
228+
x86simdsortStatic::keyvalue_qsort<scalar_t, index_t>(
229+
reinterpret_cast<scalar_t*>(values_data_bytes),
230+
reinterpret_cast<index_t*>(indices_data_bytes),
231+
dim_size,
232+
true,
233+
descending);
234+
235+
values_data_bytes += strides[0];
236+
indices_data_bytes += strides[1];
237+
}
238+
}else{
239+
std::vector<scalar_t> tmp_values(dim_size);
240+
std::vector<index_t> tmp_indices(dim_size);
241+
242+
for (const auto i : c10::irange(n)) {
243+
TensorAccessor<scalar_t, 1> mode_values_acc(
244+
reinterpret_cast<scalar_t*>(data[0] + i * strides[0]),
245+
&dim_size, &values_dim_stride);
246+
TensorAccessor<index_t, 1> mode_indices_acc(
247+
reinterpret_cast<index_t*>(data[1] + i * strides[1]),
248+
&dim_size, &indices_dim_stride);
249+
250+
for (const auto j : c10::irange(dim_size)) {
251+
tmp_values[j] = mode_values_acc[j];
252+
tmp_indices[j] = j;
253+
}
254+
255+
x86simdsortStatic::keyvalue_qsort<scalar_t, index_t>(
256+
tmp_values.data(),
257+
tmp_indices.data(),
258+
dim_size,
259+
true,
260+
descending);
261+
262+
for (const auto j : c10::irange(dim_size)) {
263+
mode_values_acc[j] = tmp_values[j];
264+
mode_indices_acc[j] = tmp_indices[j];
265+
}
266+
}
267+
}
268+
};
269+
270+
int64_t grain_size = internal::GRAIN_SIZE / std::max(int64_t{1}, dim_size);
271+
iter.for_each(loop, /*grain_size=*/grain_size);
272+
273+
});
274+
}
275+
#endif
276+
168277
static void sort_kernel(
169278
const TensorBase& self,
170279
const TensorBase& values,
@@ -179,6 +288,14 @@ static void sort_kernel(
179288
// https://github.com/pytorch/pytorch/issues/91420
180289
return;
181290
}
291+
292+
#if defined(XSS_COMPILE_TIME_SUPPORTED)
293+
if (can_use_xss_sort(values, indices, dim, stable)){
294+
xss_sort_kernel(values, indices, dim, descending);
295+
return;
296+
}
297+
#endif
298+
182299
#ifdef USE_FBGEMM
183300
if (can_use_radix_sort(values, descending)) {
184301
parallel_sort1d_kernel(values, indices);
@@ -230,6 +347,7 @@ static void topk_kernel(
230347
int64_t dim,
231348
bool largest,
232349
bool sorted) {
350+
233351
auto sizes = self.sizes();
234352
auto iter = TensorIteratorConfig()
235353
.check_all_same_dtype(false)
@@ -264,7 +382,7 @@ static void topk_kernel(
264382

265383
} // anonymous namespace
266384

267-
REGISTER_DISPATCH(sort_stub, &sort_kernel)
268-
REGISTER_DISPATCH(topk_stub, &topk_kernel)
385+
ALSO_REGISTER_AVX512_DISPATCH(sort_stub, &sort_kernel)
386+
ALSO_REGISTER_AVX512_DISPATCH(topk_stub, &topk_kernel)
269387

270388
} //at::native

cmake/Dependencies.cmake

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1301,6 +1301,22 @@ if(CAFFE2_CMAKE_BUILDING_WITH_MAIN_REPO AND NOT INTERN_DISABLE_ONNX)
13011301
set(BUILD_SHARED_LIBS ${TEMP_BUILD_SHARED_LIBS})
13021302
endif()
13031303

1304+
# --[ x86-simd-sort integration
1305+
if(USE_X86_SIMD_SORT)
1306+
if(NOT CMAKE_SIZEOF_VOID_P EQUAL 8)
1307+
message(WARNING
1308+
"x64 operating system is required for x86-simd-sort. "
1309+
"Not compiling with x86-simd-sort. "
1310+
"Turn this warning off by USE_X86_SIMD_SORT=OFF.")
1311+
set(USE_X86_SIMD_SORT OFF)
1312+
endif()
1313+
1314+
if(USE_X86_SIMD_SORT)
1315+
set(XSS_SIMD_SORT_INCLUDE_DIR ${CMAKE_CURRENT_LIST_DIR}/../third_party/x86-simd-sort)
1316+
include_directories(SYSTEM ${XSS_SIMD_SORT_INCLUDE_DIR})
1317+
endif()
1318+
endif()
1319+
13041320
# --[ ATen checks
13051321
set(USE_LAPACK 0)
13061322

cmake/Summary.cmake

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ function(caffe2_print_configuration_summary)
133133
endif()
134134
message(STATUS " BUILD_NVFUSER : ${BUILD_NVFUSER}")
135135
message(STATUS " USE_EIGEN_FOR_BLAS : ${CAFFE2_USE_EIGEN_FOR_BLAS}")
136+
message(STATUS " USE_X86_SIMD_SORT : ${USE_X86_SIMD_SORT}")
136137
message(STATUS " USE_FBGEMM : ${USE_FBGEMM}")
137138
message(STATUS " USE_FAKELOWP : ${USE_FAKELOWP}")
138139
message(STATUS " USE_KINETO : ${USE_KINETO}")

third_party/x86-simd-sort

Submodule x86-simd-sort added at 9a1b616

0 commit comments

Comments
 (0)
0