8000 [ATen][CUDA] Implement 128 bit vectorization v2 (#145746) · pytorch/pytorch@e84bf88 · GitHub
[go: up one dir, main page]

Skip to content

Commit e84bf88

Browse files
Aidyn-ASkylion007
authored andcommitted
[ATen][CUDA] Implement 128 bit vectorization v2 (#145746)
This is a re-base PR to my previous one #141959. Description from the original PR: This PR implements 128-bit vectorization. It improves the performance of contiguous elementwise ops by 4-10% on Hopper H100. <details> <summary>The benchmark code used </summary> ```Python import time import torch from torch.profiler import profile, ProfilerActivity def benchmark(function, dtype=torch.float32, check_numerics=True, print_profile=False): device = torch.device("cuda") shapes = [] for p in range(24, 30): shape = 1<<p shapes.append(shape) for shape in shapes: for _ in range(6): x = torch.randn(shape, device=device, dtype=dtype) y = function(x) if print_profile: x = torch.randn(shape, device=device, dtype=dtype) with profile(activities=[ProfilerActivity.CUDA], record_shapes=True) as prof: y = function(x) print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) x = torch.randn(shape, device=device, dtype=dtype) torch.cuda.synchronize() t1 = time.perf_counter() for _ in range(6): y = function(x) torch.cuda.synchronize() t2 = time.perf_counter() perf_time = (t2 - t1) / 6 print(f"{function.__name__}, {dtype}, {shape}, {perf_time}") if check_numerics: x_cpu = x.cpu() y_cpu = function(x_cpu).cuda() try: torch.testing.assert_allclose(y_cpu, y) except AssertionError as error: print("An exception occurred:", error) def main(): ops = [ torch.relu, torch.sigmoid, torch.tanh, torch.nn.functional.gelu, torch.sin, torch.exp, ] dtypes = [ torch.float16, torch.bfloat16, torch.float32, ] for op in ops: for dtype in dtypes: benchmark(op, dtype=dtype) torch.cuda.empty_cache() if __name__ == "__main__": main() ``` </details> <details> <summary> Results </summary> | op | dtype | size | time after | time before | % improvement | | ---- | ---- | ---- | ---- | ---- | ---- | | relu | torch.float16 | 33554432 | 4.84E-05 | 5.06E-05 | 4.66296539127052 | | relu | torch.float16 | 67108864 | 9.22E-05 | 9.64E-05 | 4.56491432752297 | | relu | torch.float16 | 134217728 | 0.000180343495837102 | 0.000187981834945579 | 4.23543919508829 | | relu | torch.float16 | 268435456 | 0.000355071155354381 | 0.000370856161074092 | 4.44558942107169 | | relu | torch.float16 | 536870912 | 0.000704489842367669 | 0.000736006341564159 | 4.47366268483987 | | relu | torch.bfloat16 | 16777216 | 3.03E-05 | 3.04E-05 | 0.166504085842689 | | relu | torch.bfloat16 | 33554432 | 4.89E-05 | 5.06E-05 | 3.45848238875716 | | relu | torch.bfloat16 | 67108864 | 9.32E-05 | 9.65E-05 | 3.56122651631445 | | relu | torch.bfloat16 | 134217728 | 0.000180805509444326 | 0.000187998676362137 | 3.97840029317567 | | relu | torch.bfloat16 | 268435456 | 0.000356242332297067 | 0.000371279485989362 | 4.22104627356745 | | relu | torch.bfloat16 | 536870912 | 0.000708114336399982 | 0.000736773828975856 | 4.04729732229083 | | relu | torch.float32 | 16777216 | 5.61E-05 | 5.61E-05 | 0.0442587268354941 | | relu | torch.float32 | 33554432 | 9.33E-05 | 9.30E-05 | -0.259070913799022 | | relu | torch.float32 | 67108864 | 0.000181321326332788 | 0.000181289506144822 | -0.0175490597877115 | | relu | torch.float32 | 134217728 | 0.000356896334172537 | 0.000356570177245885 | -0.0913870206618981 | | relu | torch.float32 | 268435456 | 0.000709421835684528 | 0.000707465515006334 | -0.275762681635911 | | relu | torch.float32 | 536870912 | 0.00141372415237129 | 0.00141036518228551 | -0.237597276678471 | | sigmoid | torch.float16 | 16777216 | 3.10E-05 | 3.16E-05 | 2.10012593866895 | | sigmoid | torch.float16 | 33554432 | 4.91E-05 | 5.23E-05 | 6.37710600666122 | | sigmoid | torch.float16 | 67108864 | 9.30E-05 | 0.000100057009452333 | 7.61866144555331 | | sigmoid | torch.float16 | 134217728 | 0.000180928347011407 | 0.000194982004662355 | 7.76752669390248 | | sigmoid | torch.float16 | 268435456 | 0.000355658994521946 | 0.00038468533117945 | 8.16128288742412 | | sigmoid | torch.float16 | 536870912 | 0.000705982849467546 | 0.000764021339515845 | 8.22094900634937 | | sigmoid | torch.bfloat16 | 16777216 | 3.08E-05 | 3.17E-05 | 2.90965915673149 | | sigmoid | torch.bfloat16 | 33554432 | 4.87E-05 | 5.24E-05 | 7.63503884668234 | | sigmoid | torch.bfloat16 | 67108864 | 9.33E-05 | 0.000100019678939134 | 7.21238137428013 | | sigmoid | torch.bfloat16 | 134217728 | 0.000180786165098349 | 0.000194868014659733 | 7.78922964250206 | | sigmoid | torch.bfloat16 | 268435456 | 0.000355564659306159 | 0.000384909333661199 | 8.25297835063321 | | sigmoid | torch.bfloat16 | 536870912 | 0.000705831005082776 | 0.000764102345177283 | 8.2557070566308 | | sigmoid | torch.float32 | 16777216 | 4.93E-05 | 5.65E-05 | 14.5314136197766 | | sigmoid | torch.float32 | 33554432 | 9.32E-05 | 9.31E-05 | -0.120169865610833 | | sigmoid | torch.float32 | 67108864 | 0.000181328505277634 | 0.000180455681402236 | -0.481349512069855 | | sigmoid | torch.float32 | 134217728 | 0.000357362829769651 | 0.000356093340087682 | -0.35523831137877 | | sigmoid | torch.float32 | 268435456 | 0.000708921831877281 | 0.000707052337626616 | -0.263709504574663 | | sigmoid | torch.float32 | 536870912 | 0.00141358317341656 | 0.0014090768333214 | -0.318788464654745 | | tanh | torch.float16 | 16777216 | 3.03E-05 | 3.03E-05 | -0.0912564658661808 | | tanh | torch.float16 | 33 8000 554432 | 4.90E-05 | 5.07E-05 | 3.46644442974484 | | tanh | torch.float16 | 67108864 | 9.30E-05 | 9.68E-05 | 3.99871369815531 | | tanh | torch.float16 | 134217728 | 0.00018052199933057 | 0.000188717152923346 | 4.53969799978138 | | tanh | torch.float16 | 268435456 | 0.000355684508879979 | 0.000373026006855071 | 4.8755280430115 | | tanh | torch.float16 | 536870912 | 0.000706660988119741 | 0.000740105014604827 | 4.73268328765002 | | tanh | torch.bfloat16 | 16777216 | 2.99E-05 | 3.03E-05 | 1.21049563135981 | | tanh | torch.bfloat16 | 33554432 | 4.89E-05 | 5.06E-05 | 3.48836101041744 | | tanh | torch.bfloat16 | 67108864 | 9.28E-05 | 9.69E-05 | 4.39944918036626 | | tanh | torch.bfloat16 | 134217728 | 0.000180710999605556 | 0.000189167990659674 | 4.67984299382829 | | tanh | torch.bfloat16 | 268435456 | 0.000356062994493792 | 0.000372666652159144 | 4.66312363882606 | | tanh | torch.bfloat16 | 536870912 | 0.000707100164921333 | 0.000740134331863374 | 4.67178040408393 | | tanh | torch.float32 | 16777216 | 5.61E-05 | 5.64E-05 | 0.439595755746353 | | tanh | torch.float32 | 33554432 | 9.31E-05 | 9.31E-05 | 0.00287633090228212 | | tanh | torch.float32 | 67108864 | 0.000181465332085888 | 0.000180895323865116 | -0.31411411437098 | | tanh | torch.float32 | 134217728 | 0.000356963835656643 | 0.000356073161431899 | -0.249513854283251 | | tanh | torch.float32 | 268435456 | 0.000709201170442005 | 0.00070707315656667 | -0.300057862849997 | | tanh | torch.float32 | 536870912 | 0.00141367283261692 | 0.00141030051357423 | -0.238550176877922 | | gelu | torch.float16 | 16777216 | 2.73E-05 | 3.17E-05 | 15.921079070745 | | gelu | torch.float16 | 33554432 | 5.06E-05 | 5.55E-05 | 9.76345374333098 | | gelu | torch.float16 | 67108864 | 9.65E-05 | 0.000106600326641152 | 10.4308039074712 | | gelu | torch.float16 | 134217728 | 0.000187776672343413 | 0.000208565829476962 | 11.0712139447915 | | gelu | torch.float16 | 268435456 | 0.000370216167842348 | 0.000412251994324227 | 11.3544005187205 | | gelu | torch.float16 | 536870912 | 0.000737301345604161 | 0.000819394170927505 | 11.1342296895002 | | gelu | torch.bfloat16 | 16777216 | 3.02E-05 | 3.08E-05 | 1.78405479367653 | | gelu | torch.bfloat16 | 33554432 | 5.13E-05 | 5.69E-05 | 10.9929393318302 | | gelu | torch.bfloat16 | 67108864 | 9.76E-05 | 0.00010968199543034 | 12.3420807512356 | | gelu | torch.bfloat16 | 134217728 | 0.000189661824454864 | 0.000214487663470209 | 13.0895287371091 | | gelu | torch.bfloat16 | 268435456 | 0.000374197009174774 | 0.000423670164309442 | 13.2211519391275 | | gelu | torch.bfloat16 | 536870912 | 0.000743675006863972 | 0.000842577001700799 | 13.299088166737 | | gelu | torch.float32 | 16777216 | 5.06E-05 | 5.04E-05 | -0.413385894716413 | | gelu | torch.float32 | 33554432 | 9.31E-05 | 9.32E-05 | 0.134157041722546 | | gelu | torch.float32 | 67108864 | 0.000181480175039421 | 0.000180836669945469 | -0.354586992112075 | | gelu | torch.float32 | 134217728 | 0.000356874331676712 | 0.000356305002545317 | -0.159532104402047 | | gelu | torch.float32 | 268435456 | 0.000708909006789327 | 0.000706991491218408 | -0.270488250615287 | | gelu | torch.float32 | 536870912 | 0.00141321367118508 | 0.00140937082081412 | -0.271922813181618 | | sin | torch.float16 | 16777216 | 3.04E-05 | 3.11E-05 | 2.21834939018859 | | sin | torch.float16 | 33554432 | 4.85E-05 | 5.23E-05 | 7.72165512511596 | | sin | torch.float16 | 67108864 | 9.31E-05 | 9.98E-05 | 7.24947099480072 | | sin | torch.float16 | 134217728 | 0.000180371008658161 | 0.000194791161144773 | 7.99471744039613 | | sin | torch.float16 | 268435456 | 0.000355454161763191 | 0.000384903668115536 | 8.28503630574026 | | sin | torch.float16 | 536870912 | 0.000705183832906187 | 0.000764360166310022 | 8.39161799270973 | | sin | torch.bfloat16 | 16777216 | 3.11E-05 | 3.10E-05 | -0.257677954940036 | | sin | torch.bfloat16 | 33554432 | 4.89E-05 | 5.24E-05 | 7.34808420323539 | | sin | torch.bfloat16 | 67108864 | 9.26E-05 | 0.000100248667877167 | 8.22347488801205 | | sin | torch.bfloat16 | 134217728 | 0.000180674154156198 | 0.00019567032965521 | 8.30012215584937 | | sin | torch.bfloat16 | 268435456 | 0.000355360486234228 | 0.000386023331278314 | 8.62865913118873 | | sin | torch.bfloat16 | 536870912 | 0.00070483615854755 | 0.000766805159704139 | 8.79197248964745 | | sin | torch.float32 | 16777216 | 5.67E-05 | 5.64E-05 | -0.441348534920039 | | sin | torch.float32 | 33554432 | 9.34E-05 | 9.30E-05 | -0.496458540364117 | | sin | torch.float32 | 67108864 | 0.000181706990891447 | 0.000180556671693921 | -0.633062708199702 | | sin | torch.float32 | 134217728 | 0.000356894995396336 | 0.000356046327700218 | -0.237791985616354 | | sin | torch.float32 | 268435456 | 0.000708777321657787 | 0.000707602652255446 | -0.165731798471427 | | sin | torch.float32 | 536870912 | 0.00141263716310884 | 0.00140912582476934 | -0.248566187496451 | | exp | torch.float16 | 16777216 | 3.00E-05 | 3.04E-05 | 1.40099098901014 | | exp | torch.float16 | 33554432 | 4.86E-05 | 5.03E-05 | 3.44611943643906 | | exp | torch.float16 | 67108864 | 9.37E-05 | 9.55E-05 | 1.96412400380129 | | exp | torch.float16 | 134217728 | 0.000180913504057874 | 0.000187193179347863 | 3.47109262113439 | | exp | torch.float16 | 268435456 | 0.00035607748820136 | 0.000369079003576189 | 3.65131630210701 | | exp | torch.float16 | 536870912 | 0.000707551507124056 | 0.000732363162872692 | 3.50669251620789 | | exp | torch.bfloat16 | 16777216 | 2.98E-05 | 3.04E-05 | 1.74345594341654 | | exp | torch.bfloat16 | 33554432 | 4.88E-05 | 5.04E-05 | 3.40217856534821 | | exp | torch.bfloat16 | 67108864 | 9.32E-05 | 9.62E-05 | 3.29219958210226 | | exp | torch.bfloat16 | 134217728 | 0.000180999826019009 | 0.000187239318620414 | 3.44723679499521 | | exp | torch.bfloat16 | 268435456 | 0.000355944503098726 | 0.000369370992605885 | 3.77207384585864 | | exp | torch.bfloat16 | 536870912 | 0.000707135167128096 | 0.000733066000975668 | 3.66702648277075 | | exp | torch.float32 | 16777216 | 4.89E-05 | 5.63E-05 | 15.1245314346532 | | exp | torch.float32 | 33554432 | 9.34E-05 | 9.31E-05 | -0.259945454477446 | | exp | torch.float32 | 67108864 | 0.000181152504713585 | 0.000180474346658836 | -0.374357536939058 | | exp | torch.float32 | 134217728 | 0.000356771342922002 | 0.000355627329554409 | -0.3206573034212 | | exp | torch.float32 | 268435456 | 0.000708404501589636 | 0.00070713268360123 | -0.179532736671163 | | exp | torch.float32 | 536870912 | 0.00141283582585553 | 0.00140944866385932 | -0.23974208002295 | </details> Pull Request resolved: #145746 Approved by: https://github.com/eqy, https://github.com/ngimel Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
1 parent eeb5e1b commit e84bf88

File tree

8 files changed

+77
-21
lines changed
  • aten/src/ATen
    • native/cuda
      • jit_utils.cpp
    • test

8 files changed

+77
-21
lines changed

aten/src/ATen/native/cuda/CUDAJitLoops.cuh

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@ struct JittedVecKernelCache {
4949
at::cuda::jit::NvrtcFunction vec1;
5050
at::cuda::jit::NvrtcFunction vec2;
5151
at::cuda::jit::NvrtcFunction vec4;
52-
#ifdef USE_ROCM
5352
at::cuda::jit::NvrtcFunction vec8;
53+
#ifdef USE_ROCM
5454
at::cuda::jit::NvrtcFunction vec16;
5555
#endif
5656

@@ -131,18 +131,30 @@ void launch_jitted_vectorized_kernel(
131131
int vec_size = at::cuda::jit::can_vectorize_up_to(
132132
desc, c10::ArrayRef<char*>(data.data(), data.size()));
133133

134+
#ifndef USE_ROCM
135+
const auto input_size = c10::scalarTypeToTypeMeta(desc.f_inputs_type).itemsize();
136+
const int optimal_vec_size = 16 / static_cast<int>(input_size);
137+
vec_size = std::min<int>(optimal_vec_size, vec_size);
138+
// Here we purposely omit vec8 for 1-byte data because of a bug in NVCC
139+
// that causes some numerical mismatches with uint8 on sm80 and sm90.
140+
// TODO: Revisit this after CUDA 12.8 update.
141+
if (input_size < 2) {
142+
vec_size = std::min<int>(vec_size, 4);
143+
}
144+
#endif
145+
134146
// Different kernels are compiled depending on what we're vectorizing up to (1, 2 or 4 elements)
135147
// fn_ptr is set to the appropriate function based on the vec size and GPU used
136148
at::cuda::jit::NvrtcFunction* fn_ptr = nullptr;
137149

138150
#ifdef USE_ROCM
139151
if (vec_size == 16) {
140152
fn_ptr = &fn_cache.vec16;
141-
} else if (vec_size == 8) {
142-
fn_ptr = &fn_cache.vec8;
143153
} else
144154
#endif
145-
if (vec_size == 4) {
155+
if (vec_size == 8) {
156+
fn_ptr = &fn_cache.vec8;
157+
} else if (vec_size == 4) {
146158
fn_ptr = &fn_cache.vec4;
147159
} else if (vec_size == 2) {
148160
fn_ptr = &fn_cache.vec2;

aten/src/ATen/native/cuda/CUDALoops.cuh

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ constexpr auto sum_of_sizes(args_t args, std::index_sequence<Is...>) {
6161
}
6262
}
6363

64+
#ifdef USE_ROCM
6465
template <int io_sizes>
6566
constexpr auto elems_per_thread(){
6667
if constexpr (io_sizes == 1) {
@@ -71,6 +72,16 @@ constexpr auto elems_per_thread(){
7172
return 4;
7273
}
7374
}
75+
#else
76+
template <int io_sizes>
77+
constexpr auto elems_per_thread(){
78+
if constexpr (io_sizes == 1) {
79+
return 16;
80+
} else {
81+
return 8;
82+
}
83+
}
84+
#endif
7485

7586
template <int io_sizes>
7687
constexpr auto io_block_work_size() {
@@ -191,21 +202,33 @@ static inline void launch_vectorized_kernel(
191202
constexpr auto io_size = calc_io_size<func_t>();
192203
int64_t grid = (N + io_block_work_size<io_size>() - 1) / io_block_work_size<io_size>();
193204
auto stream = at::cuda::getCurrentCUDAStream();
205+
#ifdef USE_ROCM
194206
int vec_size = memory::can_vectorize_up_to<func_t>(data);
195-
207+
#else
208+
using cpp_type = typename function_traits<func_t>::result_type;
209+
const uint16_t max_vec_size = memory::can_vectorize_up_to<func_t>(data);
210+
uint16_t vec_size = 16 / static_cast<uint16_t>(sizeof(cpp_type));
211+
vec_size = std::min<uint16_t>(vec_size, max_vec_size);
212+
// Here we purposely omit vec8 for 1-byte data because of a bug in NVCC
213+
// that causes some numerical mismatches with uint8 on sm80 and sm90.
214+
// TODO: Revisit this after CUDA 12.8 update.
215+
if constexpr (sizeof(cpp_type) < 2) {
216+
vec_size = std::min<uint16_t>(vec_size, 4);
217+
}
218+
#endif
196219
switch (vec_size) {
197220
#ifdef USE_ROCM
198221
case 16:
199222
vectorized_elementwise_kernel<16, func_t, array_t>
200223
<<<grid, num_threads(), 0, stream>>>(N, f, data);
201224
C10_CUDA_KERNEL_LAUNCH_CHECK();
202225
break;
226+
#endif
203227
case 8:
204228
vectorized_elementwise_kernel<8, func_t, array_t>
205229
<<<grid, num_threads(), 0, stream>>>(N, f, data);
206230
C10_CUDA_KERNEL_LAUNCH_CHECK();
207231
break;
208-
#endif
209232
case 4:
210233
vectorized_elementwise_kernel<4, func_t, array_t>
211234
<<<grid, num_threads(), 0, stream>>>(N, f, data);

aten/src/ATen/native/cuda/Dropout.cu

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,8 +217,11 @@ int get_vector_size(at::Tensor self, at::Tensor ret, at::Tensor mask) {
217217
// make sure we don't break assumption that we can't have > 16 elements / thread
218218
TORCH_INTERNAL_ASSERT(vec_size <= 16, "Value of VEC must be in [2, 4, 8, 16]");
219219
#else
220+
const int optimal_vec_size = 16 / static_cast<int>(sizeof(scalar_t));
221+
vec_size = std::min<int>(optimal_vec_size, vec_size);
222+
220223
// make sure we don't break assumption that we can't have > 4 elements / thread
221-
TORCH_INTERNAL_ASSERT(vec_size <= 4, "Value of VEC must be in [2, 4]");
224+
TORCH_INTERNAL_ASSERT(vec_size <= 8, "Value of VEC must be in [2, 4, 8]");
222225
#endif
223226
}
224227

aten/src/ATen/native/cuda/MemoryAccess.cuh

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -351,15 +351,19 @@ inline C10_HOST_DEVICE int can_vectorize_up_to(const char *pointer) {
351351
uint64_t address = reinterpret_cast<uint64_t>(pointer);
352352
constexpr int vec2_alignment = std::alignment_of_v<aligned_vector<scalar_t, 2>>;
353353
constexpr int vec4_alignment = std::alignment_of_v<aligned_vector<scalar_t, 4>>;
354-
#ifdef USE_ROCM
355354
constexpr int vec8_alignment = std::alignment_of_v<aligned_vector<scalar_t, 8>>;
355+
#ifdef USE_ROCM
356356
constexpr int vec16_alignment = std::alignment_of_v<aligned_vector<scalar_t, 16>>;
357357
constexpr int type_size = sizeof(scalar_t);
358358
if (type_size == 1 && (address % vec16_alignment == 0)) {
359359
return 16;
360360
} else if (type_size <= 2 && (address % vec8_alignment == 0)) {
361361
return 8;
362362
} else
363+
#else
364+
if (address % vec8_alignment == 0) {
365+
return 8;
366+
} else
363367
#endif
364368
if (address % vec4_alignment == 0) {
365369
return 4;

aten/src/ATen/native/cuda/jit_utils.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -932,7 +932,6 @@ void initializeCudaContext() {
932932
}
933933
}
934934

935-
#ifdef USE_ROCM
936935
int calc_io_size(
937936
const int nInputs,
938937
const int nOutputs,
@@ -952,7 +951,6 @@ int calc_io_size(
952951

953952
return 0;
954953
}
955-
#endif
956954

957955
int calc_thread_work_size(
958956
const int nInputs,
@@ -971,7 +969,14 @@ int calc_thread_work_size(
971969
}
972970
return io_size;
973971
#else
974-
return JIT_THREAD_WORK_SIZE;
972+
auto io_size = at::cuda::jit::calc_io_size(nInputs, nOutputs, inputs_type, result_type);
973+
TORCH_INTERNAL_ASSERT(io_size > 0);
974+
if (io_size == 1) {
975+
return 16;
976+
} else {
977+
return 8;
978+
}
979+
return io_size;
975980
#endif
976981
}
977982

aten/src/ATen/native/cuda/jit_utils.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,10 @@ inline int can_vectorize_up_to(size_t default_alignment, void *pointer) {
6060
if ((default_alignment <= 2) && (ip % (8 * default_alignment) == 0)) {
6161
return 8;
6262
}
63+
#else
64+
if (ip % (8 * default_alignment) == 0) {
65+
return 8;
66+
}
6367
#endif
6468
if (ip % (4 * default_alignment) == 0) {
6569
return 4;
@@ -88,15 +92,17 @@ inline int can_vectorize_up_to(const KernelDescriptor &desc, c10::ArrayRef<char*
8892
}
8993

9094
//FIXME - this are defined in Loops.cuh, but including Loops.cuh here would lead to circular includes Loops.cuh -> CUDALoops.cuh -> jit_utils.h -> Loops.cuh
95+
#ifdef USE_ROCM
9196
#define JIT_THREAD_WORK_SIZE 4
97+
#else
98+
#define JIT_THREAD_WORK_SIZE 8
99+
#endif
92100

93-
#ifdef USE_ROCM
94101
int calc_io_size(
95102
const int nInputs,
96103
const int nOutputs,
97104
const c10::ScalarType& inputs_type,
98105
const c10::ScalarType& result_type);
99-
#endif
100106

101107
int calc_thread_work_size(
102108
const int nInputs,

aten/src/ATen/native/cuda/thread_constants.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,14 @@
1212
constexpr int num_threads() {
1313
return 256;
1414
}
15+
16+
constexpr int thread_work_size() { return 4; }
1517
#else
1618
constexpr uint32_t num_threads() {
1719
return C10_WARP_SIZE * 4;
1820
}
21+
22+
constexpr int thread_work_size() { return 8; }
1923
#endif
2024

21-
constexpr int thread_work_size() { return 4; }
2225
constexpr int block_work_size() { return thread_work_size() * num_threads(); }

aten/src/ATen/test/cuda_vectorized_test.cu

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,11 @@ TEST(TestLoops, HasSameArgTypes) {
4747
TEST(TestVectorizedMemoryAccess, CanVectorizeUpTo) {
4848
char *ptr = reinterpret_cast<char *>(buffer1);
4949

50-
ASSERT_EQ(memory::can_vectorize_up_to<bool>(ptr), 4);
51-
ASSERT_EQ(memory::can_vectorize_up_to<int8_t>(ptr), 4);
52-
ASSERT_EQ(memory::can_vectorize_up_to<int16_t>(ptr), 4);
53-
ASSERT_EQ(memory::can_vectorize_up_to<int>(ptr), 4);
54-
ASSERT_EQ(memory::can_vectorize_up_to<int64_t>(ptr), 4);
50+
ASSERT_EQ(memory::can_vectorize_up_to<bool>(ptr), 8);
51+
ASSERT_EQ(memory::can_vectorize_up_to<int8_t>(ptr), 8);
52+
ASSERT_EQ(memory::can_vectorize_up_to<int16_t>(ptr), 8);
53+
ASSERT_EQ(memory::can_vectorize_up_to<int>(ptr), 8);
54+
ASSERT_EQ(memory::can_vectorize_up_to<int64_t>(ptr), 8);
5555

5656
ASSERT_EQ(memory::can_vectorize_up_to<bool>(ptr + 1), 1);
5757
ASSERT_EQ(memory::can_vectorize_up_to<int8_t>(ptr + 1), 1);
@@ -65,8 +65,8 @@ TEST(TestVectorizedMemoryAccess, CanVectorizeUpTo) {
6565
ASSERT_EQ(memory::can_vectorize_up_to<int16_t>(ptr + 4), 2);
6666
ASSERT_EQ(memory::can_vectorize_up_to<int>(ptr + 4), 1);
6767

68-
ASSERT_EQ(memory::can_vectorize_up_to<bool>(ptr + 8), 4);
69-
ASSERT_EQ(memory::can_vectorize_up_to<int8_t>(ptr + 8), 4);
68+
ASSERT_EQ(memory::can_vectorize_up_to<bool>(ptr + 8), 8);
69+
ASSERT_EQ(memory::can_vectorize_up_to<int8_t>(ptr + 8), 8);
7070
ASSERT_EQ(memory::can_vectorize_up_to<int16_t>(ptr + 8), 4);
7171
ASSERT_EQ(memory::can_vectorize_up_to<int>(ptr + 8), 2);
7272
ASSERT_EQ(memory::can_vectorize_up_to<int64_t>(ptr + 8), 1);

0 commit comments

Comments
 (0)
0