8000 [BE] Fix + parametrize `test_min_max_nan_propagation` (#144250) · pytorch/pytorch@ebeb433 · GitHub
[go: up one dir, main page]

Skip to content

Commit ebeb433

Browse files
malfetpytorchmergebot
authored andcommitted
[BE] Fix + parametrize test_min_max_nan_propagation (#144250)
- `dtype` was not passed as argument to `torch.rand` before - Condition bfloat16 testing on MacOS14+ Pull Request resolved: #144250 Approved by: https://github.com/Skylion007 ghstack dependencies: #144249
1 parent 11a0663 commit ebeb433

File tree

1 file changed

+16
-18
lines changed

1 file changed

+16
-18
lines changed

test/test_mps.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8621,29 +8621,27 @@ def test_min_max(self, dtype):
86218621
z_cpu = x_cpu.min()
86228622
self.assertEqual(z, z_cpu)
86238623

8624+
@parametrize("dtype", [torch.float32, torch.float16] + ([torch.bfloat16] if MACOS_VERSION >= 14.0 else []))
8625+
def test_min_max_nan_propagation(self, dtype):
8626+
cpu_x = torch.tensor([1.0, float("nan"), 3.0], device="cpu", dtype=dtype)
8627+
mps_x = cpu_x.detach().clone().to('mps')
86248628

8625-
def test_min_max_nan_propagation(self):
8626-
def helper(dtype):
8627-
cpu_x = torch.tensor([1.0, float("nan"), 3.0], device="cpu")
8628-
mps_x = cpu_x.detach().clone().to('mps')
8629-
8630-
cpu_max = torch.max(cpu_x)
8631-
mps_max = torch.max(mps_x).to('cpu')
8629+
cpu_max = torch.max(cpu_x)
8630+
mps_max = torch.max(mps_x).to('cpu')
86328631

8633-
cpu_amax = torch.amax(cpu_x)
8634-
mps_amax = torch.amax(mps_x).to('cpu')
8632+
cpu_amax = torch.amax(cpu_x)
8633+
mps_amax = torch.amax(mps_x).to('cpu')
86358634

8636-
cpu_min = torch.min(cpu_x)
8637-
mps_min = torch.min(mps_x).to('cpu')
8635+
cpu_min = torch.min(cpu_x)
8636+
mps_min = torch.min(mps_x).to('cpu')
86388637

8639-
cpu_amin = torch.amin(cpu_x)
8640-
mps_amin = torch.amin(mps_x).to('cpu')
8638+
cpu_amin = torch.amin(cpu_x)
8639+
mps_amin = torch.amin(mps_x).to('cpu')
86418640

8642-
self.assertEqual(cpu_max, mps_max)
8643-
self.assertEqual(cpu_amax, mps_amax)
8644-
self.assertEqual(cpu_min, mps_min)
8645-
self.assertEqual(cpu_amin, mps_amin)
8646-
[helper(dtype) for dtype in [torch.float32, torch.float16, torch.bfloat16]]
8641+
self.assertEqual(cpu_max, mps_max)
8642+
self.assertEqual(cpu_amax, mps_amax)
8643+
self.assertEqual(cpu_min, mps_min)
8644+
self.assertEqual(cpu_amin, mps_amin)
86478645

86488646
def test_isin(self):
86498647
def helper(dtype):

0 commit comments

Comments
 (0)
0