From 0ce7003bbbaf83b01f2dfc3cfcbc9e7d817d551e Mon Sep 17 00:00:00 2001 From: Andrei Ivanov Date: Tue, 14 Jan 2025 11:30:37 -0800 Subject: [PATCH] Fixing bug in `get_namespace_and_device`. --- sklearn/utils/_array_api.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py index 65503a0674a70..44a4af6397642 100644 --- a/sklearn/utils/_array_api.py +++ b/sklearn/utils/_array_api.py @@ -619,11 +619,10 @@ def get_namespace_and_device(*array_list, remove_none=True, remove_types=(str,)) skip_remove_kwargs = dict(remove_none=False, remove_types=[]) xp, is_array_api = get_namespace(*array_list, **skip_remove_kwargs) - arrays_device = device(*array_list, **skip_remove_kwargs) if is_array_api: - return xp, is_array_api, arrays_device + return xp, is_array_api, device(*array_list, **skip_remove_kwargs) else: - return xp, False, arrays_device + return xp, False, None def _expit(X, xp=None):