8000 Workaround MPS bug in torch.nextafter · scikit-learn/scikit-learn@7a2f907 · GitHub
[go: up one dir, main page]

Skip to content

Commit 7a2f907

Browse files
committed
Workaround MPS bug in torch.nextafter
1 parent 020cbe8 commit 7a2f907

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

sklearn/utils/tests/test_stats.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,14 @@ def test_weighted_percentile_array_api_consistency(
186186

187187
xp = _array_api_for_tests(array_namespace, device)
188188

189+
# Skip test for percentile=0 edge case (#20528) on namespace/device where
190+
# xp.nextafter is broken (e.g. torch with MPS device at the time of
191+
# writing).
192+
zero = xp.zeros(1, device=device)
193+
one = xp.ones(1, device=device)
194+
if percentile == 0 and xp.all(xp.nextafter(zero, one) == zero):
195+
pytest.xfail(f"xp.nextafter is broken on {device}")
196+
189197
rng = np.random.RandomState(global_random_seed)
190198
X_np = data(rng) if callable(data) else data
191199
weights_np = weights(rng) if callable(weights) else weights

0 commit comments

Comments
 (0)
0