8000 use higher threshold · pytorch/pytorch@64642db · GitHub
[go: up one dir, main page]

Skip to content

Commit 64642db

Browse files
committed
use higher threshold
1 parent df3f93e commit 64642db

File tree

3 files changed

+4
-4
lines changed

3 files changed

+4
-4
lines changed

aten/src/ATen/core/DistributionsHelper.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ struct uniform_int_from_to_distribution {
4040

4141
template <typename RNG>
4242
C10_HOST_DEVICE inline T operator()(RNG generator) {
43-
if (range_ >= 1ULL << 25) // allow approx 1% skew in uniform int generation using %
43+
if (range_ >= 1ULL << 28) // allow approx 5% skew in uniform int generation using %
4444
{
4545
return transformation::uniform_int_from_to<T>(generator->random64(), range_, base_);
4646
} else {

aten/src/ATen/native/cuda/DistributionTemplates.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ namespace cuda {
280280
template<typename RNG>
281281
void random_from_to_kernel(TensorIteratorBase& iter, uint64_t range, int64_t base, RNG gen) {
282282
AT_DISPATCH_V2(iter.dtype(), "random_from_to_kernel_cuda", AT_WRAP([&] {
283-
if (range >= 1ULL << 25) // allow approx 1% skew in uniform int generation using %
283+
if (range >= 1ULL << 28) // allow approx 5% skew in uniform int generation using %
284284
{
285285
// define lambda to mod with range and add base
286286
auto random_func = [range, base] __device__ (uint64_t rand) {

aten/src/ATen/test/rng_test.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,8 +138,8 @@ void test_random_from_to(const at::Device& device) {
138138
from_case_covered = true;
139139
}
140140
// this is leaking details of implementation into test
141-
// we are starting to use random64() at 2^25 to minimize skew due to %
142-
if (range < (1ULL << 25)) {
141+
// we are starting to use random64() at 2^28 to minimize skew due to %
142+
if (range < (1ULL << 28)) {
143143
exp = static_cast<T>(static_cast<int64_t>((static_cast<uint32_t>(val) % range + from)));
144144
} else {
145145
exp = static_cast<T>(static_cast<int64_t>((val % range + from)));

0 commit comments

Comments
 (0)
0