10000 fix randint distribution for large max (#143787) · pytorch/pytorch@ab1f627 · GitHub
[go: up one dir, main page]

Skip to content

Commit ab1f627

Browse files
ngimelpytorchmergebot
authored andcommitted
fix randint distribution for large max (#143787)
Fixes #ISSUE_NUMBER Similar to #143682, for large maximum values we were sampling integers via % and it doesn't provide uniform distribution. Here we limit the max skew to approx 1% (random32 is used for max values `<= 2**32 / 128`) This comes with significant perf penalty, especially for cuda, but it's a pretty bad bug, so we'll have to figure out what can be done to improve it. `torch.compile` has always been producing correct results for this, and it's performance is also significantly better than current eager (eager is ~660 GB/s on H100, torch.compile 1200 GB/s), so we have to figure out why torch.compile is better. `__launch_bounds__` slightly regress perf, so perhaps we can figure out how to specify them better, but it's only 20-30 GB/s, so the big difference is still unexplained. Pull Request resolved: #143787 Approved by: https://github.com/eqy
1 parent 0e1675a commit ab1f627

7 files changed

+45
-11
lines changed

aten/src/ATen/core/DistributionsHelper.h

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,7 @@ struct uniform_int_from_to_distribution {
4040

4141
template <typename RNG>
4242
C10_HOST_DEVICE inline T operator()(RNG generator) {
43-
if ((
44-
std::is_same_v<T, int64_t> ||
45-
std::is_same_v<T, double> ||
46-
std::is_same_v<T, float> ||
47-
std::is_same_v<T, at::BFloat16>) && range_ >= 1ULL << 32)
43+
if (range_ >= 1ULL << 28) // allow approx 5% skew in uniform int generation using %
4844
{
4945
return transformation::uniform_int_from_to<T>(generator->random64(), range_, base_);
5046
} else {

aten/src/ATen/native/cuda/DistributionTemplates.h

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -280,11 +280,7 @@ namespace cuda {
280280
template<typename RNG>
281281
void random_from_to_kernel(TensorIteratorBase& iter, uint64_t range, int64_t base, RNG gen) {
282282
AT_DISPATCH_V2(iter.dtype(), "random_from_to_kernel_cuda", AT_WRAP([&] {
283-
if ((
284-
std::is_same_v<scalar_t, int64_t> ||
285-
std::is_same_v<scalar_t, double> ||
286-
std::is_same_v<scalar_t, float> ||
287-
std::is_same_v<scalar_t, at::BFloat16>) && range >= 1ULL << 32)
283+
if (range >= 1ULL << 28) // allow approx 5% skew in uniform int generation using %
288284
{
289285
// define lambda to mod with range and add base
290286
auto random_func = [range, base] __device__ (uint64_t rand) {

aten/src/ATen/test/rng_test.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,9 @@ void test_random_from_to(const at::Device& device) {
137137
range = static_cast<uint64_t>(max_to) - static_cast<uint64_t>(from) + 1;
138138
from_case_covered = true;
139139
}
140-
if (range < (1ULL << 32)) {
140+
// this is leaking details of implementation into test
141+
// we are starting to use random64() at 2^28 to minimize skew due to %
142+
if (range < (1ULL << 28)) {
141143
exp = static_cast<T>(static_cast<int64_t>((static_cast<uint32_t>(val) % range + from)));
142144
} else {
143145
exp = static_cast<T>(static_cast<int64_t>((val % range + from)));

test/inductor/test_torchinductor.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8657,6 +8657,26 @@ def fn(x):
86578657
self.assertGreater(c0.max(), 2**40)
86588658
self.assertLess(c0.max(), 2**50)
86598659

8660+
def test_randint_distribution(self):
8661+
@torch.compile(fullgraph=True)
8662+
def fn(n_argsmax, size):
8663+
return torch.randint(n_max, (size,), device=self.device)
8664+
8665+
def bin(index, max_size):
8666+
return index // (max_size // n_bins)
8667+
8668+
size = 1_000_000
8669+
n_max = int(0.75 * 2**32)
8670+
n_bins = 8
8671+
8672+
res = fn(n_max, size)
8673+
bins = bin(res, n_max).float().cpu()
8674+
hist, _ = bins.histogram(8, range=(0, n_bins))
8675+
expected_bin = res.shape[0] / 8
8676+
expected_error = math.sqrt(expected_bin) / expected_bin * 3
8677+
error = (hist - expected_bin).abs().max() / expected_bin
8678+
self.assertTrue(error < expected_error)
8679+
86608680
@config.patch(fallback_random=True)
86618681
def test_like_rands(self):
86628682
def fn(x):

test/inductor/test_torchinductor_codegen_dynamic_shapes.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,7 @@ def run(*ex, **kwargs):
231231
"test_pointwise_laguerre_polynomial_l_dynamic_shapes": TestFailure(("cuda", "xpu")),
232232
"test_pointwise_legendre_polynomial_p_dynamic_shapes": TestFailure(("cuda", "xpu")),
233233
"test_polar_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu"), is_skip=True),
234+
"test_randint_distribution_dynamic_shapes": TestFailure(("cuda",)),
234235
"test_randn_generator_dynamic_shapes": TestFailure(("cpu",)),
235236
"test_randn_like_empty_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
236237
"test_single_elem_dynamic_shapes": TestFailure(("cpu",)),

test/inductor/test_torchinductor_dynamic_shapes.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
("cpu", "cuda", "xpu")
6060
),
6161
"test_conv_inference_heuristics_dynamic_shapes": TestFailure(("cuda", "xpu")),
62+
"test_randint_distribution_dynamic_shapes": TestFailure(("cuda",)),
6263
}
6364

6465
if TEST_WITH_ROCM:

test/test_tensor_creation_ops.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3499,6 +3499,24 @@ def seed(generator):
34993499
self.assertTrue((res1 < 6).all().item())
35003500
self.assertTrue((res1 >= 0).all().item())
35013501

3502+
3503+
def test_randint_distribution(self, device):
3504+
size = 1_000_000
3505+
n_max = int(0.75 * 2 ** 32)
3506+
n_bins = 8
3507+
3508+
def bin(index, max_size):
3509+
return index // (max_size // n_bins)
3510+
res = torch.randint(n_max, (size,), device=device)
3511+
# histogram implemented for float only
3512+
bins = bin(res, n_max).float().cpu()
3513+
hist, _ = bins.histogram(8, range=(0, n_bins))
3514+
expected_bin = res.shape[0] / 8
3515+
expected_error = math.sqrt(expected_bin) / expected_bin * 3
3516+
error = (hist - expected_bin).abs().max() / expected_bin
3517+
self.assertTrue(error < expected_error)
3518+
3519+
35023520
@dtypes(torch.half, torch.float, torch.bfloat16, torch.double,
35033521
torch.complex32, torch.complex64, torch.complex128)
35043522
def test_randn(self, device, dtype):

0 commit comments

Comments
 (0)
0