diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py index 8374dc35ff4f0..8d1ef05d9ee65 100644 --- a/sklearn/utils/_array_api.py +++ b/sklearn/utils/_array_api.py @@ -304,6 +304,7 @@ def isdtype(self, dtype, kind): def _check_device_cpu(device): # noqa + device = getattr(device, "type", device) if device not in {"cpu", None}: raise ValueError(f"Unsupported device for NumPy: {device!r}")