10000 fix randint distribution for large max by ngimel · Pull Request #143787 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

fix randint distribution for large max #143787

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions aten/src/ATen/core/DistributionsHelper.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,7 @@ struct uniform_int_from_to_distribution {

template <typename RNG>
C10_HOST_DEVICE inline T operator()(RNG generator) {
if ((
std::is_same_v<T, int64_t> ||
std::is_same_v<T, double> ||
std::is_same_v<T, float> ||
std::is_same_v<T, at::BFloat16>) && range_ >= 1ULL << 32)
if (range_ >= 1ULL << 28) // allow approx 5% skew in uniform int generation using %
{
return transformation::uniform_int_from_to<T>(generator->random64(), range_, base_);
} else {
Expand Down
6 changes: 1 addition & 5 deletions aten/src/ATen/native/cuda/DistributionTemplates.h
Original file line number Diff line number Diff line change
Expand Up @@ -280,11 +280,7 @@ namespace cuda {
template<typename RNG>
void random_from_to_kernel(TensorIteratorBase& iter, uint64_t range, int64_t base, RNG gen) {
AT_DISPATCH_V2(iter.dtype(), "random_from_to_kernel_cuda", AT_WRAP([&] {
if ((
std::is_same_v<scalar_t, int64_t> ||
std::is_same_v<scalar_t, double> ||
std::is_same_v<scalar_t, float> ||
std::is_same_v<scalar_t, at::BFloat16>) && range >= 1ULL << 32)
if (range >= 1ULL << 28) // allow approx 5% skew in uniform int generation using %
{
// define lambda to mod with range and add base
auto random_func = [range, base] __device__ (uint64_t rand) {
Expand Down
4 changes: 3 additions & 1 deletion aten/src/ATen/test/rng_test.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,9 @@ void test_random_from_to(const at::Device& device) {
range = static_cast<uint64_t>(max_to) - static_cast<uint64_t>(from) + 1;
from_case_covered = true;
}
if (range < (1ULL << 32)) {
// this is leaking details of implementation into test
// we are starting to use random64() at 2^28 to minimize skew due to %
if (range < (1ULL << 28)) {
exp = static_cast<T>(static_cast<int64_t>((static_cast<uint32_t>(val) % range + from)));
} else {
exp = static_cast<T>(static_cast<int64_t>((val % range + from)));
Expand Down
20 changes: 20 additions & 0 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8658,6 +8658,26 @@ def fn(x):
self.assertGreater(c0.max(), 2**40)
self.assertLess(c0.max(), 2**50)

def test_randint_distribution(self):
@torch.compile(fullgraph=True)
def fn(n_argsmax, size):
return torch.randint(n_max, (size,), device=self.device)

def bin(index, max_size):
return index // (max_size // n_bins)

size = 1_000_000
n_max = int(0.75 * 2**32)
n_bins = 8

res = fn(n_max, size)
bins = bin(res, n_max).float().cpu()
hist, _ = bins.histogram(8, range=(0, n_bins))
expected_bin = res.shape[0] / 8
expected_error = math.sqrt(expected_bin) / expected_bin * 3
error = (hist - expected_bin).abs().max() / expected_bin
self.assertTrue(error < expected_error)

@config.patch(fallback_random=True)
def test_like_rands(self):
def fn(x):
Expand Down
1 change: 1 addition & 0 deletions test/inductor/test_torchinductor_codegen_dynamic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ def run(*ex, **kwargs):
"test_pointwise_laguerre_polynomial_l_dynamic_shapes": TestFailure(("cuda", "xpu")),
"test_pointwise_legendre_polynomial_p_dynamic_shapes": TestFailure(("cuda", "xpu")),
"test_polar_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu"), is_skip=True),
"test_randint_distribution_dynamic_shapes": TestFailure(("cuda",)),
"test_randn_generator_dynamic_shapes": TestFailure(("cpu",)),
"test_randn_like_empty_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
"test_single_elem_dynamic_shapes": TestFailure(("cpu",)),
Expand Down
1 change: 1 addition & 0 deletions test/inductor/test_torchinductor_dynamic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
("cpu", "cuda", "xpu")
),
"test_conv_inference_heuristics_dynamic_shapes": TestFailure(("cuda", "xpu")),
"test_randint_distribution_dynamic_shapes": TestFailure(("cuda",)),
}

if TEST_WITH_ROCM:
Expand Down
18 changes: 18 additions & 0 deletions test/test_tensor_creation_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3495,6 +3495,24 @@ def seed(generator):
self.assertTrue((res1 < 6).all().item())
self.assertTrue((res1 >= 0).all().item())


def test_randint_distribution(self, device):
size = 1_000_000
n_max = int(0.75 * 2 ** 32)
n_bins = 8

def bin(index, max_size):
return index // (max_size // n_bins)
res = torch.randint(n_max, (size,), device=device)
# histogram implemented for float only
bins = bin(res, n_max).float().cpu()
hist, _ = bins.histogram(8, range=(0, n_bins))
expected_bin = res.shape[0] / 8
expected_error = math.sqrt(expected_bin) / expected_bin * 3
error = (hist - expected_bin).abs().max() / expected_bin
self.assertTrue(error < expected_error)


@dtypes(torch.half, torch.float, torch.bfloat16, torch.double,
torch.complex32, torch.complex64, torch.complex128)
def test_randn(self, device, dtype):
Expand Down
Loading
0