8000 [wip] fix searchsorted non dense (#165064) · pytorch/pytorch@f5543e3 · GitHub
[go: up one dir, main page]

Skip to content

Commit f5543e3

Browse files
eellisonpytorchmergebot
authored andcommitted
[wip] fix searchsorted non dense (#165064)
Fix for #163528 Pull Request resolved: #165064 Approved by: https://github.com/benjaminglass1, https://github.com/mlazos
1 parent 5fc2c7a commit f5543e3

File tree

2 files changed

+67
-3
lines changed

2 files changed

+67
-3
lines changed

test/inductor/test_cuda_repro.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,10 @@
3838
from torch.testing._internal.common_utils import (
3939
DeterministicGuard,
4040
freeze_rng_state,
41+
instantiate_parametrized_tests,
4142
IS_FBCODE,
4243
MI350_ARCH,
44+
parametrize,
4345
skipIfRocmArch,
4446
TEST_WITH_ASAN,
4547
TEST_WITH_ROCM,
@@ -85,6 +87,7 @@
8587
aten = torch.ops.aten
8688

8789

90+
@instantiate_parametrized_tests
8891
class CudaReproTests(TestCase):
8992
device = "cuda"
9093
common = check_model_cuda
@@ -2441,6 +2444,60 @@ def forward(self, x):
24412444
f"Max diff: {torch.max(torch.abs(eager_output - compiled_output)):.6f}",
24422445
)
24432446

2447+
@parametrize(
2448+
"quantiles_shape,quantiles_strides,batch_size",
2449+
[
2450+
((100, 10), (10, 1), 16), # Contiguous C-order
2451+
((100, 10), (1, 100), 16), # Transposed/F-order
2452+
((80, 12), (1, 80), 16), # Transposed different size
2453+
((50, 20), (1, 50), 16), # Transposed medium
2454+
((200, 8), (1, 200), 16), # Transposed large x small
2455+
((25, 40), (1, 25), 16), # Transposed small x large
2456+
((20, 5, 8), (40, 1, 5), 16), # 3D case with mixed strides
2457+
((20, 5, 8), (1, 20, 100), 16), # 3D case different stride order
2458+
],
2459+
)
2460+
def test_searchsorted_stride_permutations(
2461+
self, quantiles_shape, quantiles_strides, batch_size
2462+
):
2463+
class Foo(torch.nn.Module):
2464+
def __init__(self, quantiles: torch.Tensor) -> None:
2465+
super().__init__()
2466+
assert quantiles.shape[0] > 0
2467+
quantiles = quantiles.T
2468+
self.q = torch.nn.Parameter(quantiles, requires_grad=False)
2469+
2470+
def forward(self, x: torch.Tensor) -> torch.Tensor:
2471+
return torch.searchsorted(self.q, x.T).T
2472+
2473+
torch.manual_seed(42)
2474+
2475+
# Create contiguous tensor first
2476+
numel = 1
2477+
for dim in quantiles_shape:
2478+
numel *= dim
2479+
data = torch.randn(numel, dtype=torch.float32, device="cuda")
2480+
2481+
# Create tensor with specified shape and strides
2482+
quantiles = torch.as_strided(
2483+
data, size=quantiles_shape, stride=quantiles_strides
2484+
)
2485+
2486+
quantiles = torch.sort(quantiles, dim=0)[0]
2487+
2488+
x_shape = (batch_size,) + quantiles_shape[1:]
2489+
x = torch.randn(*x_shape, dtype=torch.float32, device="cuda")
2490+
2491+
foo = Foo(quantiles)
2492+
foo_compiled = torch.compile(Foo(quantiles), fullgraph=True)
2493+
2494+
# Test eager vs compiled
2495+
with torch.no_grad():
2496+
eager = foo(x)
2497+
compiled = foo_compiled(x)
2498+
2499+
self.assertEqual(eager, compiled)
2500+
24442501
def test_identity_load(self):
24452502
device = "cuda"
24462503

torch/_inductor/lowering.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2490,11 +2490,18 @@ def inner_fn(index):
24902490

24912491

24922492
def _boundaries_helper(tb: TensorBox) -> tuple[str, sympy.Expr, sympy.Expr, sympy.Expr]:
2493+
# Calculate the maximum offset for the boundaries tensor
2494+
# For a strided tensor, this is sum((size[i] - 1) * stride[i]) + stride[-1]
2495+
# This ensures the mask check in bucketize_binary_search works correctly
2496+
# for both contiguous and non-contiguous tensors.
2497+
size = tb.get_size()
2498+
stride = tb.get_stride()
2499+
max_offset = sum((s - 1) * st for s, st in zip(size, stride)) + stride[-1]
24932500
return (
24942501
tb.get_name(),
2495-
tb.get_size()[-1],
2496-
tb.get_size()[0] * tb.get_stride()[0],
2497-
tb.get_stride()[-1],
2502+
size[-1],
2503+
max_offset,
2504+
stride[-1],
24982505
)
24992506

25002507

0 commit comments

Comments
 (0)
0