8000 FIX skip array API tests when running with device="mps" without the P… · scikit-learn/scikit-learn@b7d80a0 · GitHub
[go: up one dir, main page]

Skip to content

Commit b7d80a0

Browse files
betatimogrisel
andauthored
FIX skip array API tests when running with device="mps" without the PYTORCH_ENABLE_MPS_FALLBACK env var (#27199)
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
1 parent a31e108 commit b7d80a0

File tree

1 file changed

+18
-9
lines changed

sklearn/utils/_testing.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1071,17 +1071,26 @@ def _array_api_for_tests(array_namespace, device, dtype):
10711071
xp = array_api_compat.get_namespace(array_mod.asarray(1))
10721072
if array_namespace == "torch" and device == "cuda" and not xp.has_cuda:
10731073
raise SkipTest("PyTorch test requires cuda, which is not available")
1074-
elif array_namespace == "torch" and device == "mps" and not xp.has_mps:
1075-
if not xp.backends.mps.is_built():
1074+
elif array_namespace == "torch" and device == "mps":
1075+
if os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") != "1":
1076+
# For now we need PYTORCH_ENABLE_MPS_FALLBACK=1 for all estimators to work
1077+
# when using the MPS device.
10761078
raise SkipTest(
1077-
"MPS is not available because the current PyTorch install was not "
1078-
"built with MPS enabled."
1079-
)
1080-
else:
1081-
raise SkipTest(
1082-
"MPS is not available because the current MacOS version is not 12.3+ "
1083-
"and/or you do not have an MPS-enabled device on this machine."
1079+
"Skipping MPS device test because PYTORCH_ENABLE_MPS_FALLBACK is not "
1080+
"set."
10841081
)
1082+
if not xp.has_mps:
1083+
if not xp.backends.mps.is_built():
1084+
raise SkipTest(
1085+
"MPS is not available because the current PyTorch install was not "
1086+
"built with MPS enabled."
1087+
)
1088+
else:
1089+
raise SkipTest(
1090+
"MPS is not available because the current MacOS version is not"
1091+
" 12.3+ and/or you do not have an MPS-enabled device on this"
1092+
" machine."
1093+
)
10851094
elif array_namespace in {"cupy", "cupy.array_api"}: # pragma: nocover
10861095
import cupy
10871096

0 commit comments

Comments
 (0)
0