|
38 | 38 | from torch.testing._internal.common_utils import ( |
39 | 39 | DeterministicGuard, |
40 | 40 | freeze_rng_state, |
| 41 | + instantiate_parametrized_tests, |
41 | 42 | IS_FBCODE, |
42 | 43 | MI350_ARCH, |
| 44 | + parametrize, |
43 | 45 | skipIfRocmArch, |
44 | 46 | TEST_WITH_ASAN, |
45 | 47 | TEST_WITH_ROCM, |
|
85 | 87 | aten = torch.ops.aten |
86 | 88 |
|
87 | 89 |
|
| 90 | +@instantiate_parametrized_tests |
88 | 91 | class CudaReproTests(TestCase): |
89 | 92 | device = "cuda" |
90 | 93 | common = check_model_cuda |
@@ -2441,6 +2444,60 @@ def forward(self, x): |
2441 | 2444 | f"Max diff: {torch.max(torch.abs(eager_output - compiled_output)):.6f}", |
2442 | 2445 | ) |
2443 | 2446 |
|
| 2447 | + @parametrize( |
| 2448 | + "quantiles_shape,quantiles_strides,batch_size", |
| 2449 | + [ |
| 2450 | + ((100, 10), (10, 1), 16), # Contiguous C-order |
| 2451 | + ((100, 10), (1, 100), 16), # Transposed/F-order |
| 2452 | + ((80, 12), (1, 80), 16), # Transposed different size |
| 2453 | + ((50, 20), (1, 50), 16), # Transposed medium |
| 2454 | + ((200, 8), (1, 200), 16), # Transposed large x small |
| 2455 | + ((25, 40), (1, 25), 16), # Transposed small x large |
| 2456 | + ((20, 5, 8), (40, 1, 5), 16), # 3D case with mixed strides |
| 2457 | + ((20, 5, 8), (1, 20, 100), 16), # 3D case different stride order |
| 2458 | + ], |
| 2459 | + ) |
| 2460 | + def test_searchsorted_stride_permutations( |
| 2461 | + self, quantiles_shape, quantiles_strides, batch_size |
| 2462 | + ): |
| 2463 | + class Foo(torch.nn.Module): |
| 2464 | + def __init__(self, quantiles: torch.Tensor) -> None: |
| 2465 | + super().__init__() |
| 2466 | + assert quantiles.shape[0] > 0 |
| 2467 | + quantiles = quantiles.T |
| 2468 | + self.q = torch.nn.Parameter(quantiles, requires_grad=False) |
| 2469 | + |
| 2470 | + def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 2471 | + return torch.searchsorted(self.q, x.T).T |
| 2472 | + |
| 2473 | + torch.manual_seed(42) |
| 2474 | + |
| 2475 | + # Create contiguous tensor first |
| 2476 | + numel = 1 |
| 2477 | + for dim in quantiles_shape: |
| 2478 | + numel *= dim |
| 2479 | + data = torch.randn(numel, dtype=torch.float32, device="cuda") |
| 2480 | + |
| 2481 | + # Create tensor with specified shape and strides |
| 2482 | + quantiles = torch.as_strided( |
| 2483 | + data, size=quantiles_shape, stride=quantiles_strides |
| 2484 | + ) |
| 2485 | + |
| 2486 | + quantiles = torch.sort(quantiles, dim=0)[0] |
| 2487 | + |
| 2488 | + x_shape = (batch_size,) + quantiles_shape[1:] |
| 2489 | + x = torch.randn(*x_shape, dtype=torch.float32, device="cuda") |
| 2490 | + |
| 2491 | + foo = Foo(quantiles) |
| 2492 | + foo_compiled = torch.compile(Foo(quantiles), fullgraph=True) |
| 2493 | + |
| 2494 | + # Test eager vs compiled |
| 2495 | + with torch.no_grad(): |
| 2496 | + eager = foo(x) |
| 2497 | + compiled = foo_compiled(x) |
| 2498 | + |
| 2499 | + self.assertEqual(eager, compiled) |
| 2500 | + |
2444 | 2501 | def test_identity_load(self): |
2445 | 2502 | device = "cuda" |
2446 | 2503 |
|
|
0 commit comments