8000 Revert "fix randint distribution for large max (#143787)" · pytorch/pytorch@3571476 · GitHub
[go: up one dir, main page]

Skip to content

Commit 3571476

Browse files
Revert "fix randint distribution for large max (#143787)"
This reverts commit 8059d56. Reverted #143787 on behalf of https://github.com/wdvr due to failing internal tests, to be fixed first ([comment](#143787 (comment)))
1 parent f6801ba commit 3571476

File tree

7 files changed

+11
-45
lines changed

7 files changed

+11
-45
lines changed

aten/src/ATen/core/DistributionsHelper.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,11 @@ struct uniform_int_from_to_distribution {
4040

4141
template <typename RNG>
4242
C10_HOST_DEVICE inline T operator()(RNG generator) {
43-
if (range_ >= 1ULL << 25) // allow approx 1% skew in uniform int generation using %
43+
if ((
44+
std::is_same_v<T, int64_t> ||
45+
std::is_same_v<T, double> ||
46+
std::is_same_v<T, float> ||
47+
std::is_same_v<T, at::BFloat16>) && range_ >= 1ULL << 32)
4448
{
4549
return transformation::uniform_int_from_to<T>(generator->random64(), range_, base_);
4650
} else {

aten/src/ATen/native/cuda/DistributionTemplates.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,11 @@ namespace cuda {
280280
template<typename RNG>
281281
void random_from_to_kernel(TensorIteratorBase& iter, uint64_t range, int64_t base, RNG gen) {
282282
AT_DISPATCH_V2(iter.dtype(), "random_from_to_kernel_cuda", AT_WRAP([&] {
283-
if (range >= 1ULL << 25) // allow approx 1% skew in uniform int generation using %
283+
if ((
284+
std::is_same_v<scalar_t, int64_t> ||
285+
std::is_same_v<scalar_t, double> ||
286+
std::is_same_v<scalar_t, float> ||
287+
std::is_same_v<scalar_t, at::BFloat16>) && range >= 1ULL << 32)
284288
{
285289
// define lambda to mod with range and add base
286290
auto random_func = [range, base] __device__ (uint64_t rand) {

aten/src/ATen/test/rng_test.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,9 +137,7 @@ void test_random_from_to(const at::Device& device) {
137137
range = static_cast<uint64_t>(max_to) - static_cast<uint64_t>(from) + 1;
138138
from_case_covered = true;
139139
}
140-
// this is leaking details of implementation into test
141-
// we are starting to use random64() at 2^25 to minimize skew due to %
142-
if (range < (1ULL << 25)) {
140+
if (range < (1ULL << 32)) {
143141
exp = static_cast<T>(static_cast<int64_t>((static_cast<uint32_t>(val) % range + from)));
144142
} else {
145143
exp = static_cast<T>(static_cast<int64_t>((val % range + from)));

test/inductor/test_torchinductor.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8551,26 +8551,6 @@ def fn(x):
85518551
self.assertGreater(c0.max(), 2**40)
85528552
self.assertLess(c0.max(), 2**50)
85538553

8554-
def test_randint_distribution(self):
8555-
@torch.compile(fullgraph=True)
8556-
def fn(n_argsmax, size):
8557-
return torch.randint(n_max, (size,), device=self.device)
8558-
8559-
def bin(index, max_size):
8560-
return index // (max_size // n_bins)
8561-
8562-
size = 1_000_000
8563-
n_max = int(0.75 * 2**32)
8564-
n_bins = 8
8565-
8566-
res = fn(n_max, size)
8567-
bins = bin(res, n_max).float().cpu()
8568-
hist, _ = bins.histogram(8, range=(0, n_bins))
8569-
expected_bin = res.shape[0] / 8
8570-
expected_error = math.sqrt(expected_bin) / expected_bin * 3
8571-
error = (hist - expected_bin).abs().max() / expected_bin
8572-
self.assertTrue(error < expected_error)
8573-
85748554
@config.patch(fallback_random=True)
85758555
def test_like_rands(self):
85768556
def fn(x):

test/inductor/test_torchinductor_codegen_dynamic_shapes.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,6 @@ def run(*ex, **kwargs):
231231
"test_pointwise_laguerre_polynomial_l_dynamic_shapes": TestFailure(("cuda", "xpu")),
232232
"test_pointwise_legendre_polynomial_p_dynamic_shapes": TestFailure(("cuda", "xpu")),
233233
"test_polar_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu"), is_skip=True),
234-
"test_randint_distribution_dynamic_shapes": TestFailure(("cuda",)),
235234
"test_randn_generator_dynamic_shapes": TestFailure(("cpu",)),
236235
"test_randn_like_empty_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
237236
"test_single_elem_dynamic_shapes": TestFailure(("cpu",)),

test/inductor/test_torchinductor_dynamic_shapes.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@
5959
("cpu", "cuda", "xpu")
6060
),
6161
"test_conv_inference_heuristics_dynamic_shapes": TestFailure(("cuda", "xpu")),
62-
"test_randint_distribution_dynamic_shapes": TestFailure(("cuda",)),
6362
}
6463

6564
if TEST_WITH_ROCM:

test/test_tensor_creation_ops.py

Lines changed: 0 additions 9E88 & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3495,24 +3495,6 @@ def seed(generator):
34953495
self.assertTrue((res1 < 6).all().item())
34963496
self.assertTrue((res1 >= 0).all().item())
34973497

3498-
3499-
def test_randint_distribution(self, device):
3500-
size = 1_000_000
3501-
n_max = int(0.75 * 2 ** 32)
3502-
n_bins = 8
3503-
3504-
def bin(index, max_size):
3505-
return index // (max_size // n_bins)
3506-
res = torch.randint(n_max, (size,), device=device)
3507-
# histogram implemented for float only
3508-
bins = bin(res, n_max).float().cpu()
3509-
hist, _ = bins.histogram(8, range=(0, n_bins))
3510-
expected_bin = res.shape[0] / 8
3511-
expected_error = math.sqrt(expected_bin) / expected_bin * 3
3512-
error = (hist - expected_bin).abs().max() / expected_bin
3513-
self.assertTrue(error < expected_error)
3514-
3515-
35163498
@dtypes(torch.half, torch.float, torch.bfloat16, torch.double,
35173499
torch.complex32, torch.complex64, torch.complex128)
35183500
def test_randn(self, device, dtype):
< 35F3 button class="Button Button--iconOnly Button--invisible ExpandableHunkHeaderDiffLine-module__expand-button-line--wZKjF ExpandableHunkHeaderDiffLine-module__expand-button-unified--Eae6C" aria-label="Expand file down from line 3500" data-direction="down" aria-hidden="true" tabindex="-1">

0 commit comments

Comments
 (0)
0