From 73bd27cce5306e69f0bf0465cd02e1d0b91f5696 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Thu, 23 Mar 2023 13:29:21 -0400 Subject: [PATCH 01/42] ENH Adds PyTorch support to LinearDiscriminantAnalysis --- doc/modules/array_api.rst | 30 +- doc/whats_new/v1.3.rst | 7 + maint_tools/vendor_array_api_compat.sh | 13 + setup.cfg | 2 + sklearn/discriminant_analysis.py | 14 +- .../externals/_array_api_compat/__init__.py | 22 + .../externals/_array_api_compat/_internal.py | 43 ++ .../_array_api_compat/common/__init__.py | 1 + .../_array_api_compat/common/_aliases.py | 523 ++++++++++++++ .../_array_api_compat/common/_helpers.py | 229 ++++++ .../_array_api_compat/common/_linalg.py | 146 ++++ .../_array_api_compat/common/_typing.py | 20 + .../_array_api_compat/cupy/__init__.py | 16 + .../_array_api_compat/cupy/_aliases.py | 69 ++ .../_array_api_compat/cupy/_typing.py | 46 ++ .../_array_api_compat/cupy/linalg.py | 41 ++ .../_array_api_compat/numpy/__init__.py | 22 + .../_array_api_compat/numpy/_aliases.py | 69 ++ .../_array_api_compat/numpy/_typing.py | 46 ++ .../_array_api_compat/numpy/linalg.py | 34 + .../_array_api_compat/torch/__init__.py | 22 + .../_array_api_compat/torch/_aliases.py | 666 ++++++++++++++++++ .../_array_api_compat/torch/linalg.py | 27 + sklearn/linear_model/_base.py | 4 +- sklearn/tests/test_discriminant_analysis.py | 57 ++ sklearn/utils/_array_api.py | 165 +++-- sklearn/utils/extmath.py | 4 +- sklearn/utils/multiclass.py | 4 +- sklearn/utils/tests/test_array_api.py | 100 ++- sklearn/utils/validation.py | 23 +- 30 files changed, 2343 insertions(+), 122 deletions(-) create mode 100755 maint_tools/vendor_array_api_compat.sh create mode 100644 sklearn/externals/_array_api_compat/__init__.py create mode 100644 sklearn/externals/_array_api_compat/_internal.py create mode 100644 sklearn/externals/_array_api_compat/common/__init__.py create mode 100644 sklearn/externals/_array_api_compat/common/_aliases.py create mode 100644 sklearn/externals/_array_api_compat/common/_helpers.py create mode 100644 sklearn/externals/_array_api_compat/common/_linalg.py create mode 100644 sklearn/externals/_array_api_compat/common/_typing.py create mode 100644 sklearn/externals/_array_api_compat/cupy/__init__.py create mode 100644 sklearn/externals/_array_api_compat/cupy/_aliases.py create mode 100644 sklearn/externals/_array_api_compat/cupy/_typing.py create mode 100644 sklearn/externals/_array_api_compat/cupy/linalg.py create mode 100644 sklearn/externals/_array_api_compat/numpy/__init__.py create mode 100644 sklearn/externals/_array_api_compat/numpy/_aliases.py create mode 100644 sklearn/externals/_array_api_compat/numpy/_typing.py create mode 100644 sklearn/externals/_array_api_compat/numpy/linalg.py create mode 100644 sklearn/externals/_array_api_compat/torch/__init__.py create mode 100644 sklearn/externals/_array_api_compat/torch/_aliases.py create mode 100644 sklearn/externals/_array_api_compat/torch/linalg.py diff --git a/doc/modules/array_api.rst b/doc/modules/array_api.rst index 0d89ec2ef5879..7b4e3453badb1 100644 --- a/doc/modules/array_api.rst +++ b/doc/modules/array_api.rst @@ -23,8 +23,8 @@ At this stage, this support is **considered experimental** and must be enabled explicitly as explained in the following. .. note:: - Currently, only `cupy.array_api` and `numpy.array_api` are known to work - with scikit-learn's estimators. + Currently, only `cupy.array_api`, `numpy.array_api`, `cupy`, and `PyTorch` + are known to work with scikit-learn's estimators. Example usage ============= @@ -36,11 +36,11 @@ Here is an example code snippet to demonstrate how to use `CuPy >>> from sklearn.datasets import make_classification >>> from sklearn import config_context >>> from sklearn.discriminant_analysis import LinearDiscriminantAnalysis - >>> import cupy.array_api as xp + >>> import cupy >>> X_np, y_np = make_classification(random_state=0) - >>> X_cu = xp.asarray(X_np) - >>> y_cu = xp.asarray(y_np) + >>> X_cu = cupy.asarray(X_np) + >>> y_cu = cupy.asarray(y_np) >>> X_cu.device @@ -57,12 +57,30 @@ GPU. We provide a experimental `_estimator_with_converted_arrays` utility that transfers an estimator attributes from Array API to a ndarray:: >>> from sklearn.utils._array_api import _estimator_with_converted_arrays - >>> cupy_to_ndarray = lambda array : array._array.get() + >>> cupy_to_ndarray = lambda array : array.get() >>> lda_np = _estimator_with_converted_arrays(lda, cupy_to_ndarray) >>> X_trans = lda_np.transform(X_np) >>> type(X_trans) +PyTorch Support +--------------- + +PyTorch Tensors are supported by setting `array_api_dispatch=True` and passing in +the tensors directly:: + + >>> import torch + >>> X_torch = torch.asarray(X_np, device="cuda", dtype=torch.float32) + >>> y_torch = torch.asarray(y_np, device="cuda", dtype=torch.float32) + + >>> with config_context(array_api_dispatch=True): + ... lda = LinearDiscriminantAnalysis() + ... X_trans = lda.fit_transform(X_torch, y_torch) + >>> type(X_trans) + + >>> X_trans.device.type + 'cuda' + .. _array_api_estimators: Estimators with support for `Array API`-compatible inputs diff --git a/doc/whats_new/v1.3.rst b/doc/whats_new/v1.3.rst index c50672a712f93..ce40c38aff299 100644 --- a/doc/whats_new/v1.3.rst +++ b/doc/whats_new/v1.3.rst @@ -199,6 +199,13 @@ Changelog :class:`decomposition.MiniBatchNMF` which can produce different results than previous versions. :pr:`25438` by :user:`Yotam Avidar-Constantini `. + :mod:`sklearn.discriminant_analysis` +.................................... + +- |Enhancement| :class:`discriminant_analysis.LinearDiscriminantAnalysis` now + supports the `PyTorch `__. See + :ref:`array_api` for more details. :pr:`xxxxx` by `Thomas Fan`_. + :mod:`sklearn.ensemble` ....................... diff --git a/maint_tools/vendor_array_api_compat.sh b/maint_tools/vendor_array_api_compat.sh new file mode 100755 index 0000000000000..7ce2b7995b72d --- /dev/null +++ b/maint_tools/vendor_array_api_compat.sh @@ -0,0 +1,13 @@ +#!/bin/bash + +# Vendors https://github.com/data-apis/array-api-compat/ into sklearn/externals + +ARRAY_API_COMPAT_SHA="b32a5b32892f5f4b5052ef54a04b8ed51936b008" +URL="https://github.com/data-apis/array-api-compat/archive/$ARRAY_API_COMPAT_SHA.tar.gz" + +rm -rf sklearn/externals/_array_api_compat + +curl -s -L $URL | + tar xvz --strip-components=1 -C sklearn/externals array-api-compat-$ARRAY_API_COMPAT_SHA/array_api_compat + +mv sklearn/externals/array_api_compat sklearn/externals/_array_api_compat diff --git a/setup.cfg b/setup.cfg index 3ed576cedf92f..dc18059dca8a9 100644 --- a/setup.cfg +++ b/setup.cfg @@ -76,6 +76,8 @@ per-file-ignores = [mypy] ignore_missing_imports = True allow_redefinition = True +exclude = + sklearn/externals [check-manifest] # ignore files missing in VCS diff --git a/sklearn/discriminant_analysis.py b/sklearn/discriminant_analysis.py index 0017c218e2fe0..52a4b495f8c23 100644 --- a/sklearn/discriminant_analysis.py +++ b/sklearn/discriminant_analysis.py @@ -21,7 +21,7 @@ from .covariance import ledoit_wolf, empirical_covariance, shrunk_covariance from .utils.multiclass import unique_labels from .utils.validation import check_is_fitted -from .utils._array_api import get_namespace, _expit +from .utils._array_api import get_namespace, _expit, device, size from .utils.multiclass import check_classification_targets from .utils.extmath import softmax from .utils._param_validation import StrOptions, Interval, HasMethods @@ -109,7 +109,7 @@ def _class_means(X, y): """ xp, is_array_api = get_namespace(X) classes, y = xp.unique_inverse(y) - means = xp.zeros(shape=(classes.shape[0], X.shape[1])) + means = xp.zeros((classes.shape[0], X.shape[1]), device=device(X), dtype=X.dtype) if is_array_api: for i in range(classes.shape[0]): @@ -586,9 +586,9 @@ def fit(self, X, y): if self.priors is None: # estimate priors from sample _, cnts = xp.unique_counts(y) # non-negative ints - self.priors_ = xp.astype(cnts, xp.float64) / float(y.shape[0]) + self.priors_ = xp.astype(cnts, X.dtype) / float(y.shape[0]) else: - self.priors_ = xp.asarray(self.priors) + self.priors_ = xp.asarray(self.priors, dtype=X.dtype) if xp.any(self.priors_ < 0): raise ValueError("priors must be non-negative") @@ -634,13 +634,13 @@ def fit(self, X, y): shrinkage=self.shrinkage, covariance_estimator=self.covariance_estimator, ) - if self.classes_.size == 2: # treat binary case as a special case + if size(self.classes_) == 2: # treat binary case as a special case coef_ = xp.asarray(self.coef_[1, :] - self.coef_[0, :], dtype=X.dtype) self.coef_ = xp.reshape(coef_, (1, -1)) intercept_ = xp.asarray( self.intercept_[1] - self.intercept_[0], dtype=X.dtype ) - self.intercept_ = xp.reshape(intercept_, 1) + self.intercept_ = xp.reshape(intercept_, (1,)) self._n_features_out = self._max_components return self @@ -690,7 +690,7 @@ def predict_proba(self, X): check_is_fitted(self) xp, is_array_api = get_namespace(X) decision = self.decision_function(X) - if self.classes_.size == 2: + if size(self.classes_) == 2: proba = _expit(decision) return xp.stack([1 - proba, proba], axis=1) else: diff --git a/sklearn/externals/_array_api_compat/__init__.py b/sklearn/externals/_array_api_compat/__init__.py new file mode 100644 index 0000000000000..c92d3d89e3c63 --- /dev/null +++ b/sklearn/externals/_array_api_compat/__init__.py @@ -0,0 +1,22 @@ +""" +NumPy Array API compatibility library + +This is a small wrapper around NumPy and CuPy that is compatible with the +Array API standard https://data-apis.org/array-api/latest/. See also NEP 47 +https://numpy.org/neps/nep-0047-array-api-standard.html. + +Unlike numpy.array_api, this is not a strict minimal implementation of the +Array API, but rather just an extension of the main NumPy namespace with +changes needed to be compliant with the Array API. See +https://numpy.org/doc/stable/reference/array_api.html for a full list of +changes. In particular, unlike numpy.array_api, this package does not use a +separate Array object, but rather just uses numpy.ndarray directly. + +Library authors using the Array API may wish to test against numpy.array_api +to ensure they are not using functionality outside of the standard, but prefer +this implementation for the default when working with NumPy arrays. + +""" +__version__ = '1.1.1' + +from .common import * diff --git a/sklearn/externals/_array_api_compat/_internal.py b/sklearn/externals/_array_api_compat/_internal.py new file mode 100644 index 0000000000000..553c03561b45e --- /dev/null +++ b/sklearn/externals/_array_api_compat/_internal.py @@ -0,0 +1,43 @@ +""" +Internal helpers +""" + +from functools import wraps +from inspect import signature + +def get_xp(xp): + """ + Decorator to automatically replace xp with the corresponding array module. + + Use like + + import numpy as np + + @get_xp(np) + def func(x, /, xp, kwarg=None): + return xp.func(x, kwarg=kwarg) + + Note that xp must be a keyword argument and come after all non-keyword + arguments. + + """ + def inner(f): + @wraps(f) + def wrapped_f(*args, **kwargs): + return f(*args, xp=xp, **kwargs) + + sig = signature(f) + new_sig = sig.replace(parameters=[sig.parameters[i] for i in sig.parameters if i != 'xp']) + + if wrapped_f.__doc__ is None: + wrapped_f.__doc__ = f"""\ +Array API compatibility wrapper for {f.__name__}. + +See the corresponding documentation in NumPy/CuPy and/or the array API +specification for more details. + +""" + wrapped_f.__signature__ = new_sig + return wrapped_f + + return inner diff --git a/sklearn/externals/_array_api_compat/common/__init__.py b/sklearn/externals/_array_api_compat/common/__init__.py new file mode 100644 index 0000000000000..ce3f44dd486cb --- /dev/null +++ b/sklearn/externals/_array_api_compat/common/__init__.py @@ -0,0 +1 @@ +from ._helpers import * diff --git a/sklearn/externals/_array_api_compat/common/_aliases.py b/sklearn/externals/_array_api_compat/common/_aliases.py new file mode 100644 index 0000000000000..87f0d766db03d --- /dev/null +++ b/sklearn/externals/_array_api_compat/common/_aliases.py @@ -0,0 +1,523 @@ +""" +These are functions that are just aliases of existing functions in NumPy. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from typing import Optional, Sequence, Tuple, Union, List + from ._typing import ndarray, Device, Dtype, NestedSequence, SupportsBufferProtocol + +from typing import NamedTuple +from types import ModuleType +import inspect + +from ._helpers import _check_device, _is_numpy_array, array_namespace + +# These functions are modified from the NumPy versions. + +def arange( + start: Union[int, float], + /, + stop: Optional[Union[int, float]] = None, + step: Union[int, float] = 1, + *, + xp, + dtype: Optional[Dtype] = None, + device: Optional[Device] = None, + **kwargs +) -> ndarray: + _check_device(xp, device) + return xp.arange(start, stop=stop, step=step, dtype=dtype, **kwargs) + +def empty( + shape: Union[int, Tuple[int, ...]], + xp, + *, + dtype: Optional[Dtype] = None, + device: Optional[Device] = None, + **kwargs +) -> ndarray: + _check_device(xp, device) + return xp.empty(shape, dtype=dtype, **kwargs) + +def empty_like( + x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None, + **kwargs +) -> ndarray: + _check_device(xp, device) + return xp.empty_like(x, dtype=dtype, **kwargs) + +def eye( + n_rows: int, + n_cols: Optional[int] = None, + /, + *, + xp, + k: int = 0, + dtype: Optional[Dtype] = None, + device: Optional[Device] = None, + **kwargs, +) -> ndarray: + _check_device(xp, device) + return xp.eye(n_rows, M=n_cols, k=k, dtype=dtype, **kwargs) + +def full( + shape: Union[int, Tuple[int, ...]], + fill_value: Union[int, float], + xp, + *, + dtype: Optional[Dtype] = None, + device: Optional[Device] = None, + **kwargs, +) -> ndarray: + _check_device(xp, device) + return xp.full(shape, fill_value, dtype=dtype, **kwargs) + +def full_like( + x: ndarray, + /, + fill_value: Union[int, float], + *, + xp, + dtype: Optional[Dtype] = None, + device: Optional[Device] = None, + **kwargs, +) -> ndarray: + _check_device(xp, device) + return xp.full_like(x, fill_value, dtype=dtype, **kwargs) + +def linspace( + start: Union[int, float], + stop: Union[int, float], + /, + num: int, + *, + xp, + dtype: Optional[Dtype] = None, + device: Optional[Device] = None, + endpoint: bool = True, + **kwargs, +) -> ndarray: + _check_device(xp, device) + return xp.linspace(start, stop, num, dtype=dtype, endpoint=endpoint, **kwargs) + +def ones( + shape: Union[int, Tuple[int, ...]], + xp, + *, + dtype: Optional[Dtype] = None, + device: Optional[Device] = None, + **kwargs, +) -> ndarray: + _check_device(xp, device) + return xp.ones(shape, dtype=dtype, **kwargs) + +def ones_like( + x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None, + **kwargs, +) -> ndarray: + _check_device(xp, device) + return xp.ones_like(x, dtype=dtype, **kwargs) + +def zeros( + shape: Union[int, Tuple[int, ...]], + xp, + *, + dtype: Optional[Dtype] = None, + device: Optional[Device] = None, + **kwargs, +) -> ndarray: + _check_device(xp, device) + return xp.zeros(shape, dtype=dtype, **kwargs) + +def zeros_like( + x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None, + **kwargs, +) -> ndarray: + _check_device(xp, device) + return xp.zeros_like(x, dtype=dtype, **kwargs) + +# np.unique() is split into four functions in the array API: +# unique_all, unique_counts, unique_inverse, and unique_values (this is done +# to remove polymorphic return types). + +# The functions here return namedtuples (np.unique() returns a normal +# tuple). +class UniqueAllResult(NamedTuple): + values: ndarray + indices: ndarray + inverse_indices: ndarray + counts: ndarray + + +class UniqueCountsResult(NamedTuple): + values: ndarray + counts: ndarray + + +class UniqueInverseResult(NamedTuple): + values: ndarray + inverse_indices: ndarray + + +def _unique_kwargs(xp): + # Older versions of NumPy and CuPy do not have equal_nan. Rather than + # trying to parse version numbers, just check if equal_nan is in the + # signature. + s = inspect.signature(xp.unique) + if 'equal_nan' in s.parameters: + return {'equal_nan': False} + return {} + +def unique_all(x: ndarray, /, xp) -> UniqueAllResult: + kwargs = _unique_kwargs(xp) + values, indices, inverse_indices, counts = xp.unique( + x, + return_counts=True, + return_index=True, + return_inverse=True, + **kwargs, + ) + # np.unique() flattens inverse indices, but they need to share x's shape + # See https://github.com/numpy/numpy/issues/20638 + inverse_indices = inverse_indices.reshape(x.shape) + return UniqueAllResult( + values, + indices, + inverse_indices, + counts, + ) + + +def unique_counts(x: ndarray, /, xp) -> UniqueCountsResult: + kwargs = _unique_kwargs(xp) + res = xp.unique( + x, + return_counts=True, + return_index=False, + return_inverse=False, + **kwargs + ) + + return UniqueCountsResult(*res) + + +def unique_inverse(x: ndarray, /, xp) -> UniqueInverseResult: + kwargs = _unique_kwargs(xp) + values, inverse_indices = xp.unique( + x, + return_counts=False, + return_index=False, + return_inverse=True, + **kwargs, + ) + # xp.unique() flattens inverse indices, but they need to share x's shape + # See https://github.com/numpy/numpy/issues/20638 + inverse_indices = inverse_indices.reshape(x.shape) + return UniqueInverseResult(values, inverse_indices) + + +def unique_values(x: ndarray, /, xp) -> ndarray: + kwargs = _unique_kwargs(xp) + return xp.unique( + x, + return_counts=False, + return_index=False, + return_inverse=False, + **kwargs, + ) + +def astype(x: ndarray, dtype: Dtype, /, *, copy: bool = True) -> ndarray: + if not copy and dtype == x.dtype: + return x + return x.astype(dtype=dtype, copy=copy) + +# These functions have different keyword argument names + +def std( + x: ndarray, + /, + xp, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + correction: Union[int, float] = 0.0, # correction instead of ddof + keepdims: bool = False, + **kwargs, +) -> ndarray: + return xp.std(x, axis=axis, ddof=correction, keepdims=keepdims, **kwargs) + +def var( + x: ndarray, + /, + xp, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + correction: Union[int, float] = 0.0, # correction instead of ddof + keepdims: bool = False, + **kwargs, +) -> ndarray: + return xp.var(x, axis=axis, ddof=correction, keepdims=keepdims, **kwargs) + +# Unlike transpose(), the axes argument to permute_dims() is required. +def permute_dims(x: ndarray, /, axes: Tuple[int, ...], xp) -> ndarray: + return xp.transpose(x, axes) + +# Creation functions add the device keyword (which does nothing for NumPy) + +# asarray also adds the copy keyword +def _asarray( + obj: Union[ + ndarray, + bool, + int, + float, + NestedSequence[bool | int | float], + SupportsBufferProtocol, + ], + /, + *, + dtype: Optional[Dtype] = None, + device: Optional[Device] = None, + copy: "Optional[Union[bool, np._CopyMode]]" = None, + namespace = None, + **kwargs, +) -> ndarray: + """ + Array API compatibility wrapper for asarray(). + + See the corresponding documentation in NumPy/CuPy and/or the array API + specification for more details. + + """ + if namespace is None: + try: + xp = array_namespace(obj, _use_compat=False) + except ValueError: + # TODO: What about lists of arrays? + raise ValueError("A namespace must be specified for asarray() with non-array input") + elif isinstance(namespace, ModuleType): + xp = namespace + elif namespace == 'numpy': + import numpy as xp + elif namespace == 'cupy': + import cupy as xp + else: + raise ValueError("Unrecognized namespace argument to asarray()") + + _check_device(xp, device) + if _is_numpy_array(obj): + import numpy as np + if hasattr(np, '_CopyMode'): + # Not present in older NumPys + COPY_FALSE = (False, np._CopyMode.IF_NEEDED) + COPY_TRUE = (True, np._CopyMode.ALWAYS) + else: + COPY_FALSE = (False,) + COPY_TRUE = (True,) + else: + COPY_FALSE = (False,) + COPY_TRUE = (True,) + if copy in COPY_FALSE: + # copy=False is not yet implemented in xp.asarray + raise NotImplementedError("copy=False is not yet implemented") + if isinstance(obj, xp.ndarray): + if dtype is not None and obj.dtype != dtype: + copy = True + if copy in COPY_TRUE: + return xp.array(obj, copy=True, dtype=dtype) + return obj + + return xp.asarray(obj, dtype=dtype, **kwargs) + +# xp.reshape calls the keyword argument 'newshape' instead of 'shape' +def reshape(x: ndarray, + /, + shape: Tuple[int, ...], + xp, copy: Optional[bool] = None, + **kwargs) -> ndarray: + if copy is True: + x = x.copy() + elif copy is False: + x.shape = shape + return x + return xp.reshape(x, shape, **kwargs) + +# The descending keyword is new in sort and argsort, and 'kind' replaced with +# 'stable' +def argsort( + x: ndarray, /, xp, *, axis: int = -1, descending: bool = False, stable: bool = True, + **kwargs, +) -> ndarray: + # Note: this keyword argument is different, and the default is different. + # We set it in kwargs like this because numpy.sort uses kind='quicksort' + # as the default whereas cupy.sort uses kind=None. + if stable: + kwargs['kind'] = "stable" + if not descending: + res = xp.argsort(x, axis=axis, **kwargs) + else: + # As NumPy has no native descending sort, we imitate it here. Note that + # simply flipping the results of xp.argsort(x, ...) would not + # respect the relative order like it would in native descending sorts. + res = xp.flip( + xp.argsort(xp.flip(x, axis=axis), axis=axis, **kwargs), + axis=axis, + ) + # Rely on flip()/argsort() to validate axis + normalised_axis = axis if axis >= 0 else x.ndim + axis + max_i = x.shape[normalised_axis] - 1 + res = max_i - res + return res + +def sort( + x: ndarray, /, xp, *, axis: int = -1, descending: bool = False, stable: bool = True, + **kwargs, +) -> ndarray: + # Note: this keyword argument is different, and the default is different. + # We set it in kwargs like this because numpy.sort uses kind='quicksort' + # as the default whereas cupy.sort uses kind=None. + if stable: + kwargs['kind'] = "stable" + res = xp.sort(x, axis=axis, **kwargs) + if descending: + res = xp.flip(res, axis=axis) + return res + +# sum() and prod() should always upcast when dtype=None +def sum( + x: ndarray, + /, + xp, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + dtype: Optional[Dtype] = None, + keepdims: bool = False, + **kwargs, +) -> ndarray: + # `xp.sum` already upcasts integers, but not floats + if dtype is None and x.dtype == xp.float32: + dtype = xp.float64 + return xp.sum(x, axis=axis, dtype=dtype, keepdims=keepdims, **kwargs) + +def prod( + x: ndarray, + /, + xp, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + dtype: Optional[Dtype] = None, + keepdims: bool = False, + **kwargs, +) -> ndarray: + if dtype is None and x.dtype == xp.float32: + dtype = xp.float64 + return xp.prod(x, dtype=dtype, axis=axis, keepdims=keepdims, **kwargs) + +# ceil, floor, and trunc return integers for integer inputs + +def ceil(x: ndarray, /, xp, **kwargs) -> ndarray: + if xp.issubdtype(x.dtype, xp.integer): + return x + return xp.ceil(x, **kwargs) + +def floor(x: ndarray, /, xp, **kwargs) -> ndarray: + if xp.issubdtype(x.dtype, xp.integer): + return x + return xp.floor(x, **kwargs) + +def trunc(x: ndarray, /, xp, **kwargs) -> ndarray: + if xp.issubdtype(x.dtype, xp.integer): + return x + return xp.trunc(x, **kwargs) + +# linear algebra functions + +def matmul(x1: ndarray, x2: ndarray, /, xp, **kwargs) -> ndarray: + return xp.matmul(x1, x2, **kwargs) + +# Unlike transpose, matrix_transpose only transposes the last two axes. +def matrix_transpose(x: ndarray, /, xp) -> ndarray: + if x.ndim < 2: + raise ValueError("x must be at least 2-dimensional for matrix_transpose") + return xp.swapaxes(x, -1, -2) + +def tensordot(x1: ndarray, + x2: ndarray, + /, + xp, + *, + axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2, + **kwargs, +) -> ndarray: + return xp.tensordot(x1, x2, axes=axes, **kwargs) + +def vecdot(x1: ndarray, x2: ndarray, /, xp, *, axis: int = -1) -> ndarray: + ndim = max(x1.ndim, x2.ndim) + x1_shape = (1,)*(ndim - x1.ndim) + tuple(x1.shape) + x2_shape = (1,)*(ndim - x2.ndim) + tuple(x2.shape) + if x1_shape[axis] != x2_shape[axis]: + raise ValueError("x1 and x2 must have the same size along the given axis") + + if hasattr(xp, 'broadcast_tensors'): + _broadcast = xp.broadcast_tensors + else: + _broadcast = xp.broadcast_arrays + + x1_, x2_ = _broadcast(x1, x2) + x1_ = xp.moveaxis(x1_, axis, -1) + x2_ = xp.moveaxis(x2_, axis, -1) + + res = x1_[..., None, :] @ x2_[..., None] + return res[..., 0, 0] + +# isdtype is a new function in the 2022.12 array API specification. + +def isdtype( + dtype: Dtype, kind: Union[Dtype, str, Tuple[Union[Dtype, str], ...]], xp, + *, _tuple=True, # Disallow nested tuples +) -> bool: + """ + Returns a boolean indicating whether a provided dtype is of a specified data type ``kind``. + + Note that outside of this function, this compat library does not yet fully + support complex numbers. + + See + https://data-apis.org/array-api/latest/API_specification/generated/array_api.isdtype.html + for more details + """ + if isinstance(kind, tuple) and _tuple: + return any(isdtype(dtype, k, xp, _tuple=False) for k in kind) + elif isinstance(kind, str): + if kind == 'bool': + return dtype == xp.bool_ + elif kind == 'signed integer': + return xp.issubdtype(dtype, xp.signedinteger) + elif kind == 'unsigned integer': + return xp.issubdtype(dtype, xp.unsignedinteger) + elif kind == 'integral': + return xp.issubdtype(dtype, xp.integer) + elif kind == 'real floating': + return xp.issubdtype(dtype, xp.floating) + elif kind == 'complex floating': + return xp.issubdtype(dtype, xp.complexfloating) + elif kind == 'numeric': + return xp.issubdtype(dtype, xp.number) + else: + raise ValueError(f"Unrecognized data type kind: {kind!r}") + else: + # This will allow things that aren't required by the spec, like + # isdtype(np.float64, float) or isdtype(np.int64, 'l'). Should we be + # more strict here to match the type annotation? Note that the + # numpy.array_api implementation will be very strict. + return dtype == kind + +__all__ = ['arange', 'empty', 'empty_like', 'eye', 'full', 'full_like', + 'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like', + 'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult', + 'unique_all', 'unique_counts', 'unique_inverse', 'unique_values', + 'astype', 'std', 'var', 'permute_dims', 'reshape', 'argsort', + 'sort', 'sum', 'prod', 'ceil', 'floor', 'trunc', 'matmul', + 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype'] diff --git a/sklearn/externals/_array_api_compat/common/_helpers.py b/sklearn/externals/_array_api_compat/common/_helpers.py new file mode 100644 index 0000000000000..e6adc948522bd --- /dev/null +++ b/sklearn/externals/_array_api_compat/common/_helpers.py @@ -0,0 +1,229 @@ +""" +Various helper functions which are not part of the spec. + +Functions which start with an underscore are for internal use only but helpers +that are in __all__ are intended as additional helper functions for use by end +users of the compat library. +""" +from __future__ import annotations + +import sys +import math + +def _is_numpy_array(x): + # Avoid importing NumPy if it isn't already + if 'numpy' not in sys.modules: + return False + + import numpy as np + + # TODO: Should we reject ndarray subclasses? + return isinstance(x, (np.ndarray, np.generic)) + +def _is_cupy_array(x): + # Avoid importing NumPy if it isn't already + if 'cupy' not in sys.modules: + return False + + import cupy as cp + + # TODO: Should we reject ndarray subclasses? + return isinstance(x, (cp.ndarray, cp.generic)) + +def _is_torch_array(x): + # Avoid importing torch if it isn't already + if 'torch' not in sys.modules: + return False + + import torch + + # TODO: Should we reject ndarray subclasses? + return isinstance(x, torch.Tensor) + +def is_array_api_obj(x): + """ + Check if x is an array API compatible array object. + """ + return _is_numpy_array(x) \ + or _is_cupy_array(x) \ + or _is_torch_array(x) \ + or hasattr(x, '__array_namespace__') + +def _check_api_version(api_version): + if api_version is not None and api_version != '2021.12': + raise ValueError("Only the 2021.12 version of the array API specification is currently supported") + +def array_namespace(*xs, api_version=None, _use_compat=True): + """ + Get the array API compatible namespace for the arrays `xs`. + + `xs` should contain one or more arrays. + + Typical usage is + + def your_function(x, y): + xp = array_api_compat.array_namespace(x, y) + # Now use xp as the array library namespace + return xp.mean(x, axis=0) + 2*xp.std(y, axis=0) + + api_version should be the newest version of the spec that you need support + for (currently the compat library wrapped APIs only support v2021.12). + """ + namespaces = set() + for x in xs: + if isinstance(x, (tuple, list)): + namespaces.add(array_namespace(*x, _use_compat=_use_compat)) + elif hasattr(x, '__array_namespace__'): + namespaces.add(x.__array_namespace__(api_version=api_version)) + elif _is_numpy_array(x): + _check_api_version(api_version) + if _use_compat: + from .. import numpy as numpy_namespace + namespaces.add(numpy_namespace) + else: + import numpy as np + namespaces.add(np) + elif _is_cupy_array(x): + _check_api_version(api_version) + if _use_compat: + from .. import cupy as cupy_namespace + namespaces.add(cupy_namespace) + else: + import cupy as cp + namespaces.add(cp) + elif _is_torch_array(x): + _check_api_version(api_version) + if _use_compat: + from .. import torch as torch_namespace + namespaces.add(torch_namespace) + else: + import torch + namespaces.add(torch) + else: + # TODO: Support Python scalars? + raise TypeError("The input is not a supported array type") + + if not namespaces: + raise TypeError("Unrecognized array input") + + if len(namespaces) != 1: + raise TypeError(f"Multiple namespaces for array inputs: {namespaces}") + + xp, = namespaces + + return xp + +# backwards compatibility alias +get_namespace = array_namespace + +def _check_device(xp, device): + if xp == sys.modules.get('numpy'): + if device not in ["cpu", None]: + raise ValueError(f"Unsupported device for NumPy: {device!r}") + +# device() is not on numpy.ndarray and and to_device() is not on numpy.ndarray +# or cupy.ndarray. They are not included in array objects of this library +# because this library just reuses the respective ndarray classes without +# wrapping or subclassing them. These helper functions can be used instead of +# the wrapper functions for libraries that need to support both NumPy/CuPy and +# other libraries that use devices. +def device(x: "Array", /) -> "Device": + """ + Hardware device the array data resides on. + + Parameters + ---------- + x: array + array instance from NumPy or an array API compatible library. + + Returns + ------- + out: device + a ``device`` object (see the "Device Support" section of the array API specification). + """ + if _is_numpy_array(x): + return "cpu" + return x.device + +# Based on cupy.array_api.Array.to_device +def _cupy_to_device(x, device, /, stream=None): + import cupy as cp + from cupy.cuda import Device as _Device + from cupy.cuda import stream as stream_module + from cupy_backends.cuda.api import runtime + + if device == x.device: + return x + elif not isinstance(device, _Device): + raise ValueError(f"Unsupported device {device!r}") + else: + # see cupy/cupy#5985 for the reason how we handle device/stream here + prev_device = runtime.getDevice() + prev_stream: stream_module.Stream = None + if stream is not None: + prev_stream = stream_module.get_current_stream() + # stream can be an int as specified in __dlpack__, or a CuPy stream + if isinstance(stream, int): + stream = cp.cuda.ExternalStream(stream) + elif isinstance(stream, cp.cuda.Stream): + pass + else: + raise ValueError('the input stream is not recognized') + stream.use() + try: + runtime.setDevice(device.id) + arr = x.copy() + finally: + runtime.setDevice(prev_device) + if stream is not None: + prev_stream.use() + return arr + +def _torch_to_device(x, device, /, stream=None): + if stream is not None: + raise NotImplementedError + return x.to(device) + +def to_device(x: "Array", device: "Device", /, *, stream: "Optional[Union[int, Any]]" = None) -> "Array": + """ + Copy the array from the device on which it currently resides to the specified ``device``. + + Parameters + ---------- + x: array + array instance from NumPy or an array API compatible library. + device: device + a ``device`` object (see the "Device Support" section of the array API specification). + stream: Optional[Union[int, Any]] + stream object to use during copy. In addition to the types supported in ``array.__dlpack__``, implementations may choose to support any library-specific stream object with the caveat that any code using such an object would not be portable. + + Returns + ------- + out: array + an array with the same data and data type as ``x`` and located on the specified ``device``. + + .. note:: + If ``stream`` is given, the copy operation should be enqueued on the provided ``stream``; otherwise, the copy operation should be enqueued on the default stream/queue. Whether the copy is performed synchronously or asynchronously is implementation-dependent. Accordingly, if synchronization is required to guarantee data safety, this must be clearly explained in a conforming library's documentation. + """ + if _is_numpy_array(x): + if stream is not None: + raise ValueError("The stream argument to to_device() is not supported") + if device == 'cpu': + return x + raise ValueError(f"Unsupported device {device!r}") + elif _is_cupy_array(x): + # cupy does not yet have to_device + return _cupy_to_device(x, device, stream=stream) + elif _is_torch_array(x): + return _torch_to_device(x, device, stream=stream) + return x.to_device(device, stream=stream) + +def size(x): + """ + Return the total number of elements of x + """ + if None in x.shape: + return None + return math.prod(x.shape) + +__all__ = ['is_array_api_obj', 'array_namespace', 'get_namespace', 'device', 'to_device', 'size'] diff --git a/sklearn/externals/_array_api_compat/common/_linalg.py b/sklearn/externals/_array_api_compat/common/_linalg.py new file mode 100644 index 0000000000000..07daefd9cfd99 --- /dev/null +++ b/sklearn/externals/_array_api_compat/common/_linalg.py @@ -0,0 +1,146 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, NamedTuple +if TYPE_CHECKING: + from typing import Literal, Optional, Sequence, Tuple, Union + from ._typing import ndarray + +from numpy.core.numeric import normalize_axis_tuple + +from ._aliases import matmul, matrix_transpose, tensordot, vecdot +from .._internal import get_xp + +# These are in the main NumPy namespace but not in numpy.linalg +def cross(x1: ndarray, x2: ndarray, /, xp, *, axis: int = -1, **kwargs) -> ndarray: + return xp.cross(x1, x2, axis=axis, **kwargs) + +def outer(x1: ndarray, x2: ndarray, /, xp, **kwargs) -> ndarray: + return xp.outer(x1, x2, **kwargs) + +class EighResult(NamedTuple): + eigenvalues: ndarray + eigenvectors: ndarray + +class QRResult(NamedTuple): + Q: ndarray + R: ndarray + +class SlogdetResult(NamedTuple): + sign: ndarray + logabsdet: ndarray + +class SVDResult(NamedTuple): + U: ndarray + S: ndarray + Vh: ndarray + +# These functions are the same as their NumPy counterparts except they return +# a namedtuple. +def eigh(x: ndarray, /, xp, **kwargs) -> EighResult: + return EighResult(*xp.linalg.eigh(x, **kwargs)) + +def qr(x: ndarray, /, xp, *, mode: Literal['reduced', 'complete'] = 'reduced', + **kwargs) -> QRResult: + return QRResult(*xp.linalg.qr(x, mode=mode, **kwargs)) + +def slogdet(x: ndarray, /, xp, **kwargs) -> SlogdetResult: + return SlogdetResult(*xp.linalg.slogdet(x, **kwargs)) + +def svd(x: ndarray, /, xp, *, full_matrices: bool = True, **kwargs) -> SVDResult: + return SVDResult(*xp.linalg.svd(x, full_matrices=full_matrices, **kwargs)) + +# These functions have additional keyword arguments + +# The upper keyword argument is new from NumPy +def cholesky(x: ndarray, /, xp, *, upper: bool = False, **kwargs) -> ndarray: + L = xp.linalg.cholesky(x, **kwargs) + if upper: + return get_xp(xp)(matrix_transpose)(L) + return L + +# The rtol keyword argument of matrix_rank() and pinv() is new from NumPy. +# Note that it has a different semantic meaning from tol and rcond. +def matrix_rank(x: ndarray, + /, + xp, + *, + rtol: Optional[Union[float, ndarray]] = None, + **kwargs) -> ndarray: + # this is different from xp.linalg.matrix_rank, which supports 1 + # dimensional arrays. + if x.ndim < 2: + raise xp.linalg.LinAlgError("1-dimensional array given. Array must be at least two-dimensional") + S = xp.linalg.svd(x, compute_uv=False, **kwargs) + if rtol is None: + tol = S.max(axis=-1, keepdims=True) * max(x.shape[-2:]) * xp.finfo(S.dtype).eps + else: + # this is different from xp.linalg.matrix_rank, which does not + # multiply the tolerance by the largest singular value. + tol = S.max(axis=-1, keepdims=True)*xp.asarray(rtol)[..., xp.newaxis] + return xp.count_nonzero(S > tol, axis=-1) + +def pinv(x: ndarray, /, xp, *, rtol: Optional[Union[float, ndarray]] = None, **kwargs) -> ndarray: + # this is different from xp.linalg.pinv, which does not multiply the + # default tolerance by max(M, N). + if rtol is None: + rtol = max(x.shape[-2:]) * xp.finfo(x.dtype).eps + return xp.linalg.pinv(x, rcond=rtol, **kwargs) + +# These functions are new in the array API spec + +def matrix_norm(x: ndarray, /, xp, *, keepdims: bool = False, ord: Optional[Union[int, float, Literal['fro', 'nuc']]] = 'fro') -> ndarray: + return xp.linalg.norm(x, axis=(-2, -1), keepdims=keepdims, ord=ord) + +# svdvals is not in NumPy (but it is in SciPy). It is equivalent to +# xp.linalg.svd(compute_uv=False). +def svdvals(x: ndarray, /, xp) -> Union[ndarray, Tuple[ndarray, ...]]: + return xp.linalg.svd(x, compute_uv=False) + +def vector_norm(x: ndarray, /, xp, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, ord: Optional[Union[int, float]] = 2) -> ndarray: + # xp.linalg.norm tries to do a matrix norm whenever axis is a 2-tuple or + # when axis=None and the input is 2-D, so to force a vector norm, we make + # it so the input is 1-D (for axis=None), or reshape so that norm is done + # on a single dimension. + if axis is None: + # Note: xp.linalg.norm() doesn't handle 0-D arrays + x = x.ravel() + _axis = 0 + elif isinstance(axis, tuple): + # Note: The axis argument supports any number of axes, whereas + # xp.linalg.norm() only supports a single axis for vector norm. + normalized_axis = normalize_axis_tuple(axis, x.ndim) + rest = tuple(i for i in range(x.ndim) if i not in normalized_axis) + newshape = axis + rest + x = xp.transpose(x, newshape).reshape( + (xp.prod([x.shape[i] for i in axis], dtype=int), *[x.shape[i] for i in rest])) + _axis = 0 + else: + _axis = axis + + res = xp.linalg.norm(x, axis=_axis, ord=ord) + + if keepdims: + # We can't reuse xp.linalg.norm(keepdims) because of the reshape hacks + # above to avoid matrix norm logic. + shape = list(x.shape) + _axis = normalize_axis_tuple(range(x.ndim) if axis is None else axis, x.ndim) + for i in _axis: + shape[i] = 1 + res = xp.reshape(res, tuple(shape)) + + return res + +# xp.diagonal and xp.trace operate on the first two axes whereas these +# operates on the last two + +def diagonal(x: ndarray, /, xp, *, offset: int = 0, **kwargs) -> ndarray: + return xp.diagonal(x, offset=offset, axis1=-2, axis2=-1, **kwargs) + +def trace(x: ndarray, /, xp, *, offset: int = 0, **kwargs) -> ndarray: + return xp.asarray(xp.trace(x, offset=offset, axis1=-2, axis2=-1, **kwargs)) + +__all__ = ['cross', 'matmul', 'outer', 'tensordot', 'EighResult', + 'QRResult', 'SlogdetResult', 'SVDResult', 'eigh', 'qr', 'slogdet', + 'svd', 'cholesky', 'matrix_rank', 'pinv', 'matrix_norm', + 'matrix_transpose', 'svdvals', 'vecdot', 'vector_norm', 'diagonal', + 'trace'] diff --git a/sklearn/externals/_array_api_compat/common/_typing.py b/sklearn/externals/_array_api_compat/common/_typing.py new file mode 100644 index 0000000000000..3f17806094baa --- /dev/null +++ b/sklearn/externals/_array_api_compat/common/_typing.py @@ -0,0 +1,20 @@ +from __future__ import annotations + +__all__ = [ + "NestedSequence", + "SupportsBufferProtocol", +] + +from typing import ( + Any, + TypeVar, + Protocol, +) + +_T_co = TypeVar("_T_co", covariant=True) + +class NestedSequence(Protocol[_T_co]): + def __getitem__(self, key: int, /) -> _T_co | NestedSequence[_T_co]: ... + def __len__(self, /) -> int: ... + +SupportsBufferProtocol = Any diff --git a/sklearn/externals/_array_api_compat/cupy/__init__.py b/sklearn/externals/_array_api_compat/cupy/__init__.py new file mode 100644 index 0000000000000..10c31bc6aad3a --- /dev/null +++ b/sklearn/externals/_array_api_compat/cupy/__init__.py @@ -0,0 +1,16 @@ +from cupy import * + +# from cupy import * doesn't overwrite these builtin names +from cupy import abs, max, min, round + +# These imports may overwrite names from the import * above. +from ._aliases import * + +# See the comment in the numpy __init__.py +__import__(__package__ + '.linalg') + +from .linalg import matrix_transpose, vecdot + +from ..common._helpers import * + +__array_api_version__ = '2021.12' diff --git a/sklearn/externals/_array_api_compat/cupy/_aliases.py b/sklearn/externals/_array_api_compat/cupy/_aliases.py new file mode 100644 index 0000000000000..b43c371f34f43 --- /dev/null +++ b/sklearn/externals/_array_api_compat/cupy/_aliases.py @@ -0,0 +1,69 @@ +from __future__ import annotations + +from functools import partial + +from ..common import _aliases + +from .._internal import get_xp + +asarray = asarray_cupy = partial(_aliases._asarray, namespace='cupy') +asarray.__doc__ = _aliases._asarray.__doc__ +del partial + +import cupy as cp +bool = cp.bool_ + +# Basic renames +acos = cp.arccos +acosh = cp.arccosh +asin = cp.arcsin +asinh = cp.arcsinh +atan = cp.arctan +atan2 = cp.arctan2 +atanh = cp.arctanh +bitwise_left_shift = cp.left_shift +bitwise_invert = cp.invert +bitwise_right_shift = cp.right_shift +concat = cp.concatenate +pow = cp.power + +arange = get_xp(cp)(_aliases.arange) +empty = get_xp(cp)(_aliases.empty) +empty_like = get_xp(cp)(_aliases.empty_like) +eye = get_xp(cp)(_aliases.eye) +full = get_xp(cp)(_aliases.full) +full_like = get_xp(cp)(_aliases.full_like) +linspace = get_xp(cp)(_aliases.linspace) +ones = get_xp(cp)(_aliases.ones) +ones_like = get_xp(cp)(_aliases.ones_like) +zeros = get_xp(cp)(_aliases.zeros) +zeros_like = get_xp(cp)(_aliases.zeros_like) +UniqueAllResult = get_xp(cp)(_aliases.UniqueAllResult) +UniqueCountsResult = get_xp(cp)(_aliases.UniqueCountsResult) +UniqueInverseResult = get_xp(cp)(_aliases.UniqueInverseResult) +unique_all = get_xp(cp)(_aliases.unique_all) +unique_counts = get_xp(cp)(_aliases.unique_counts) +unique_inverse = get_xp(cp)(_aliases.unique_inverse) +unique_values = get_xp(cp)(_aliases.unique_values) +astype = _aliases.astype +std = get_xp(cp)(_aliases.std) +var = get_xp(cp)(_aliases.var) +permute_dims = get_xp(cp)(_aliases.permute_dims) +reshape = get_xp(cp)(_aliases.reshape) +argsort = get_xp(cp)(_aliases.argsort) +sort = get_xp(cp)(_aliases.sort) +sum = get_xp(cp)(_aliases.sum) +prod = get_xp(cp)(_aliases.prod) +ceil = get_xp(cp)(_aliases.ceil) +floor = get_xp(cp)(_aliases.floor) +trunc = get_xp(cp)(_aliases.trunc) +matmul = get_xp(cp)(_aliases.matmul) +matrix_transpose = get_xp(cp)(_aliases.matrix_transpose) +tensordot = get_xp(cp)(_aliases.tensordot) +vecdot = get_xp(cp)(_aliases.vecdot) +isdtype = get_xp(cp)(_aliases.isdtype) + +__all__ = _aliases.__all__ + ['asarray', 'asarray_cupy', 'bool', 'acos', + 'acosh', 'asin', 'asinh', 'atan', 'atan2', + 'atanh', 'bitwise_left_shift', 'bitwise_invert', + 'bitwise_right_shift', 'concat', 'pow'] diff --git a/sklearn/externals/_array_api_compat/cupy/_typing.py b/sklearn/externals/_array_api_compat/cupy/_typing.py new file mode 100644 index 0000000000000..f3d9aab67e52f --- /dev/null +++ b/sklearn/externals/_array_api_compat/cupy/_typing.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +__all__ = [ + "ndarray", + "Device", + "Dtype", +] + +import sys +from typing import ( + Union, + TYPE_CHECKING, +) + +from cupy import ( + ndarray, + dtype, + int8, + int16, + int32, + int64, + uint8, + uint16, + uint32, + uint64, + float32, + float64, +) + +from cupy.cuda.device import Device + +if TYPE_CHECKING or sys.version_info >= (3, 9): + Dtype = dtype[Union[ + int8, + int16, + int32, + int64, + uint8, + uint16, + uint32, + uint64, + float32, + float64, + ]] +else: + Dtype = dtype diff --git a/sklearn/externals/_array_api_compat/cupy/linalg.py b/sklearn/externals/_array_api_compat/cupy/linalg.py new file mode 100644 index 0000000000000..99c4cc68d783c --- /dev/null +++ b/sklearn/externals/_array_api_compat/cupy/linalg.py @@ -0,0 +1,41 @@ +from cupy.linalg import * +# cupy.linalg doesn't have __all__. If it is added, replace this with +# +# from cupy.linalg import __all__ as linalg_all +_n = {} +exec('from cupy.linalg import *', _n) +del _n['__builtins__'] +linalg_all = list(_n) +del _n + +from ..common import _linalg +from .._internal import get_xp +from ._aliases import (matmul, matrix_transpose, tensordot, vecdot) + +import cupy as cp + +cross = get_xp(cp)(_linalg.cross) +outer = get_xp(cp)(_linalg.outer) +EighResult = _linalg.EighResult +QRResult = _linalg.QRResult +SlogdetResult = _linalg.SlogdetResult +SVDResult = _linalg.SVDResult +eigh = get_xp(cp)(_linalg.eigh) +qr = get_xp(cp)(_linalg.qr) +slogdet = get_xp(cp)(_linalg.slogdet) +svd = get_xp(cp)(_linalg.svd) +cholesky = get_xp(cp)(_linalg.cholesky) +matrix_rank = get_xp(cp)(_linalg.matrix_rank) +pinv = get_xp(cp)(_linalg.pinv) +matrix_norm = get_xp(cp)(_linalg.matrix_norm) +svdvals = get_xp(cp)(_linalg.svdvals) +vector_norm = get_xp(cp)(_linalg.vector_norm) +diagonal = get_xp(cp)(_linalg.diagonal) +trace = get_xp(cp)(_linalg.trace) + +__all__ = linalg_all + _linalg.__all__ + +del get_xp +del cp +del linalg_all +del _linalg diff --git a/sklearn/externals/_array_api_compat/numpy/__init__.py b/sklearn/externals/_array_api_compat/numpy/__init__.py new file mode 100644 index 0000000000000..745367bc8705e --- /dev/null +++ b/sklearn/externals/_array_api_compat/numpy/__init__.py @@ -0,0 +1,22 @@ +from numpy import * + +# from numpy import * doesn't overwrite these builtin names +from numpy import abs, max, min, round + +# These imports may overwrite names from the import * above. +from ._aliases import * + +# Don't know why, but we have to do an absolute import to import linalg. If we +# instead do +# +# from . import linalg +# +# It doesn't overwrite np.linalg from above. The import is generated +# dynamically so that the library can be vendored. +__import__(__package__ + '.linalg') + +from .linalg import matrix_transpose, vecdot + +from ..common._helpers import * + +__array_api_version__ = '2021.12' diff --git a/sklearn/externals/_array_api_compat/numpy/_aliases.py b/sklearn/externals/_array_api_compat/numpy/_aliases.py new file mode 100644 index 0000000000000..08f4de0bafeeb --- /dev/null +++ b/sklearn/externals/_array_api_compat/numpy/_aliases.py @@ -0,0 +1,69 @@ +from __future__ import annotations + +from functools import partial + +from ..common import _aliases + +from .._internal import get_xp + +asarray = asarray_numpy = partial(_aliases._asarray, namespace='numpy') +asarray.__doc__ = _aliases._asarray.__doc__ +del partial + +import numpy as np +bool = np.bool_ + +# Basic renames +acos = np.arccos +acosh = np.arccosh +asin = np.arcsin +asinh = np.arcsinh +atan = np.arctan +atan2 = np.arctan2 +atanh = np.arctanh +bitwise_left_shift = np.left_shift +bitwise_invert = np.invert +bitwise_right_shift = np.right_shift +concat = np.concatenate +pow = np.power + +arange = get_xp(np)(_aliases.arange) +empty = get_xp(np)(_aliases.empty) +empty_like = get_xp(np)(_aliases.empty_like) +eye = get_xp(np)(_aliases.eye) +full = get_xp(np)(_aliases.full) +full_like = get_xp(np)(_aliases.full_like) +linspace = get_xp(np)(_aliases.linspace) +ones = get_xp(np)(_aliases.ones) +ones_like = get_xp(np)(_aliases.ones_like) +zeros = get_xp(np)(_aliases.zeros) +zeros_like = get_xp(np)(_aliases.zeros_like) +UniqueAllResult = get_xp(np)(_aliases.UniqueAllResult) +UniqueCountsResult = get_xp(np)(_aliases.UniqueCountsResult) +UniqueInverseResult = get_xp(np)(_aliases.UniqueInverseResult) +unique_all = get_xp(np)(_aliases.unique_all) +unique_counts = get_xp(np)(_aliases.unique_counts) +unique_inverse = get_xp(np)(_aliases.unique_inverse) +unique_values = get_xp(np)(_aliases.unique_values) +astype = _aliases.astype +std = get_xp(np)(_aliases.std) +var = get_xp(np)(_aliases.var) +permute_dims = get_xp(np)(_aliases.permute_dims) +reshape = get_xp(np)(_aliases.reshape) +argsort = get_xp(np)(_aliases.argsort) +sort = get_xp(np)(_aliases.sort) +sum = get_xp(np)(_aliases.sum) +prod = get_xp(np)(_aliases.prod) +ceil = get_xp(np)(_aliases.ceil) +floor = get_xp(np)(_aliases.floor) +trunc = get_xp(np)(_aliases.trunc) +matmul = get_xp(np)(_aliases.matmul) +matrix_transpose = get_xp(np)(_aliases.matrix_transpose) +tensordot = get_xp(np)(_aliases.tensordot) +vecdot = get_xp(np)(_aliases.vecdot) +isdtype = get_xp(np)(_aliases.isdtype) + +__all__ = _aliases.__all__ + ['asarray', 'asarray_numpy', 'bool', 'acos', + 'acosh', 'asin', 'asinh', 'atan', 'atan2', + 'atanh', 'bitwise_left_shift', 'bitwise_invert', + 'bitwise_right_shift', 'concat', 'pow'] diff --git a/sklearn/externals/_array_api_compat/numpy/_typing.py b/sklearn/externals/_array_api_compat/numpy/_typing.py new file mode 100644 index 0000000000000..c5ebb5abb9875 --- /dev/null +++ b/sklearn/externals/_array_api_compat/numpy/_typing.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +__all__ = [ + "ndarray", + "Device", + "Dtype", +] + +import sys +from typing import ( + Literal, + Union, + TYPE_CHECKING, +) + +from numpy import ( + ndarray, + dtype, + int8, + int16, + int32, + int64, + uint8, + uint16, + uint32, + uint64, + float32, + float64, +) + +Device = Literal["cpu"] +if TYPE_CHECKING or sys.version_info >= (3, 9): + Dtype = dtype[Union[ + int8, + int16, + int32, + int64, + uint8, + uint16, + uint32, + uint64, + float32, + float64, + ]] +else: + Dtype = dtype diff --git a/sklearn/externals/_array_api_compat/numpy/linalg.py b/sklearn/externals/_array_api_compat/numpy/linalg.py new file mode 100644 index 0000000000000..26d6e88e1af47 --- /dev/null +++ b/sklearn/externals/_array_api_compat/numpy/linalg.py @@ -0,0 +1,34 @@ +from numpy.linalg import * +from numpy.linalg import __all__ as linalg_all + +from ..common import _linalg +from .._internal import get_xp +from ._aliases import (matmul, matrix_transpose, tensordot, vecdot) + +import numpy as np + +cross = get_xp(np)(_linalg.cross) +outer = get_xp(np)(_linalg.outer) +EighResult = _linalg.EighResult +QRResult = _linalg.QRResult +SlogdetResult = _linalg.SlogdetResult +SVDResult = _linalg.SVDResult +eigh = get_xp(np)(_linalg.eigh) +qr = get_xp(np)(_linalg.qr) +slogdet = get_xp(np)(_linalg.slogdet) +svd = get_xp(np)(_linalg.svd) +cholesky = get_xp(np)(_linalg.cholesky) +matrix_rank = get_xp(np)(_linalg.matrix_rank) +pinv = get_xp(np)(_linalg.pinv) +matrix_norm = get_xp(np)(_linalg.matrix_norm) +svdvals = get_xp(np)(_linalg.svdvals) +vector_norm = get_xp(np)(_linalg.vector_norm) +diagonal = get_xp(np)(_linalg.diagonal) +trace = get_xp(np)(_linalg.trace) + +__all__ = linalg_all + _linalg.__all__ + +del get_xp +del np +del linalg_all +del _linalg diff --git a/sklearn/externals/_array_api_compat/torch/__init__.py b/sklearn/externals/_array_api_compat/torch/__init__.py new file mode 100644 index 0000000000000..18776f1a0f73b --- /dev/null +++ b/sklearn/externals/_array_api_compat/torch/__init__.py @@ -0,0 +1,22 @@ +from torch import * + +# Several names are not included in the above import * +import torch +for n in dir(torch): + if (n.startswith('_') + or n.endswith('_') + or 'cuda' in n + or 'cpu' in n + or 'backward' in n): + continue + exec(n + ' = torch.' + n) + +# These imports may overwrite names from the import * above. +from ._aliases import * + +# See the comment in the numpy __init__.py +__import__(__package__ + '.linalg') + +from ..common._helpers import * + +__array_api_version__ = '2021.12' diff --git a/sklearn/externals/_array_api_compat/torch/_aliases.py b/sklearn/externals/_array_api_compat/torch/_aliases.py new file mode 100644 index 0000000000000..dbd4d8d9dccfa --- /dev/null +++ b/sklearn/externals/_array_api_compat/torch/_aliases.py @@ -0,0 +1,666 @@ +from __future__ import annotations + +from functools import wraps +from builtins import all as builtin_all, any as builtin_any + +from ..common._aliases import (UniqueAllResult, UniqueCountsResult, + UniqueInverseResult, + matrix_transpose as _aliases_matrix_transpose, + vecdot as _aliases_vecdot) +from .._internal import get_xp + +import torch + +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from typing import List, Optional, Sequence, Tuple, Union + from ..common._typing import Device + from torch import dtype as Dtype + + array = torch.Tensor + +_int_dtypes = { + torch.uint8, + torch.int8, + torch.int16, + torch.int32, + torch.int64, +} + +_array_api_dtypes = { + torch.bool, + *_int_dtypes, + torch.float32, + torch.float64, +} + +_promotion_table = { + # bool + (torch.bool, torch.bool): torch.bool, + # ints + (torch.int8, torch.int8): torch.int8, + (torch.int8, torch.int16): torch.int16, + (torch.int8, torch.int32): torch.int32, + (torch.int8, torch.int64): torch.int64, + (torch.int16, torch.int8): torch.int16, + (torch.int16, torch.int16): torch.int16, + (torch.int16, torch.int32): torch.int32, + (torch.int16, torch.int64): torch.int64, + (torch.int32, torch.int8): torch.int32, + (torch.int32, torch.int16): torch.int32, + (torch.int32, torch.int32): torch.int32, + (torch.int32, torch.int64): torch.int64, + (torch.int64, torch.int8): torch.int64, + (torch.int64, torch.int16): torch.int64, + (torch.int64, torch.int32): torch.int64, + (torch.int64, torch.int64): torch.int64, + # uints + (torch.uint8, torch.uint8): torch.uint8, + # ints and uints (mixed sign) + (torch.int8, torch.uint8): torch.int16, + (torch.int16, torch.uint8): torch.int16, + (torch.int32, torch.uint8): torch.int32, + (torch.int64, torch.uint8): torch.int64, + (torch.uint8, torch.int8): torch.int16, + (torch.uint8, torch.int16): torch.int16, + (torch.uint8, torch.int32): torch.int32, + (torch.uint8, torch.int64): torch.int64, + # floats + (torch.float32, torch.float32): torch.float32, + (torch.float32, torch.float64): torch.float64, + (torch.float64, torch.float32): torch.float64, + (torch.float64, torch.float64): torch.float64, +} + + +def _two_arg(f): + @wraps(f) + def _f(x1, x2, /, **kwargs): + x1, x2 = _fix_promotion(x1, x2) + return f(x1, x2, **kwargs) + if _f.__doc__ is None: + _f.__doc__ = f"""\ +Array API compatibility wrapper for torch.{f.__name__}. + +See the corresponding PyTorch documentation and/or the array API specification +for more details. + +""" + return _f + +def _fix_promotion(x1, x2, only_scalar=True): + if x1.dtype not in _array_api_dtypes or x2.dtype not in _array_api_dtypes: + return x1, x2 + # If an argument is 0-D pytorch downcasts the other argument + if not only_scalar or x1.shape == (): + dtype = result_type(x1, x2) + x2 = x2.to(dtype) + if not only_scalar or x2.shape == (): + dtype = result_type(x1, x2) + x1 = x1.to(dtype) + return x1, x2 + +def result_type(*arrays_and_dtypes: Union[array, Dtype]) -> Dtype: + if len(arrays_and_dtypes) == 0: + raise TypeError("At least one array or dtype must be provided") + if len(arrays_and_dtypes) == 1: + x = arrays_and_dtypes[0] + if isinstance(x, torch.dtype): + return x + return x.dtype + if len(arrays_and_dtypes) > 2: + return result_type(arrays_and_dtypes[0], result_type(*arrays_and_dtypes[1:])) + + x, y = arrays_and_dtypes + xdt = x.dtype if not isinstance(x, torch.dtype) else x + ydt = y.dtype if not isinstance(y, torch.dtype) else y + + if (xdt, ydt) in _promotion_table: + return _promotion_table[xdt, ydt] + + # This doesn't result_type(dtype, dtype) for non-array API dtypes + # because torch.result_type only accepts tensors. This does however, allow + # cross-kind promotion. + return torch.result_type(x, y) + +def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool: + if not isinstance(from_, torch.dtype): + from_ = from_.dtype + return torch.can_cast(from_, to) + +# Basic renames +permute_dims = torch.permute +bitwise_invert = torch.bitwise_not + +# Two-arg elementwise functions +# These require a wrapper to do the correct type promotion on 0-D tensors +add = _two_arg(torch.add) +atan2 = _two_arg(torch.atan2) +bitwise_and = _two_arg(torch.bitwise_and) +bitwise_left_shift = _two_arg(torch.bitwise_left_shift) +bitwise_or = _two_arg(torch.bitwise_or) +bitwise_right_shift = _two_arg(torch.bitwise_right_shift) +bitwise_xor = _two_arg(torch.bitwise_xor) +divide = _two_arg(torch.divide) +# Also a rename. torch.equal does not broadcast +equal = _two_arg(torch.eq) +floor_divide = _two_arg(torch.floor_divide) +greater = _two_arg(torch.greater) +greater_equal = _two_arg(torch.greater_equal) +less = _two_arg(torch.less) +less_equal = _two_arg(torch.less_equal) +logaddexp = _two_arg(torch.logaddexp) +# logical functions are not included here because they only accept bool in the +# spec, so type promotion is irrelevant. +multiply = _two_arg(torch.multiply) +not_equal = _two_arg(torch.not_equal) +pow = _two_arg(torch.pow) +remainder = _two_arg(torch.remainder) +subtract = _two_arg(torch.subtract) + +# These wrappers are mostly based on the fact that pytorch uses 'dim' instead +# of 'axis'. + +# torch.min and torch.max return a tuple and don't support multiple axes https://github.com/pytorch/pytorch/issues/58745 +def max(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> array: + # https://github.com/pytorch/pytorch/issues/29137 + if axis == (): + return torch.clone(x) + return torch.amax(x, axis, keepdims=keepdims) + +def min(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> array: + # https://github.com/pytorch/pytorch/issues/29137 + if axis == (): + return torch.clone(x) + return torch.amin(x, axis, keepdims=keepdims) + +# torch.sort also returns a tuple +# https://github.com/pytorch/pytorch/issues/70921 +def sort(x: array, /, *, axis: int = -1, descending: bool = False, stable: bool = True, **kwargs) -> array: + return torch.sort(x, dim=axis, descending=descending, stable=stable, **kwargs).values + +def _normalize_axes(axis, ndim): + axes = [] + if ndim == 0 and axis: + # Better error message in this case + raise IndexError(f"Dimension out of range: {axis[0]}") + lower, upper = -ndim, ndim - 1 + for a in axis: + if a < lower or a > upper: + # Match torch error message (e.g., from sum()) + raise IndexError(f"Dimension out of range (expected to be in range of [{lower}, {upper}], but got {a}") + if a < 0: + a = a + ndim + if a in axes: + # Use IndexError instead of RuntimeError, and "axis" instead of "dim" + raise IndexError(f"Axis {a} appears multiple times in the list of axes") + axes.append(a) + return sorted(axes) + +def _axis_none_keepdims(x, ndim, keepdims): + # Apply keepdims when axis=None + # (https://github.com/pytorch/pytorch/issues/71209) + # Note that this is only valid for the axis=None case. + if keepdims: + for i in range(ndim): + x = torch.unsqueeze(x, 0) + return x + +def _reduce_multiple_axes(f, x, axis, keepdims=False, **kwargs): + # Some reductions don't support multiple axes + # (https://github.com/pytorch/pytorch/issues/56586). + axes = _normalize_axes(axis, x.ndim) + for a in reversed(axes): + x = torch.movedim(x, a, -1) + x = torch.flatten(x, -len(axes)) + + out = f(x, -1, **kwargs) + + if keepdims: + for a in axes: + out = torch.unsqueeze(out, a) + return out + +def prod(x: array, + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + dtype: Optional[Dtype] = None, + keepdims: bool = False, + **kwargs) -> array: + x = torch.asarray(x) + ndim = x.ndim + + # https://github.com/pytorch/pytorch/issues/29137. Separate from the logic + # below because it still needs to upcast. + if axis == (): + if dtype is None: + # We can't upcast uint8 according to the spec because there is no + # torch.uint64, so at least upcast to int64 which is what sum does + # when axis=None. + if x.dtype in [torch.int8, torch.int16, torch.int32, torch.uint8]: + return x.to(torch.int64) + return x.clone() + return x.to(dtype) + + # torch.prod doesn't support multiple axes + # (https://github.com/pytorch/pytorch/issues/56586). + if isinstance(axis, tuple): + return _reduce_multiple_axes(torch.prod, x, axis, keepdims=keepdims, dtype=dtype, **kwargs) + if axis is None: + # torch doesn't support keepdims with axis=None + # (https://github.com/pytorch/pytorch/issues/71209) + res = torch.prod(x, dtype=dtype, **kwargs) + res = _axis_none_keepdims(res, ndim, keepdims) + return res + + return torch.prod(x, axis, dtype=dtype, keepdims=keepdims, **kwargs) + + +def sum(x: array, + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + dtype: Optional[Dtype] = None, + keepdims: bool = False, + **kwargs) -> array: + x = torch.asarray(x) + ndim = x.ndim + + # https://github.com/pytorch/pytorch/issues/29137. + # Make sure it upcasts. + if axis == (): + if dtype is None: + # We can't upcast uint8 according to the spec because there is no + # torch.uint64, so at least upcast to int64 which is what sum does + # when axis=None. + if x.dtype in [torch.int8, torch.int16, torch.int32, torch.uint8]: + return x.to(torch.int64) + return x.clone() + return x.to(dtype) + + if axis is None: + # torch doesn't support keepdims with axis=None + # (https://github.com/pytorch/pytorch/issues/71209) + res = torch.sum(x, dtype=dtype, **kwargs) + res = _axis_none_keepdims(res, ndim, keepdims) + return res + + return torch.sum(x, axis, dtype=dtype, keepdims=keepdims, **kwargs) + +def any(x: array, + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + keepdims: bool = False, + **kwargs) -> array: + x = torch.asarray(x) + ndim = x.ndim + if axis == (): + return x.to(torch.bool) + # torch.any doesn't support multiple axes + # (https://github.com/pytorch/pytorch/issues/56586). + if isinstance(axis, tuple): + res = _reduce_multiple_axes(torch.any, x, axis, keepdims=keepdims, **kwargs) + return res.to(torch.bool) + if axis is None: + # torch doesn't support keepdims with axis=None + # (https://github.com/pytorch/pytorch/issues/71209) + res = torch.any(x, **kwargs) + res = _axis_none_keepdims(res, ndim, keepdims) + return res.to(torch.bool) + + # torch.any doesn't return bool for uint8 + return torch.any(x, axis, keepdims=keepdims).to(torch.bool) + +def all(x: array, + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + keepdims: bool = False, + **kwargs) -> array: + x = torch.asarray(x) + ndim = x.ndim + if axis == (): + return x.to(torch.bool) + # torch.all doesn't support multiple axes + # (https://github.com/pytorch/pytorch/issues/56586). + if isinstance(axis, tuple): + res = _reduce_multiple_axes(torch.all, x, axis, keepdims=keepdims, **kwargs) + return res.to(torch.bool) + if axis is None: + # torch doesn't support keepdims with axis=None + # (https://github.com/pytorch/pytorch/issues/71209) + res = torch.all(x, **kwargs) + res = _axis_none_keepdims(res, ndim, keepdims) + return res.to(torch.bool) + + # torch.all doesn't return bool for uint8 + return torch.all(x, axis, keepdims=keepdims).to(torch.bool) + +def mean(x: array, + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + keepdims: bool = False, + **kwargs) -> array: + # https://github.com/pytorch/pytorch/issues/29137 + if axis == (): + return torch.clone(x) + if axis is None: + # torch doesn't support keepdims with axis=None + # (https://github.com/pytorch/pytorch/issues/71209) + res = torch.mean(x, **kwargs) + res = _axis_none_keepdims(res, x.ndim, keepdims) + return res + return torch.mean(x, axis, keepdims=keepdims, **kwargs) + +def std(x: array, + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + correction: Union[int, float] = 0.0, + keepdims: bool = False, + **kwargs) -> array: + # Note, float correction is not supported + # https://github.com/pytorch/pytorch/issues/61492. We don't try to + # implement it here for now. + + if isinstance(correction, float): + _correction = int(correction) + if correction != _correction: + raise NotImplementedError("float correction in torch std() is not yet supported") + + # https://github.com/pytorch/pytorch/issues/29137 + if axis == (): + return torch.zeros_like(x) + if isinstance(axis, int): + axis = (axis,) + if axis is None: + # torch doesn't support keepdims with axis=None + # (https://github.com/pytorch/pytorch/issues/71209) + res = torch.std(x, tuple(range(x.ndim)), correction=_correction, **kwargs) + res = _axis_none_keepdims(res, x.ndim, keepdims) + return res + return torch.std(x, axis, correction=_correction, keepdims=keepdims, **kwargs) + +def var(x: array, + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + correction: Union[int, float] = 0.0, + keepdims: bool = False, + **kwargs) -> array: + # Note, float correction is not supported + # https://github.com/pytorch/pytorch/issues/61492. We don't try to + # implement it here for now. + + # if isinstance(correction, float): + # correction = int(correction) + + # https://github.com/pytorch/pytorch/issues/29137 + if axis == (): + return torch.zeros_like(x) + if isinstance(axis, int): + axis = (axis,) + if axis is None: + # torch doesn't support keepdims with axis=None + # (https://github.com/pytorch/pytorch/issues/71209) + res = torch.var(x, tuple(range(x.ndim)), correction=correction, **kwargs) + res = _axis_none_keepdims(res, x.ndim, keepdims) + return res + return torch.var(x, axis, correction=correction, keepdims=keepdims, **kwargs) + +# torch.concat doesn't support dim=None +# https://github.com/pytorch/pytorch/issues/70925 +def concat(arrays: Union[Tuple[array, ...], List[array]], + /, + *, + axis: Optional[int] = 0, + **kwargs) -> array: + if axis is None: + arrays = tuple(ar.flatten() for ar in arrays) + axis = 0 + return torch.concat(arrays, axis, **kwargs) + +# torch.squeeze only accepts int dim and doesn't require it +# https://github.com/pytorch/pytorch/issues/70924. Support for tuple dim was +# added at https://github.com/pytorch/pytorch/pull/89017. +def squeeze(x: array, /, axis: Union[int, Tuple[int, ...]]) -> array: + if isinstance(axis, int): + axis = (axis,) + for a in axis: + if x.shape[a] != 1: + raise ValueError("squeezed dimensions must be equal to 1") + axes = _normalize_axes(axis, x.ndim) + # Remove this once pytorch 1.14 is released with the above PR #89017. + sequence = [a - i for i, a in enumerate(axes)] + for a in sequence: + x = torch.squeeze(x, a) + return x + +# The axis parameter doesn't work for flip() and roll() +# https://github.com/pytorch/pytorch/issues/71210. Also torch.flip() doesn't +# accept axis=None +def flip(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> array: + if axis is None: + axis = tuple(range(x.ndim)) + # torch.flip doesn't accept dim as an int but the method does + # https://github.com/pytorch/pytorch/issues/18095 + return x.flip(axis) + +def roll(x: array, /, shift: Union[int, Tuple[int, ...]], *, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> array: + return torch.roll(x, shift, axis) + +def nonzero(x: array, /, **kwargs) -> Tuple[array, ...]: + return torch.nonzero(x, as_tuple=True, **kwargs) + +def where(condition: array, x1: array, x2: array, /) -> array: + x1, x2 = _fix_promotion(x1, x2) + return torch.where(condition, x1, x2) + +# torch.arange doesn't support returning empty arrays +# (https://github.com/pytorch/pytorch/issues/70915), and doesn't support some +# keyword argument combinations +# (https://github.com/pytorch/pytorch/issues/70914) +def arange(start: Union[int, float], + /, + stop: Optional[Union[int, float]] = None, + step: Union[int, float] = 1, + *, + dtype: Optional[Dtype] = None, + device: Optional[Device] = None, + **kwargs) -> array: + if stop is None: + start, stop = 0, start + if step > 0 and stop <= start or step < 0 and stop >= start: + if dtype is None: + if builtin_all(isinstance(i, int) for i in [start, stop, step]): + dtype = torch.int64 + else: + dtype = torch.float32 + return torch.empty(0, dtype=dtype, device=device, **kwargs) + return torch.arange(start, stop, step, dtype=dtype, device=device, **kwargs) + +# torch.eye does not accept None as a default for the second argument and +# doesn't support off-diagonals (https://github.com/pytorch/pytorch/issues/70910) +def eye(n_rows: int, + n_cols: Optional[int] = None, + /, + *, + k: int = 0, + dtype: Optional[Dtype] = None, + device: Optional[Device] = None, + **kwargs) -> array: + if n_cols is None: + n_cols = n_rows + z = torch.zeros(n_rows, n_cols, dtype=dtype, device=device, **kwargs) + if abs(k) <= n_rows + n_cols: + z.diagonal(k).fill_(1) + return z + +# torch.linspace doesn't have the endpoint parameter +def linspace(start: Union[int, float], + stop: Union[int, float], + /, + num: int, + *, + dtype: Optional[Dtype] = None, + device: Optional[Device] = None, + endpoint: bool = True, + **kwargs) -> array: + if not endpoint: + return torch.linspace(start, stop, num+1, dtype=dtype, device=device, **kwargs)[:-1] + return torch.linspace(start, stop, num, dtype=dtype, device=device, **kwargs) + +# torch.full does not accept an int size +# https://github.com/pytorch/pytorch/issues/70906 +def full(shape: Union[int, Tuple[int, ...]], + fill_value: Union[bool, int, float, complex], + *, + dtype: Optional[Dtype] = None, + device: Optional[Device] = None, + **kwargs) -> array: + if isinstance(shape, int): + shape = (shape,) + + return torch.full(shape, fill_value, dtype=dtype, device=device, **kwargs) + +# ones, zeros, and empty do not accept shape as a keyword argument +def ones(shape: Union[int, Tuple[int, ...]], + *, + dtype: Optional[Dtype] = None, + device: Optional[Device] = None, + **kwargs) -> array: + return torch.ones(shape, dtype=dtype, device=device, **kwargs) + +def zeros(shape: Union[int, Tuple[int, ...]], + *, + dtype: Optional[Dtype] = None, + device: Optional[Device] = None, + **kwargs) -> array: + return torch.zeros(shape, dtype=dtype, device=device, **kwargs) + +def empty(shape: Union[int, Tuple[int, ...]], + *, + dtype: Optional[Dtype] = None, + device: Optional[Device] = None, + **kwargs) -> array: + return torch.empty(shape, dtype=dtype, device=device, **kwargs) + +# tril and triu do not call the keyword argument k + +def tril(x: array, /, *, k: int = 0) -> array: + return torch.tril(x, k) + +def triu(x: array, /, *, k: int = 0) -> array: + return torch.triu(x, k) + +# Functions that aren't in torch https://github.com/pytorch/pytorch/issues/58742 +def expand_dims(x: array, /, *, axis: int = 0) -> array: + return torch.unsqueeze(x, axis) + +def astype(x: array, dtype: Dtype, /, *, copy: bool = True) -> array: + return x.to(dtype, copy=copy) + +def broadcast_arrays(*arrays: array) -> List[array]: + shape = torch.broadcast_shapes(*[a.shape for a in arrays]) + return [torch.broadcast_to(a, shape) for a in arrays] + +# https://github.com/pytorch/pytorch/issues/70920 +def unique_all(x: array) -> UniqueAllResult: + # torch.unique doesn't support returning indices. + # https://github.com/pytorch/pytorch/issues/36748. The workaround + # suggested in that issue doesn't actually function correctly (it relies + # on non-deterministic behavior of scatter()). + raise NotImplementedError("unique_all() not yet implemented for pytorch (see https://github.com/pytorch/pytorch/issues/36748)") + + # values, inverse_indices, counts = torch.unique(x, return_counts=True, return_inverse=True) + # # torch.unique incorrectly gives a 0 count for nan values. + # # https://github.com/pytorch/pytorch/issues/94106 + # counts[torch.isnan(values)] = 1 + # return UniqueAllResult(values, indices, inverse_indices, counts) + +def unique_counts(x: array) -> UniqueCountsResult: + values, counts = torch.unique(x, return_counts=True) + + # torch.unique incorrectly gives a 0 count for nan values. + # https://github.com/pytorch/pytorch/issues/94106 + counts[torch.isnan(values)] = 1 + return UniqueCountsResult(values, counts) + +def unique_inverse(x: array) -> UniqueInverseResult: + values, inverse = torch.unique(x, return_inverse=True) + return UniqueInverseResult(values, inverse) + +def unique_values(x: array) -> array: + return torch.unique(x) + +def matmul(x1: array, x2: array, /, **kwargs) -> array: + # torch.matmul doesn't type promote (but differently from _fix_promotion) + x1, x2 = _fix_promotion(x1, x2, only_scalar=False) + return torch.matmul(x1, x2, **kwargs) + +matrix_transpose = get_xp(torch)(_aliases_matrix_transpose) +_vecdot = get_xp(torch)(_aliases_vecdot) + +def vecdot(x1: array, x2: array, /, *, axis: int = -1) -> array: + x1, x2 = _fix_promotion(x1, x2, only_scalar=False) + return _vecdot(x1, x2, axis=axis) + +# torch.tensordot uses dims instead of axes +def tensordot(x1: array, x2: array, /, *, axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2, **kwargs) -> array: + # Note: torch.tensordot fails with integer dtypes when there is only 1 + # element in the axis (https://github.com/pytorch/pytorch/issues/84530). + x1, x2 = _fix_promotion(x1, x2, only_scalar=False) + return torch.tensordot(x1, x2, dims=axes, **kwargs) + + +def isdtype( + dtype: Dtype, kind: Union[Dtype, str, Tuple[Union[Dtype, str], ...]], + *, _tuple=True, # Disallow nested tuples +) -> bool: + """ + Returns a boolean indicating whether a provided dtype is of a specified data type ``kind``. + + Note that outside of this function, this compat library does not yet fully + support complex numbers. + + See + https://data-apis.org/array-api/latest/API_specification/generated/array_api.isdtype.html + for more details + """ + if isinstance(kind, tuple) and _tuple: + return builtin_any(isdtype(dtype, k, _tuple=False) for k in kind) + elif isinstance(kind, str): + if kind == 'bool': + return dtype == torch.bool + elif kind == 'signed integer': + return dtype in _int_dtypes and dtype.is_signed + elif kind == 'unsigned integer': + return dtype in _int_dtypes and not dtype.is_signed + elif kind == 'integral': + return dtype in _int_dtypes + elif kind == 'real floating': + return dtype.is_floating_point + elif kind == 'complex floating': + return dtype.is_complex + elif kind == 'numeric': + return isdtype(dtype, ('integral', 'real floating', 'complex floating')) + else: + raise ValueError(f"Unrecognized data type kind: {kind!r}") + else: + return dtype == kind + +__all__ = ['result_type', 'can_cast', 'permute_dims', 'bitwise_invert', 'add', + 'atan2', 'bitwise_and', 'bitwise_left_shift', 'bitwise_or', + 'bitwise_right_shift', 'bitwise_xor', 'divide', 'equal', + 'floor_divide', 'greater', 'greater_equal', 'less', 'less_equal', + 'logaddexp', 'multiply', 'not_equal', 'pow', 'remainder', + 'subtract', 'max', 'min', 'sort', 'prod', 'sum', 'any', 'all', + 'mean', 'std', 'var', 'concat', 'squeeze', 'flip', 'roll', + 'nonzero', 'where', 'arange', 'eye', 'linspace', 'full', 'ones', + 'zeros', 'empty', 'tril', 'triu', 'expand_dims', 'astype', + 'broadcast_arrays', 'unique_all', 'unique_counts', + 'unique_inverse', 'unique_values', 'matmul', 'matrix_transpose', + 'vecdot', 'tensordot', 'isdtype'] diff --git a/sklearn/externals/_array_api_compat/torch/linalg.py b/sklearn/externals/_array_api_compat/torch/linalg.py new file mode 100644 index 0000000000000..c803228abc604 --- /dev/null +++ b/sklearn/externals/_array_api_compat/torch/linalg.py @@ -0,0 +1,27 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING +if TYPE_CHECKING: + import torch + array = torch.Tensor + +from torch.linalg import * + +# torch.linalg doesn't define __all__ +# from torch.linalg import __all__ as linalg_all +from torch import linalg as torch_linalg +linalg_all = [i for i in dir(torch_linalg) if not i.startswith('_')] + +# These are implemented in torch but aren't in the linalg namespace +from torch import outer, trace +from ._aliases import _fix_promotion, matrix_transpose, tensordot + +# Note: torch.linalg.cross does not default to axis=-1 (it defaults to the +# first axis with size 3), see https://github.com/pytorch/pytorch/issues/58743 +def cross(x1: array, x2: array, /, *, axis: int = -1) -> array: + x1, x2 = _fix_promotion(x1, x2, only_scalar=False) + return torch_linalg.cross(x1, x2, dim=axis) + +__all__ = linalg_all + ['outer', 'trace', 'matrix_transpose', 'tensordot'] + +del linalg_all diff --git a/sklearn/linear_model/_base.py b/sklearn/linear_model/_base.py index 987ae57c12250..cdf7fa43220a1 100644 --- a/sklearn/linear_model/_base.py +++ b/sklearn/linear_model/_base.py @@ -399,7 +399,7 @@ def decision_function(self, X): X = self._validate_data(X, accept_sparse="csr", reset=False) scores = safe_sparse_dot(X, self.coef_.T, dense_output=True) + self.intercept_ - return xp.reshape(scores, -1) if scores.shape[1] == 1 else scores + return xp.reshape(scores, (-1,)) if scores.shape[1] == 1 else scores def predict(self, X): """ @@ -422,7 +422,7 @@ def predict(self, X): else: indices = xp.argmax(scores, axis=1) - return xp.take(self.classes_, indices, axis=0) + return xp.take(self.classes_, indices) def _predict_proba_lr(self, X): """Probability estimation for OvR logistic regression. diff --git a/sklearn/tests/test_discriminant_analysis.py b/sklearn/tests/test_discriminant_analysis.py index b005d821b4a94..e431d09282b49 100644 --- a/sklearn/tests/test_discriminant_analysis.py +++ b/sklearn/tests/test_discriminant_analysis.py @@ -731,3 +731,60 @@ def test_lda_array_api(array_namespace): err_msg=f"{method} did not the return the same result", atol=1e-6, ) + + +@pytest.mark.parametrize("device", ["cuda", "cpu"]) +@pytest.mark.parametrize("dtype", ["float32", "float64"]) +def test_lda_array_torch(device, dtype): + """Check running on PyTorch Tensors gives the same results as NumPy""" + torch = pytest.importorskip("torch") + if device == "cuda" and not torch.has_cuda: + pytest.skip("test requires cuda") + + lda = LinearDiscriminantAnalysis() + X_np = X6.astype(dtype) + y_np = y6.astype(dtype) + lda.fit(X_np, y_np) + + X_torch = torch.asarray(X_np, device=device) + y_torch = torch.asarray(y_np, device=device) + lda_xp = clone(lda) + with config_context(array_api_dispatch=True): + lda_xp.fit(X_torch, y_torch) + + array_attributes = { + key: value for key, value in vars(lda).items() if isinstance(value, np.ndarray) + } + + for key, attribute in array_attributes.items(): + lda_xp_param = getattr(lda_xp, key) + assert isinstance(lda_xp_param, torch.Tensor) + + lda_xp_param_np = _convert_to_numpy(lda_xp_param, xp=torch) + assert_allclose( + attribute, lda_xp_param_np, err_msg=f"{key} not the same", atol=1e-3 + ) + + # Check predictions are the same + methods = ( + "decision_function", + "predict", + "predict_log_proba", + "predict_proba", + "transform", + ) + for method in methods: + result = getattr(lda, method)(X_np) + with config_context(array_api_dispatch=True): + result_xp = getattr(lda_xp, method)(X_torch) + + assert isinstance(result_xp, torch.Tensor) + + result_xp_np = _convert_to_numpy(result_xp, xp=torch) + + assert_allclose( + result, + result_xp_np, + err_msg=f"{method} did not the return the same result", + atol=1e-6, + ) diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py index fff8e1ee33a49..ba27af35324f9 100644 --- a/sklearn/utils/_array_api.py +++ b/sklearn/utils/_array_api.py @@ -1,8 +1,22 @@ """Tools to support array_api.""" import numpy -from .._config import get_config import scipy.special as special +import sklearn.externals._array_api_compat as array_api_compat + +import sklearn.externals._array_api_compat.numpy as array_api_compat_numpy +from sklearn.externals._array_api_compat import device, size # noqa + +from .._config import get_config + + +def _is_numpy_namespace(xp): + return xp.__name__ in { + "numpy", + "sklearn.externals._array_api_compat.numpy", + "numpy.array_api", + } + class _ArrayAPIWrapper: """sklearn specific Array API compatibility wrapper @@ -24,7 +38,7 @@ def __init__(self, array_namespace): def __getattr__(self, name): return getattr(self._namespace, name) - def take(self, X, indices, *, axis): + def take(self, X, indices, *, axis=0): # When array_api supports `take` we can use this directly # https://github.com/data-apis/array-api/issues/177 if self._namespace.__name__ == "numpy.array_api": @@ -48,43 +62,43 @@ def take(self, X, indices, *, axis): selected = [X[:, i] for i in indices] return self._namespace.stack(selected, axis=axis) + def isdtype(self, dtype, kind): + """Returns a boolean indicating whether a provided dtype is of type "kind". -class _NumPyApiWrapper: - """Array API compat wrapper for any numpy version - - NumPy < 1.22 does not expose the numpy.array_api namespace. This - wrapper makes it possible to write code that uses the standard - Array API while working with any version of NumPy supported by - scikit-learn. - - See the `get_namespace()` public function for more details. - """ - - def __getattr__(self, name): - return getattr(numpy, name) - - def astype(self, x, dtype, *, copy=True, casting="unsafe"): - # astype is not defined in the top level NumPy namespace - return x.astype(dtype, copy=copy, casting=casting) - - def asarray(self, x, *, dtype=None, device=None, copy=None): - # Support copy in NumPy namespace - if copy is True: - return numpy.array(x, copy=True, dtype=dtype) + Included in the v2022.12 of the Array API spec. + https://data-apis.org/array-api/latest/API_specification/generated/array_api.isdtype.html + """ + if isinstance(kind, tuple): + return any(self._isdtype_single(dtype, k) for k in kind) else: - return numpy.asarray(x, dtype=dtype) - - def unique_inverse(self, x): - return numpy.unique(x, return_inverse=True) - - def unique_counts(self, x): - return numpy.unique(x, return_counts=True) - - def unique_values(self, x): - return numpy.unique(x) - - def concat(self, arrays, *, axis=None): - return numpy.concatenate(arrays, axis=axis) + return self._isdtype_single(dtype, kind) + + def _isdtype_single(self, dtype, kind): + xp = self._namespace + if isinstance(kind, str): + if kind == "bool": + return dtype == xp.bool + elif kind == "signed integer": + return dtype in {xp.int8, xp.int16, xp.int32, xp.int64} + elif kind == "unsigned integer": + return dtype in {xp.uint8, xp.uint16, xp.uint32, xp.uint64} + elif kind == "integral": + return self.isdtype(dtype, ("signed integer", "unsigned integer")) + elif kind == "real floating": + return dtype in {xp.float32, xp.float64} + elif kind == "complex floating": + # cupy.array_api and numpy.array_cpi does not have copmlex + if xp.__name__ in {"cupy.array_api", "numpy.array_api"}: + return False + return dtype in {xp.complex64, xp.float128} + elif kind == "numeric": + return self.isdtype( + dtype, ("integral", "real floating", "complex floating") + ) + else: + raise ValueError(f"Unrecognized data type kind: {kind!r}") + else: + return dtype == kind def get_namespace(*arrays): @@ -123,43 +137,36 @@ def get_namespace(*arrays): Returns ------- namespace : module - Namespace shared by array objects. + Namespace shared by array objects. If any of the `arrays` are not arrays, + the namespace defaults to NumPy. is_array_api : bool True of the arrays are containers that implement the Array API spec. """ - # `arrays` contains one or more arrays, or possibly Python scalars (accepting - # those is a matter of taste, but doesn't seem unreasonable). - # Returns a tuple: (array_namespace, is_array_api) - - if not get_config()["array_api_dispatch"]: - return _NumPyApiWrapper(), False + return _get_namespace( + *arrays, array_api_dispatch=get_config()["array_api_dispatch"] + ) - namespaces = { - x.__array_namespace__() if hasattr(x, "__array_namespace__") else None - for x in arrays - if not isinstance(x, (bool, int, float, complex)) - } - if not namespaces: - # one could special-case np.ndarray above or use np.asarray here if - # older numpy versions need to be supported. - raise ValueError("Unrecognized array input") +def _get_namespace(*arrays, array_api_dispatch=False): + if not array_api_dispatch: + return array_api_compat_numpy, False + try: + namespace, is_array = array_api_compat.get_namespace(*arrays), True + except TypeError as e: + if str(e).startswith("The input is not a supported array type"): + return array_api_compat_numpy, False + raise - if len(namespaces) != 1: - raise ValueError(f"Multiple namespaces for array inputs: {namespaces}") + if namespace.__name__ in {"numpy.array_api", "cupy.array_api"}: + namespace = _ArrayAPIWrapper(namespace) - (xp,) = namespaces - if xp is None: - # Use numpy as default - return _NumPyApiWrapper(), False - - return _ArrayAPIWrapper(xp), True + return namespace, is_array def _expit(X): xp, _ = get_namespace(X) - if xp.__name__ in {"numpy", "numpy.array_api"}: + if _is_numpy_namespace(xp): return xp.asarray(special.expit(numpy.asarray(X))) return 1.0 / (1.0 + xp.exp(-X)) @@ -180,28 +187,31 @@ def _asarray_with_order(array, dtype=None, order=None, copy=None, xp=None): """ if xp is None: xp, _ = get_namespace(array) - if xp.__name__ in {"numpy", "numpy.array_api"}: + if _is_numpy_namespace(xp): # Use NumPy API to support order - array = numpy.asarray(array, order=order, dtype=dtype) - return xp.asarray(array, copy=copy) + if copy is True: + array = numpy.array(array, order=order, dtype=dtype) + else: + array = numpy.asarray(array, order=order, dtype=dtype) + return xp.asarray(array) else: return xp.asarray(array, dtype=dtype, copy=copy) def _convert_to_numpy(array, xp): - """Convert X into a NumPy ndarray. + """Convert X into a NumPy ndarray on the CPU.""" + xp_name = xp.__name__ - Only works on cupy.array_api and numpy.array_api and is used for testing. - """ - supported_array_api = ["numpy.array_api", "cupy.array_api"] - if xp.__name__ not in supported_array_api: - support_array_api_str = ", ".join(supported_array_api) - raise ValueError(f"Supported namespaces are: {support_array_api_str}") - - if xp.__name__ == "cupy.array_api": + if _is_numpy_namespace(xp): + return numpy.asarray(array) + elif xp_name in {"sklearn.externals._array_api_compat.torch", "torch"}: + return array.cpu().numpy() + elif xp_name == "cupy.array_api": return array._array.get() + elif xp_name in {"sklearn.externals._array_api_compat.cupy", "cupy"}: + return array.get() else: - return numpy.asarray(array) + raise ValueError(f"{xp_name} is an unsupported namespace") def _estimator_with_converted_arrays(estimator, converter): @@ -224,9 +234,8 @@ def _estimator_with_converted_arrays(estimator, converter): new_estimator = clone(estimator) for key, attribute in vars(estimator).items(): - if hasattr(attribute, "__array_namespace__") or isinstance( - attribute, numpy.ndarray - ): + _, is_array = _get_namespace(attribute, array_api_dispatch=True) + if is_array: attribute = converter(attribute) setattr(new_estimator, key, attribute) return new_estimator diff --git a/sklearn/utils/extmath.py b/sklearn/utils/extmath.py index 49908fdf1083d..cf933cef783c4 100644 --- a/sklearn/utils/extmath.py +++ b/sklearn/utils/extmath.py @@ -20,7 +20,7 @@ from ._logistic_sigmoid import _log_logistic_sigmoid from .sparsefuncs_fast import csr_row_norms from .validation import check_array -from ._array_api import get_namespace +from ._array_api import get_namespace, _is_numpy_namespace def squared_norm(x): @@ -886,7 +886,7 @@ def softmax(X, copy=True): max_prob = xp.reshape(xp.max(X, axis=1), (-1, 1)) X -= max_prob - if xp.__name__ in {"numpy", "numpy.array_api"}: + if _is_numpy_namespace(xp): # optimization for NumPy arrays np.exp(X, out=np.asarray(X)) else: diff --git a/sklearn/utils/multiclass.py b/sklearn/utils/multiclass.py index f14b981f9b83a..72bf499a6e31d 100644 --- a/sklearn/utils/multiclass.py +++ b/sklearn/utils/multiclass.py @@ -374,10 +374,10 @@ def type_of_target(y, input_name=""): suffix = "" # [1, 2, 3] or [[1], [2], [3]] # Check float and contains non-integer float values - if y.dtype.kind == "f": + if xp.isdtype(y.dtype, "real floating"): # [.1, .2, 3] or [[.1, .2, 3]] or [[1., .2]] and not [1., 2., 3.] data = y.data if issparse(y) else y - if xp.any(data != data.astype(int)): + if xp.any(data != xp.astype(data, int)): _assert_all_finite(data, input_name=input_name) return "continuous" + suffix diff --git a/sklearn/utils/tests/test_array_api.py b/sklearn/utils/tests/test_array_api.py index 7318382ae9d66..cb8cdeda32376 100644 --- a/sklearn/utils/tests/test_array_api.py +++ b/sklearn/utils/tests/test_array_api.py @@ -1,14 +1,16 @@ import numpy -from numpy.testing import assert_array_equal +from numpy.testing import assert_allclose, assert_array_equal import pytest from sklearn.base import BaseEstimator from sklearn.utils._array_api import get_namespace -from sklearn.utils._array_api import _NumPyApiWrapper from sklearn.utils._array_api import _ArrayAPIWrapper + from sklearn.utils._array_api import _asarray_with_order from sklearn.utils._array_api import _convert_to_numpy from sklearn.utils._array_api import _estimator_with_converted_arrays + +import sklearn.externals._array_api_compat.numpy as array_api_compat_numpy from sklearn._config import config_context pytestmark = pytest.mark.filterwarnings( @@ -22,12 +24,11 @@ def test_get_namespace_ndarray(): X_np = numpy.asarray([[1, 2, 3]]) - # Dispatching on Numpy regardless or the value of array_api_dispatch. for array_api_dispatch in [True, False]: with config_context(array_api_dispatch=array_api_dispatch): xp_out, is_array_api = get_namespace(X_np) - assert not is_array_api - assert isinstance(xp_out, _NumPyApiWrapper) + assert is_array_api == array_api_dispatch + assert xp_out is array_api_compat_numpy def test_get_namespace_array_api(): @@ -42,11 +43,12 @@ def test_get_namespace_array_api(): assert isinstance(xp_out, _ArrayAPIWrapper) # check errors - with pytest.raises(ValueError, match="Multiple namespaces"): + with pytest.raises(TypeError, match="Multiple namespaces"): get_namespace(X_np, X_xp) - with pytest.raises(ValueError, match="Unrecognized array input"): - get_namespace(1) + xp_out, is_array_api = get_namespace(1) + assert xp_out == array_api_compat_numpy + assert not is_array_api class _AdjustableNameAPITestWrapper(_ArrayAPIWrapper): @@ -150,10 +152,37 @@ def test_convert_to_numpy_error(): X = xp_.asarray([1.2, 3.4]) - with pytest.raises(ValueError, match="Supported namespaces are:"): + with pytest.raises(ValueError, match="wrapped.array_api is an unsupported"): _convert_to_numpy(X, xp=xp_) +@pytest.mark.parametrize("library", ["cupy", "torch", "cupy.array_api"]) +def test_convert_to_numpy_gpu(library): + """Check convert_to_numpy for GPU backed libraries.""" + xp = pytest.importorskip(library) + + if library == "torch": + if not xp.has_cuda: + pytest.skip("test requires cuda") + X_gpu = xp.asarray([1.0, 2.0, 3.0], device="cuda") + else: + X_gpu = xp.asarray([1.0, 2.0, 3.0]) + + X_cpu = _convert_to_numpy(X_gpu, xp=xp) + expected_output = numpy.asarray([1.0, 2.0, 3.0]) + assert_allclose(X_cpu, expected_output) + + +def test_convert_to_numpy_cpu(): + """Check convert_to_numpy for PyTorch CPU arrays.""" + torch = pytest.importorskip("torch") + X_torch = torch.asarray([1.0, 2.0, 3.0], device="cpu") + + X_cpu = _convert_to_numpy(X_torch, xp=torch) + expected_output = numpy.asarray([1.0, 2.0, 3.0]) + assert_allclose(X_cpu, expected_output) + + class SimpleEstimator(BaseEstimator): def fit(self, X, y=None): self.X_ = X @@ -161,16 +190,18 @@ def fit(self, X, y=None): return self -@pytest.mark.parametrize("array_namespace", ["numpy.array_api", "cupy.array_api"]) -def test_convert_estimator_to_ndarray(array_namespace): +@pytest.mark.parametrize( + "array_namespace, converter", + [ + ("torch", lambda array: array.cpu().numpy()), + ("numpy.array_api", lambda array: numpy.asarray(array)), + ("cupy.array_api", lambda array: array._array.get()), + ], +) +def test_convert_estimator_to_ndarray(array_namespace, converter): """Convert estimator attributes to ndarray.""" xp = pytest.importorskip(array_namespace) - if array_namespace == "numpy.array_api": - converter = lambda array: numpy.asarray(array) # noqa - else: # pragma: no cover - converter = lambda array: array._array.get() # noqa - X = xp.asarray([[1.3, 4.5]]) est = SimpleEstimator().fit(X) @@ -187,3 +218,40 @@ def test_convert_estimator_to_array_api(): new_est = _estimator_with_converted_arrays(est, lambda array: xp.asarray(array)) assert hasattr(new_est.X_, "__array_namespace__") + + +@pytest.mark.parametrize("array_api_dispatch", [True, False]) +def test_get_namespace_array_api_isdtype(array_api_dispatch): + """Test isdtype implementation from _ArrayAPIWrapper and array_api_compat.""" + xp = pytest.importorskip("numpy.array_api") + + X_xp = xp.asarray([[1, 2, 3]]) + with config_context(array_api_dispatch=array_api_dispatch): + xp_out, _ = get_namespace(X_xp) + assert xp_out.isdtype(xp_out.float32, "real floating") + assert xp_out.isdtype(xp_out.float64, "real floating") + assert not xp_out.isdtype(xp_out.int32, "real floating") + + assert xp_out.isdtype(xp_out.bool, "bool") + assert not xp_out.isdtype(xp_out.float32, "bool") + + assert xp_out.isdtype(xp_out.int16, "signed integer") + assert not xp_out.isdtype(xp_out.uint32, "signed integer") + + assert xp_out.isdtype(xp_out.uint16, "unsigned integer") + assert not xp_out.isdtype(xp_out.int64, "unsigned integer") + + assert xp_out.isdtype(xp_out.int64, "numeric") + assert xp_out.isdtype(xp_out.float32, "numeric") + assert xp_out.isdtype(xp_out.uint32, "numeric") + + +@pytest.mark.parametrize("array_api_dispatch", [True, False]) +def test_get_namespace_list(array_api_dispatch): + """Test get_namespace for lists.""" + + X = [1, 2, 3] + with config_context(array_api_dispatch=array_api_dispatch): + xp_out, is_array = get_namespace(X) + assert not is_array + assert xp_out is array_api_compat_numpy diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index e8978a086d001..887c46e972310 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -31,6 +31,7 @@ from ..exceptions import DataConversionWarning from ..utils._array_api import get_namespace from ..utils._array_api import _asarray_with_order +from ..utils._array_api import _is_numpy_namespace from ._isfinite import cy_isfinite, FiniteStatus FLOAT_DTYPES = (np.float64, np.float32, np.float16) @@ -111,7 +112,7 @@ def _assert_all_finite( raise ValueError("Input contains NaN") # We need only consider float arrays, hence can early return for all else. - if X.dtype.kind not in "fc": + if not xp.isdtype(X.dtype, ("real floating", "complex floating")): return # First try an O(n) time, O(1) space solution for the common case that @@ -759,7 +760,7 @@ def check_array( dtype_numeric = isinstance(dtype, str) and dtype == "numeric" dtype_orig = getattr(array, "dtype", None) - if not hasattr(dtype_orig, "kind"): + if not is_array_api and not hasattr(dtype_orig, "kind"): # not a data type (e.g. a column named dtype in a pandas DataFrame) dtype_orig = None @@ -801,7 +802,11 @@ def check_array( dtype_orig = None if dtype_numeric: - if dtype_orig is not None and dtype_orig.kind == "O": + if ( + dtype_orig is not None + and hasattr(dtype_orig, "kind") + and dtype_orig.kind == "O" + ): # if input is object, convert to float. dtype = xp.float64 else: @@ -875,12 +880,12 @@ def check_array( with warnings.catch_warnings(): try: warnings.simplefilter("error", ComplexWarning) - if dtype is not None and np.dtype(dtype).kind in "iu": + if dtype is not None and xp.isdtype(dtype, "integral"): # Conversion float -> int should not contain NaN or # inf (numpy#14412). We cannot use casting='safe' because # then conversion float -> int would be disallowed. array = _asarray_with_order(array, order=order, xp=xp) - if array.dtype.kind == "f": + if xp.isdtype(array.dtype, ("real floating", "complex floating")): _assert_all_finite( array, allow_nan=False, @@ -920,7 +925,7 @@ def check_array( "if it contains a single sample.".format(array) ) - if dtype_numeric and array.dtype.kind in "USV": + if dtype_numeric and hasattr(array.dtype, "kind") and array.dtype.kind in "USV": raise ValueError( "dtype='numeric' is not compatible with arrays of bytes/strings." "Convert your data to numeric values explicitly instead." @@ -958,7 +963,7 @@ def check_array( ) if copy: - if xp.__name__ in {"numpy", "numpy.array_api"}: + if _is_numpy_namespace(xp): # only make a copy if `array` and `array_orig` may share memory` if np.may_share_memory(array, array_orig): array = _asarray_with_order( @@ -1201,7 +1206,7 @@ def column_or_1d(y, *, dtype=None, warn=False): shape = y.shape if len(shape) == 1: - return _asarray_with_order(xp.reshape(y, -1), order="C", xp=xp) + return _asarray_with_order(xp.reshape(y, (-1,)), order="C", xp=xp) if len(shape) == 2 and shape[1] == 1: if warn: warnings.warn( @@ -1211,7 +1216,7 @@ def column_or_1d(y, *, dtype=None, warn=False): DataConversionWarning, stacklevel=2, ) - return _asarray_with_order(xp.reshape(y, -1), order="C", xp=xp) + return _asarray_with_order(xp.reshape(y, (-1,)), order="C", xp=xp) raise ValueError( "y should be a 1d array, got an array of shape {} instead.".format(shape) From e8937e98930e87853ec2085574aaffc961f93e93 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Thu, 23 Mar 2023 15:12:53 -0400 Subject: [PATCH 02/42] DOC Adds PR number --- doc/whats_new/v1.3.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/whats_new/v1.3.rst b/doc/whats_new/v1.3.rst index ce40c38aff299..b9bbc213edbe6 100644 --- a/doc/whats_new/v1.3.rst +++ b/doc/whats_new/v1.3.rst @@ -204,7 +204,7 @@ Changelog - |Enhancement| :class:`discriminant_analysis.LinearDiscriminantAnalysis` now supports the `PyTorch `__. See - :ref:`array_api` for more details. :pr:`xxxxx` by `Thomas Fan`_. + :ref:`array_api` for more details. :pr:`25956` by `Thomas Fan`_. :mod:`sklearn.ensemble` ....................... From 18d6aa91f1af96a6c9f079203def8e88f4a2582b Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Thu, 23 Mar 2023 15:37:53 -0400 Subject: [PATCH 03/42] CI Do not follow imports for vendored files --- pyproject.toml | 11 +++++++++++ setup.cfg | 6 +----- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index fbb1d53ef6602..1e3683752ba8b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,3 +38,14 @@ exclude = ''' | asv_benchmarks/env )/ ''' + +[tool.mypy] +ignore_missing_imports = true +allow_redefinition = true +exclude = [ + "sklearn/externals" +] + +[[tool.mypy.overrides]] +module = "sklearn.externals.*" +follow_imports = "skip" diff --git a/setup.cfg b/setup.cfg index dc18059dca8a9..637e0ba0675f2 100644 --- a/setup.cfg +++ b/setup.cfg @@ -73,11 +73,7 @@ per-file-ignores = examples/*: E402 doc/conf.py: E402 -[mypy] -ignore_missing_imports = True -allow_redefinition = True -exclude = - sklearn/externals + [check-manifest] # ignore files missing in VCS From 8678389e551242b5dafec8cab808cc79bf01b7bb Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Thu, 23 Mar 2023 15:45:00 -0400 Subject: [PATCH 04/42] CI Do not check imports in externals._array_api_compat --- sklearn/tests/test_common.py | 2 +- sklearn/tests/test_docstring_parameters.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/sklearn/tests/test_common.py b/sklearn/tests/test_common.py index 6ef0eaa433d20..96cf0835979a1 100644 --- a/sklearn/tests/test_common.py +++ b/sklearn/tests/test_common.py @@ -202,7 +202,7 @@ def test_import_all_consistency(): ) submods = [modname for _, modname, _ in pkgs] for modname in submods + ["sklearn"]: - if ".tests." in modname: + if ".tests." in modname or ".externals._array_api_compat." in modname: continue if IS_PYPY and ( "_svmlight_format_io" in modname diff --git a/sklearn/tests/test_docstring_parameters.py b/sklearn/tests/test_docstring_parameters.py index 8bf3e5dd7b24a..aa043eeff41c7 100644 --- a/sklearn/tests/test_docstring_parameters.py +++ b/sklearn/tests/test_docstring_parameters.py @@ -162,6 +162,9 @@ def test_tabs(): ): continue + if ".externals._array_api_compat." in modname: + continue + # because we don't import mod = importlib.import_module(modname) From 873f46e81f0ae03cb2846461be2498ab637ec9e5 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Thu, 23 Mar 2023 16:21:13 -0400 Subject: [PATCH 05/42] CI Fix Failures --- doc/whats_new/v1.3.rst | 2 +- sklearn/tests/test_discriminant_analysis.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/doc/whats_new/v1.3.rst b/doc/whats_new/v1.3.rst index b9bbc213edbe6..37df2731e5cb3 100644 --- a/doc/whats_new/v1.3.rst +++ b/doc/whats_new/v1.3.rst @@ -199,7 +199,7 @@ Changelog :class:`decomposition.MiniBatchNMF` which can produce different results than previous versions. :pr:`25438` by :user:`Yotam Avidar-Constantini `. - :mod:`sklearn.discriminant_analysis` +:mod:`sklearn.discriminant_analysis` .................................... - |Enhancement| :class:`discriminant_analysis.LinearDiscriminantAnalysis` now diff --git a/sklearn/tests/test_discriminant_analysis.py b/sklearn/tests/test_discriminant_analysis.py index e431d09282b49..90337cc826132 100644 --- a/sklearn/tests/test_discriminant_analysis.py +++ b/sklearn/tests/test_discriminant_analysis.py @@ -729,7 +729,7 @@ def test_lda_array_api(array_namespace): result, result_xp_np, err_msg=f"{method} did not the return the same result", - atol=1e-6, + atol=1e-5, ) From ef20a479b24d9bc75c21021294093bf6b54545b9 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Wed, 29 Mar 2023 10:23:54 -0400 Subject: [PATCH 06/42] CLN Address comments --- sklearn/tests/test_discriminant_analysis.py | 10 +++--- sklearn/utils/_array_api.py | 38 +++++++++------------ sklearn/utils/tests/test_array_api.py | 15 ++------ 3 files changed, 25 insertions(+), 38 deletions(-) diff --git a/sklearn/tests/test_discriminant_analysis.py b/sklearn/tests/test_discriminant_analysis.py index 90337cc826132..afebb66a05685 100644 --- a/sklearn/tests/test_discriminant_analysis.py +++ b/sklearn/tests/test_discriminant_analysis.py @@ -701,7 +701,7 @@ def test_lda_array_api(array_namespace): lda_xp_param = getattr(lda_xp, key) assert hasattr(lda_xp_param, "__array_namespace__") - lda_xp_param_np = _convert_to_numpy(lda_xp_param, xp=xp) + lda_xp_param_np = _convert_to_numpy(lda_xp_param) assert_allclose( attribute, lda_xp_param_np, err_msg=f"{key} not the same", atol=1e-3 ) @@ -723,7 +723,7 @@ def test_lda_array_api(array_namespace): result_xp, "__array_namespace__" ), f"{method} did not output an array_namespace" - result_xp_np = _convert_to_numpy(result_xp, xp=xp) + result_xp_np = _convert_to_numpy(result_xp) assert_allclose( result, @@ -759,8 +759,9 @@ def test_lda_array_torch(device, dtype): for key, attribute in array_attributes.items(): lda_xp_param = getattr(lda_xp, key) assert isinstance(lda_xp_param, torch.Tensor) + assert lda_xp_param.device.type == device - lda_xp_param_np = _convert_to_numpy(lda_xp_param, xp=torch) + lda_xp_param_np = _convert_to_numpy(lda_xp_param) assert_allclose( attribute, lda_xp_param_np, err_msg=f"{key} not the same", atol=1e-3 ) @@ -779,8 +780,9 @@ def test_lda_array_torch(device, dtype): result_xp = getattr(lda_xp, method)(X_torch) assert isinstance(result_xp, torch.Tensor) + assert result_xp.device.type == device - result_xp_np = _convert_to_numpy(result_xp, xp=torch) + result_xp_np = _convert_to_numpy(result_xp) assert_allclose( result, diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py index ba27af35324f9..357830c96ba39 100644 --- a/sklearn/utils/_array_api.py +++ b/sklearn/utils/_array_api.py @@ -7,7 +7,7 @@ import sklearn.externals._array_api_compat.numpy as array_api_compat_numpy from sklearn.externals._array_api_compat import device, size # noqa -from .._config import get_config +from .._config import get_config, config_context def _is_numpy_namespace(xp): @@ -143,25 +143,19 @@ def get_namespace(*arrays): is_array_api : bool True of the arrays are containers that implement the Array API spec. """ - return _get_namespace( - *arrays, array_api_dispatch=get_config()["array_api_dispatch"] - ) - - -def _get_namespace(*arrays, array_api_dispatch=False): + array_api_dispatch = get_config()["array_api_dispatch"] if not array_api_dispatch: return array_api_compat_numpy, False + try: - namespace, is_array = array_api_compat.get_namespace(*arrays), True - except TypeError as e: - if str(e).startswith("The input is not a supported array type"): - return array_api_compat_numpy, False - raise + namespace, is_array_api = array_api_compat.get_namespace(*arrays), True + except TypeError: + return array_api_compat_numpy, False if namespace.__name__ in {"numpy.array_api", "cupy.array_api"}: namespace = _ArrayAPIWrapper(namespace) - return namespace, is_array + return namespace, is_array_api def _expit(X): @@ -198,20 +192,21 @@ def _asarray_with_order(array, dtype=None, order=None, copy=None, xp=None): return xp.asarray(array, dtype=dtype, copy=copy) -def _convert_to_numpy(array, xp): +def _convert_to_numpy(array): """Convert X into a NumPy ndarray on the CPU.""" + with config_context(array_api_dispatch=True): + xp, _ = get_namespace(array) + xp_name = xp.__name__ - if _is_numpy_namespace(xp): - return numpy.asarray(array) - elif xp_name in {"sklearn.externals._array_api_compat.torch", "torch"}: + if xp_name in {"sklearn.externals._array_api_compat.torch", "torch"}: return array.cpu().numpy() elif xp_name == "cupy.array_api": return array._array.get() elif xp_name in {"sklearn.externals._array_api_compat.cupy", "cupy"}: return array.get() - else: - raise ValueError(f"{xp_name} is an unsupported namespace") + + return numpy.asarray(array) def _estimator_with_converted_arrays(estimator, converter): @@ -234,8 +229,9 @@ def _estimator_with_converted_arrays(estimator, converter): new_estimator = clone(estimator) for key, attribute in vars(estimator).items(): - _, is_array = _get_namespace(attribute, array_api_dispatch=True) - if is_array: + with config_context(array_api_dispatch=True): + _, is_array_api = get_namespace(attribute) + if is_array_api: attribute = converter(attribute) setattr(new_estimator, key, attribute) return new_estimator diff --git a/sklearn/utils/tests/test_array_api.py b/sklearn/utils/tests/test_array_api.py index cb8cdeda32376..65d56fa77713d 100644 --- a/sklearn/utils/tests/test_array_api.py +++ b/sklearn/utils/tests/test_array_api.py @@ -145,17 +145,6 @@ def test_asarray_with_order_ignored(): assert not X_new_np.flags["F_CONTIGUOUS"] -def test_convert_to_numpy_error(): - """Test convert to numpy errors for unsupported namespaces.""" - xp = pytest.importorskip("numpy.array_api") - xp_ = _AdjustableNameAPITestWrapper(xp, "wrapped.array_api") - - X = xp_.asarray([1.2, 3.4]) - - with pytest.raises(ValueError, match="wrapped.array_api is an unsupported"): - _convert_to_numpy(X, xp=xp_) - - @pytest.mark.parametrize("library", ["cupy", "torch", "cupy.array_api"]) def test_convert_to_numpy_gpu(library): """Check convert_to_numpy for GPU backed libraries.""" @@ -168,7 +157,7 @@ def test_convert_to_numpy_gpu(library): else: X_gpu = xp.asarray([1.0, 2.0, 3.0]) - X_cpu = _convert_to_numpy(X_gpu, xp=xp) + X_cpu = _convert_to_numpy(X_gpu) expected_output = numpy.asarray([1.0, 2.0, 3.0]) assert_allclose(X_cpu, expected_output) @@ -178,7 +167,7 @@ def test_convert_to_numpy_cpu(): torch = pytest.importorskip("torch") X_torch = torch.asarray([1.0, 2.0, 3.0], device="cpu") - X_cpu = _convert_to_numpy(X_torch, xp=torch) + X_cpu = _convert_to_numpy(X_torch) expected_output = numpy.asarray([1.0, 2.0, 3.0]) assert_allclose(X_cpu, expected_output) From 52fdd9935a98af7c344e0764b7a9c996e2b587b3 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Wed, 29 Mar 2023 12:32:49 -0400 Subject: [PATCH 07/42] API Move array_api_compat to optional dependency --- maint_tools/vendor_array_api_compat.sh | 13 - .../externals/_array_api_compat/__init__.py | 22 - .../externals/_array_api_compat/_internal.py | 43 -- .../_array_api_compat/common/__init__.py | 1 - .../_array_api_compat/common/_aliases.py | 523 -------------- .../_array_api_compat/common/_helpers.py | 229 ------ .../_array_api_compat/common/_linalg.py | 146 ---- .../_array_api_compat/common/_typing.py | 20 - .../_array_api_compat/cupy/__init__.py | 16 - .../_array_api_compat/cupy/_aliases.py | 69 -- .../_array_api_compat/cupy/_typing.py | 46 -- .../_array_api_compat/cupy/linalg.py | 41 -- .../_array_api_compat/numpy/__init__.py | 22 - .../_array_api_compat/numpy/_aliases.py | 69 -- .../_array_api_compat/numpy/_typing.py | 46 -- .../_array_api_compat/numpy/linalg.py | 34 - .../_array_api_compat/torch/__init__.py | 22 - .../_array_api_compat/torch/_aliases.py | 666 ------------------ .../_array_api_compat/torch/linalg.py | 27 - sklearn/tests/test_discriminant_analysis.py | 3 + sklearn/utils/_array_api.py | 238 +++++-- sklearn/utils/_testing.py | 11 + sklearn/utils/tests/test_array_api.py | 60 +- 23 files changed, 250 insertions(+), 2117 deletions(-) delete mode 100755 maint_tools/vendor_array_api_compat.sh delete mode 100644 sklearn/externals/_array_api_compat/__init__.py delete mode 100644 sklearn/externals/_array_api_compat/_internal.py delete mode 100644 sklearn/externals/_array_api_compat/common/__init__.py delete mode 100644 sklearn/externals/_array_api_compat/common/_aliases.py delete mode 100644 sklearn/externals/_array_api_compat/common/_helpers.py delete mode 100644 sklearn/externals/_array_api_compat/common/_linalg.py delete mode 100644 sklearn/externals/_array_api_compat/common/_typing.py delete mode 100644 sklearn/externals/_array_api_compat/cupy/__init__.py delete mode 100644 sklearn/externals/_array_api_compat/cupy/_aliases.py delete mode 100644 sklearn/externals/_array_api_compat/cupy/_typing.py delete mode 100644 sklearn/externals/_array_api_compat/cupy/linalg.py delete mode 100644 sklearn/externals/_array_api_compat/numpy/__init__.py delete mode 100644 sklearn/externals/_array_api_compat/numpy/_aliases.py delete mode 100644 sklearn/externals/_array_api_compat/numpy/_typing.py delete mode 100644 sklearn/externals/_array_api_compat/numpy/linalg.py delete mode 100644 sklearn/externals/_array_api_compat/torch/__init__.py delete mode 100644 sklearn/externals/_array_api_compat/torch/_aliases.py delete mode 100644 sklearn/externals/_array_api_compat/torch/linalg.py diff --git a/maint_tools/vendor_array_api_compat.sh b/maint_tools/vendor_array_api_compat.sh deleted file mode 100755 index 7ce2b7995b72d..0000000000000 --- a/maint_tools/vendor_array_api_compat.sh +++ /dev/null @@ -1,13 +0,0 @@ -#!/bin/bash - -# Vendors https://github.com/data-apis/array-api-compat/ into sklearn/externals - -ARRAY_API_COMPAT_SHA="b32a5b32892f5f4b5052ef54a04b8ed51936b008" -URL="https://github.com/data-apis/array-api-compat/archive/$ARRAY_API_COMPAT_SHA.tar.gz" - -rm -rf sklearn/externals/_array_api_compat - -curl -s -L $URL | - tar xvz --strip-components=1 -C sklearn/externals array-api-compat-$ARRAY_API_COMPAT_SHA/array_api_compat - -mv sklearn/externals/array_api_compat sklearn/externals/_array_api_compat diff --git a/sklearn/externals/_array_api_compat/__init__.py b/sklearn/externals/_array_api_compat/__init__.py deleted file mode 100644 index c92d3d89e3c63..0000000000000 --- a/sklearn/externals/_array_api_compat/__init__.py +++ /dev/null @@ -1,22 +0,0 @@ -""" -NumPy Array API compatibility library - -This is a small wrapper around NumPy and CuPy that is compatible with the -Array API standard https://data-apis.org/array-api/latest/. See also NEP 47 -https://numpy.org/neps/nep-0047-array-api-standard.html. - -Unlike numpy.array_api, this is not a strict minimal implementation of the -Array API, but rather just an extension of the main NumPy namespace with -changes needed to be compliant with the Array API. See -https://numpy.org/doc/stable/reference/array_api.html for a full list of -changes. In particular, unlike numpy.array_api, this package does not use a -separate Array object, but rather just uses numpy.ndarray directly. - -Library authors using the Array API may wish to test against numpy.array_api -to ensure they are not using functionality outside of the standard, but prefer -this implementation for the default when working with NumPy arrays. - -""" -__version__ = '1.1.1' - -from .common import * diff --git a/sklearn/externals/_array_api_compat/_internal.py b/sklearn/externals/_array_api_compat/_internal.py deleted file mode 100644 index 553c03561b45e..0000000000000 --- a/sklearn/externals/_array_api_compat/_internal.py +++ /dev/null @@ -1,43 +0,0 @@ -""" -Internal helpers -""" - -from functools import wraps -from inspect import signature - -def get_xp(xp): - """ - Decorator to automatically replace xp with the corresponding array module. - - Use like - - import numpy as np - - @get_xp(np) - def func(x, /, xp, kwarg=None): - return xp.func(x, kwarg=kwarg) - - Note that xp must be a keyword argument and come after all non-keyword - arguments. - - """ - def inner(f): - @wraps(f) - def wrapped_f(*args, **kwargs): - return f(*args, xp=xp, **kwargs) - - sig = signature(f) - new_sig = sig.replace(parameters=[sig.parameters[i] for i in sig.parameters if i != 'xp']) - - if wrapped_f.__doc__ is None: - wrapped_f.__doc__ = f"""\ -Array API compatibility wrapper for {f.__name__}. - -See the corresponding documentation in NumPy/CuPy and/or the array API -specification for more details. - -""" - wrapped_f.__signature__ = new_sig - return wrapped_f - - return inner diff --git a/sklearn/externals/_array_api_compat/common/__init__.py b/sklearn/externals/_array_api_compat/common/__init__.py deleted file mode 100644 index ce3f44dd486cb..0000000000000 --- a/sklearn/externals/_array_api_compat/common/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from ._helpers import * diff --git a/sklearn/externals/_array_api_compat/common/_aliases.py b/sklearn/externals/_array_api_compat/common/_aliases.py deleted file mode 100644 index 87f0d766db03d..0000000000000 --- a/sklearn/externals/_array_api_compat/common/_aliases.py +++ /dev/null @@ -1,523 +0,0 @@ -""" -These are functions that are just aliases of existing functions in NumPy. -""" - -from __future__ import annotations - -from typing import TYPE_CHECKING -if TYPE_CHECKING: - from typing import Optional, Sequence, Tuple, Union, List - from ._typing import ndarray, Device, Dtype, NestedSequence, SupportsBufferProtocol - -from typing import NamedTuple -from types import ModuleType -import inspect - -from ._helpers import _check_device, _is_numpy_array, array_namespace - -# These functions are modified from the NumPy versions. - -def arange( - start: Union[int, float], - /, - stop: Optional[Union[int, float]] = None, - step: Union[int, float] = 1, - *, - xp, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, - **kwargs -) -> ndarray: - _check_device(xp, device) - return xp.arange(start, stop=stop, step=step, dtype=dtype, **kwargs) - -def empty( - shape: Union[int, Tuple[int, ...]], - xp, - *, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, - **kwargs -) -> ndarray: - _check_device(xp, device) - return xp.empty(shape, dtype=dtype, **kwargs) - -def empty_like( - x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None, - **kwargs -) -> ndarray: - _check_device(xp, device) - return xp.empty_like(x, dtype=dtype, **kwargs) - -def eye( - n_rows: int, - n_cols: Optional[int] = None, - /, - *, - xp, - k: int = 0, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, - **kwargs, -) -> ndarray: - _check_device(xp, device) - return xp.eye(n_rows, M=n_cols, k=k, dtype=dtype, **kwargs) - -def full( - shape: Union[int, Tuple[int, ...]], - fill_value: Union[int, float], - xp, - *, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, - **kwargs, -) -> ndarray: - _check_device(xp, device) - return xp.full(shape, fill_value, dtype=dtype, **kwargs) - -def full_like( - x: ndarray, - /, - fill_value: Union[int, float], - *, - xp, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, - **kwargs, -) -> ndarray: - _check_device(xp, device) - return xp.full_like(x, fill_value, dtype=dtype, **kwargs) - -def linspace( - start: Union[int, float], - stop: Union[int, float], - /, - num: int, - *, - xp, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, - endpoint: bool = True, - **kwargs, -) -> ndarray: - _check_device(xp, device) - return xp.linspace(start, stop, num, dtype=dtype, endpoint=endpoint, **kwargs) - -def ones( - shape: Union[int, Tuple[int, ...]], - xp, - *, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, - **kwargs, -) -> ndarray: - _check_device(xp, device) - return xp.ones(shape, dtype=dtype, **kwargs) - -def ones_like( - x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None, - **kwargs, -) -> ndarray: - _check_device(xp, device) - return xp.ones_like(x, dtype=dtype, **kwargs) - -def zeros( - shape: Union[int, Tuple[int, ...]], - xp, - *, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, - **kwargs, -) -> ndarray: - _check_device(xp, device) - return xp.zeros(shape, dtype=dtype, **kwargs) - -def zeros_like( - x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None, - **kwargs, -) -> ndarray: - _check_device(xp, device) - return xp.zeros_like(x, dtype=dtype, **kwargs) - -# np.unique() is split into four functions in the array API: -# unique_all, unique_counts, unique_inverse, and unique_values (this is done -# to remove polymorphic return types). - -# The functions here return namedtuples (np.unique() returns a normal -# tuple). -class UniqueAllResult(NamedTuple): - values: ndarray - indices: ndarray - inverse_indices: ndarray - counts: ndarray - - -class UniqueCountsResult(NamedTuple): - values: ndarray - counts: ndarray - - -class UniqueInverseResult(NamedTuple): - values: ndarray - inverse_indices: ndarray - - -def _unique_kwargs(xp): - # Older versions of NumPy and CuPy do not have equal_nan. Rather than - # trying to parse version numbers, just check if equal_nan is in the - # signature. - s = inspect.signature(xp.unique) - if 'equal_nan' in s.parameters: - return {'equal_nan': False} - return {} - -def unique_all(x: ndarray, /, xp) -> UniqueAllResult: - kwargs = _unique_kwargs(xp) - values, indices, inverse_indices, counts = xp.unique( - x, - return_counts=True, - return_index=True, - return_inverse=True, - **kwargs, - ) - # np.unique() flattens inverse indices, but they need to share x's shape - # See https://github.com/numpy/numpy/issues/20638 - inverse_indices = inverse_indices.reshape(x.shape) - return UniqueAllResult( - values, - indices, - inverse_indices, - counts, - ) - - -def unique_counts(x: ndarray, /, xp) -> UniqueCountsResult: - kwargs = _unique_kwargs(xp) - res = xp.unique( - x, - return_counts=True, - return_index=False, - return_inverse=False, - **kwargs - ) - - return UniqueCountsResult(*res) - - -def unique_inverse(x: ndarray, /, xp) -> UniqueInverseResult: - kwargs = _unique_kwargs(xp) - values, inverse_indices = xp.unique( - x, - return_counts=False, - return_index=False, - return_inverse=True, - **kwargs, - ) - # xp.unique() flattens inverse indices, but they need to share x's shape - # See https://github.com/numpy/numpy/issues/20638 - inverse_indices = inverse_indices.reshape(x.shape) - return UniqueInverseResult(values, inverse_indices) - - -def unique_values(x: ndarray, /, xp) -> ndarray: - kwargs = _unique_kwargs(xp) - return xp.unique( - x, - return_counts=False, - return_index=False, - return_inverse=False, - **kwargs, - ) - -def astype(x: ndarray, dtype: Dtype, /, *, copy: bool = True) -> ndarray: - if not copy and dtype == x.dtype: - return x - return x.astype(dtype=dtype, copy=copy) - -# These functions have different keyword argument names - -def std( - x: ndarray, - /, - xp, - *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, - correction: Union[int, float] = 0.0, # correction instead of ddof - keepdims: bool = False, - **kwargs, -) -> ndarray: - return xp.std(x, axis=axis, ddof=correction, keepdims=keepdims, **kwargs) - -def var( - x: ndarray, - /, - xp, - *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, - correction: Union[int, float] = 0.0, # correction instead of ddof - keepdims: bool = False, - **kwargs, -) -> ndarray: - return xp.var(x, axis=axis, ddof=correction, keepdims=keepdims, **kwargs) - -# Unlike transpose(), the axes argument to permute_dims() is required. -def permute_dims(x: ndarray, /, axes: Tuple[int, ...], xp) -> ndarray: - return xp.transpose(x, axes) - -# Creation functions add the device keyword (which does nothing for NumPy) - -# asarray also adds the copy keyword -def _asarray( - obj: Union[ - ndarray, - bool, - int, - float, - NestedSequence[bool | int | float], - SupportsBufferProtocol, - ], - /, - *, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, - copy: "Optional[Union[bool, np._CopyMode]]" = None, - namespace = None, - **kwargs, -) -> ndarray: - """ - Array API compatibility wrapper for asarray(). - - See the corresponding documentation in NumPy/CuPy and/or the array API - specification for more details. - - """ - if namespace is None: - try: - xp = array_namespace(obj, _use_compat=False) - except ValueError: - # TODO: What about lists of arrays? - raise ValueError("A namespace must be specified for asarray() with non-array input") - elif isinstance(namespace, ModuleType): - xp = namespace - elif namespace == 'numpy': - import numpy as xp - elif namespace == 'cupy': - import cupy as xp - else: - raise ValueError("Unrecognized namespace argument to asarray()") - - _check_device(xp, device) - if _is_numpy_array(obj): - import numpy as np - if hasattr(np, '_CopyMode'): - # Not present in older NumPys - COPY_FALSE = (False, np._CopyMode.IF_NEEDED) - COPY_TRUE = (True, np._CopyMode.ALWAYS) - else: - COPY_FALSE = (False,) - COPY_TRUE = (True,) - else: - COPY_FALSE = (False,) - COPY_TRUE = (True,) - if copy in COPY_FALSE: - # copy=False is not yet implemented in xp.asarray - raise NotImplementedError("copy=False is not yet implemented") - if isinstance(obj, xp.ndarray): - if dtype is not None and obj.dtype != dtype: - copy = True - if copy in COPY_TRUE: - return xp.array(obj, copy=True, dtype=dtype) - return obj - - return xp.asarray(obj, dtype=dtype, **kwargs) - -# xp.reshape calls the keyword argument 'newshape' instead of 'shape' -def reshape(x: ndarray, - /, - shape: Tuple[int, ...], - xp, copy: Optional[bool] = None, - **kwargs) -> ndarray: - if copy is True: - x = x.copy() - elif copy is False: - x.shape = shape - return x - return xp.reshape(x, shape, **kwargs) - -# The descending keyword is new in sort and argsort, and 'kind' replaced with -# 'stable' -def argsort( - x: ndarray, /, xp, *, axis: int = -1, descending: bool = False, stable: bool = True, - **kwargs, -) -> ndarray: - # Note: this keyword argument is different, and the default is different. - # We set it in kwargs like this because numpy.sort uses kind='quicksort' - # as the default whereas cupy.sort uses kind=None. - if stable: - kwargs['kind'] = "stable" - if not descending: - res = xp.argsort(x, axis=axis, **kwargs) - else: - # As NumPy has no native descending sort, we imitate it here. Note that - # simply flipping the results of xp.argsort(x, ...) would not - # respect the relative order like it would in native descending sorts. - res = xp.flip( - xp.argsort(xp.flip(x, axis=axis), axis=axis, **kwargs), - axis=axis, - ) - # Rely on flip()/argsort() to validate axis - normalised_axis = axis if axis >= 0 else x.ndim + axis - max_i = x.shape[normalised_axis] - 1 - res = max_i - res - return res - -def sort( - x: ndarray, /, xp, *, axis: int = -1, descending: bool = False, stable: bool = True, - **kwargs, -) -> ndarray: - # Note: this keyword argument is different, and the default is different. - # We set it in kwargs like this because numpy.sort uses kind='quicksort' - # as the default whereas cupy.sort uses kind=None. - if stable: - kwargs['kind'] = "stable" - res = xp.sort(x, axis=axis, **kwargs) - if descending: - res = xp.flip(res, axis=axis) - return res - -# sum() and prod() should always upcast when dtype=None -def sum( - x: ndarray, - /, - xp, - *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, - dtype: Optional[Dtype] = None, - keepdims: bool = False, - **kwargs, -) -> ndarray: - # `xp.sum` already upcasts integers, but not floats - if dtype is None and x.dtype == xp.float32: - dtype = xp.float64 - return xp.sum(x, axis=axis, dtype=dtype, keepdims=keepdims, **kwargs) - -def prod( - x: ndarray, - /, - xp, - *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, - dtype: Optional[Dtype] = None, - keepdims: bool = False, - **kwargs, -) -> ndarray: - if dtype is None and x.dtype == xp.float32: - dtype = xp.float64 - return xp.prod(x, dtype=dtype, axis=axis, keepdims=keepdims, **kwargs) - -# ceil, floor, and trunc return integers for integer inputs - -def ceil(x: ndarray, /, xp, **kwargs) -> ndarray: - if xp.issubdtype(x.dtype, xp.integer): - return x - return xp.ceil(x, **kwargs) - -def floor(x: ndarray, /, xp, **kwargs) -> ndarray: - if xp.issubdtype(x.dtype, xp.integer): - return x - return xp.floor(x, **kwargs) - -def trunc(x: ndarray, /, xp, **kwargs) -> ndarray: - if xp.issubdtype(x.dtype, xp.integer): - return x - return xp.trunc(x, **kwargs) - -# linear algebra functions - -def matmul(x1: ndarray, x2: ndarray, /, xp, **kwargs) -> ndarray: - return xp.matmul(x1, x2, **kwargs) - -# Unlike transpose, matrix_transpose only transposes the last two axes. -def matrix_transpose(x: ndarray, /, xp) -> ndarray: - if x.ndim < 2: - raise ValueError("x must be at least 2-dimensional for matrix_transpose") - return xp.swapaxes(x, -1, -2) - -def tensordot(x1: ndarray, - x2: ndarray, - /, - xp, - *, - axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2, - **kwargs, -) -> ndarray: - return xp.tensordot(x1, x2, axes=axes, **kwargs) - -def vecdot(x1: ndarray, x2: ndarray, /, xp, *, axis: int = -1) -> ndarray: - ndim = max(x1.ndim, x2.ndim) - x1_shape = (1,)*(ndim - x1.ndim) + tuple(x1.shape) - x2_shape = (1,)*(ndim - x2.ndim) + tuple(x2.shape) - if x1_shape[axis] != x2_shape[axis]: - raise ValueError("x1 and x2 must have the same size along the given axis") - - if hasattr(xp, 'broadcast_tensors'): - _broadcast = xp.broadcast_tensors - else: - _broadcast = xp.broadcast_arrays - - x1_, x2_ = _broadcast(x1, x2) - x1_ = xp.moveaxis(x1_, axis, -1) - x2_ = xp.moveaxis(x2_, axis, -1) - - res = x1_[..., None, :] @ x2_[..., None] - return res[..., 0, 0] - -# isdtype is a new function in the 2022.12 array API specification. - -def isdtype( - dtype: Dtype, kind: Union[Dtype, str, Tuple[Union[Dtype, str], ...]], xp, - *, _tuple=True, # Disallow nested tuples -) -> bool: - """ - Returns a boolean indicating whether a provided dtype is of a specified data type ``kind``. - - Note that outside of this function, this compat library does not yet fully - support complex numbers. - - See - https://data-apis.org/array-api/latest/API_specification/generated/array_api.isdtype.html - for more details - """ - if isinstance(kind, tuple) and _tuple: - return any(isdtype(dtype, k, xp, _tuple=False) for k in kind) - elif isinstance(kind, str): - if kind == 'bool': - return dtype == xp.bool_ - elif kind == 'signed integer': - return xp.issubdtype(dtype, xp.signedinteger) - elif kind == 'unsigned integer': - return xp.issubdtype(dtype, xp.unsignedinteger) - elif kind == 'integral': - return xp.issubdtype(dtype, xp.integer) - elif kind == 'real floating': - return xp.issubdtype(dtype, xp.floating) - elif kind == 'complex floating': - return xp.issubdtype(dtype, xp.complexfloating) - elif kind == 'numeric': - return xp.issubdtype(dtype, xp.number) - else: - raise ValueError(f"Unrecognized data type kind: {kind!r}") - else: - # This will allow things that aren't required by the spec, like - # isdtype(np.float64, float) or isdtype(np.int64, 'l'). Should we be - # more strict here to match the type annotation? Note that the - # numpy.array_api implementation will be very strict. - return dtype == kind - -__all__ = ['arange', 'empty', 'empty_like', 'eye', 'full', 'full_like', - 'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like', - 'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult', - 'unique_all', 'unique_counts', 'unique_inverse', 'unique_values', - 'astype', 'std', 'var', 'permute_dims', 'reshape', 'argsort', - 'sort', 'sum', 'prod', 'ceil', 'floor', 'trunc', 'matmul', - 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype'] diff --git a/sklearn/externals/_array_api_compat/common/_helpers.py b/sklearn/externals/_array_api_compat/common/_helpers.py deleted file mode 100644 index e6adc948522bd..0000000000000 --- a/sklearn/externals/_array_api_compat/common/_helpers.py +++ /dev/null @@ -1,229 +0,0 @@ -""" -Various helper functions which are not part of the spec. - -Functions which start with an underscore are for internal use only but helpers -that are in __all__ are intended as additional helper functions for use by end -users of the compat library. -""" -from __future__ import annotations - -import sys -import math - -def _is_numpy_array(x): - # Avoid importing NumPy if it isn't already - if 'numpy' not in sys.modules: - return False - - import numpy as np - - # TODO: Should we reject ndarray subclasses? - return isinstance(x, (np.ndarray, np.generic)) - -def _is_cupy_array(x): - # Avoid importing NumPy if it isn't already - if 'cupy' not in sys.modules: - return False - - import cupy as cp - - # TODO: Should we reject ndarray subclasses? - return isinstance(x, (cp.ndarray, cp.generic)) - -def _is_torch_array(x): - # Avoid importing torch if it isn't already - if 'torch' not in sys.modules: - return False - - import torch - - # TODO: Should we reject ndarray subclasses? - return isinstance(x, torch.Tensor) - -def is_array_api_obj(x): - """ - Check if x is an array API compatible array object. - """ - return _is_numpy_array(x) \ - or _is_cupy_array(x) \ - or _is_torch_array(x) \ - or hasattr(x, '__array_namespace__') - -def _check_api_version(api_version): - if api_version is not None and api_version != '2021.12': - raise ValueError("Only the 2021.12 version of the array API specification is currently supported") - -def array_namespace(*xs, api_version=None, _use_compat=True): - """ - Get the array API compatible namespace for the arrays `xs`. - - `xs` should contain one or more arrays. - - Typical usage is - - def your_function(x, y): - xp = array_api_compat.array_namespace(x, y) - # Now use xp as the array library namespace - return xp.mean(x, axis=0) + 2*xp.std(y, axis=0) - - api_version should be the newest version of the spec that you need support - for (currently the compat library wrapped APIs only support v2021.12). - """ - namespaces = set() - for x in xs: - if isinstance(x, (tuple, list)): - namespaces.add(array_namespace(*x, _use_compat=_use_compat)) - elif hasattr(x, '__array_namespace__'): - namespaces.add(x.__array_namespace__(api_version=api_version)) - elif _is_numpy_array(x): - _check_api_version(api_version) - if _use_compat: - from .. import numpy as numpy_namespace - namespaces.add(numpy_namespace) - else: - import numpy as np - namespaces.add(np) - elif _is_cupy_array(x): - _check_api_version(api_version) - if _use_compat: - from .. import cupy as cupy_namespace - namespaces.add(cupy_namespace) - else: - import cupy as cp - namespaces.add(cp) - elif _is_torch_array(x): - _check_api_version(api_version) - if _use_compat: - from .. import torch as torch_namespace - namespaces.add(torch_namespace) - else: - import torch - namespaces.add(torch) - else: - # TODO: Support Python scalars? - raise TypeError("The input is not a supported array type") - - if not namespaces: - raise TypeError("Unrecognized array input") - - if len(namespaces) != 1: - raise TypeError(f"Multiple namespaces for array inputs: {namespaces}") - - xp, = namespaces - - return xp - -# backwards compatibility alias -get_namespace = array_namespace - -def _check_device(xp, device): - if xp == sys.modules.get('numpy'): - if device not in ["cpu", None]: - raise ValueError(f"Unsupported device for NumPy: {device!r}") - -# device() is not on numpy.ndarray and and to_device() is not on numpy.ndarray -# or cupy.ndarray. They are not included in array objects of this library -# because this library just reuses the respective ndarray classes without -# wrapping or subclassing them. These helper functions can be used instead of -# the wrapper functions for libraries that need to support both NumPy/CuPy and -# other libraries that use devices. -def device(x: "Array", /) -> "Device": - """ - Hardware device the array data resides on. - - Parameters - ---------- - x: array - array instance from NumPy or an array API compatible library. - - Returns - ------- - out: device - a ``device`` object (see the "Device Support" section of the array API specification). - """ - if _is_numpy_array(x): - return "cpu" - return x.device - -# Based on cupy.array_api.Array.to_device -def _cupy_to_device(x, device, /, stream=None): - import cupy as cp - from cupy.cuda import Device as _Device - from cupy.cuda import stream as stream_module - from cupy_backends.cuda.api import runtime - - if device == x.device: - return x - elif not isinstance(device, _Device): - raise ValueError(f"Unsupported device {device!r}") - else: - # see cupy/cupy#5985 for the reason how we handle device/stream here - prev_device = runtime.getDevice() - prev_stream: stream_module.Stream = None - if stream is not None: - prev_stream = stream_module.get_current_stream() - # stream can be an int as specified in __dlpack__, or a CuPy stream - if isinstance(stream, int): - stream = cp.cuda.ExternalStream(stream) - elif isinstance(stream, cp.cuda.Stream): - pass - else: - raise ValueError('the input stream is not recognized') - stream.use() - try: - runtime.setDevice(device.id) - arr = x.copy() - finally: - runtime.setDevice(prev_device) - if stream is not None: - prev_stream.use() - return arr - -def _torch_to_device(x, device, /, stream=None): - if stream is not None: - raise NotImplementedError - return x.to(device) - -def to_device(x: "Array", device: "Device", /, *, stream: "Optional[Union[int, Any]]" = None) -> "Array": - """ - Copy the array from the device on which it currently resides to the specified ``device``. - - Parameters - ---------- - x: array - array instance from NumPy or an array API compatible library. - device: device - a ``device`` object (see the "Device Support" section of the array API specification). - stream: Optional[Union[int, Any]] - stream object to use during copy. In addition to the types supported in ``array.__dlpack__``, implementations may choose to support any library-specific stream object with the caveat that any code using such an object would not be portable. - - Returns - ------- - out: array - an array with the same data and data type as ``x`` and located on the specified ``device``. - - .. note:: - If ``stream`` is given, the copy operation should be enqueued on the provided ``stream``; otherwise, the copy operation should be enqueued on the default stream/queue. Whether the copy is performed synchronously or asynchronously is implementation-dependent. Accordingly, if synchronization is required to guarantee data safety, this must be clearly explained in a conforming library's documentation. - """ - if _is_numpy_array(x): - if stream is not None: - raise ValueError("The stream argument to to_device() is not supported") - if device == 'cpu': - return x - raise ValueError(f"Unsupported device {device!r}") - elif _is_cupy_array(x): - # cupy does not yet have to_device - return _cupy_to_device(x, device, stream=stream) - elif _is_torch_array(x): - return _torch_to_device(x, device, stream=stream) - return x.to_device(device, stream=stream) - -def size(x): - """ - Return the total number of elements of x - """ - if None in x.shape: - return None - return math.prod(x.shape) - -__all__ = ['is_array_api_obj', 'array_namespace', 'get_namespace', 'device', 'to_device', 'size'] diff --git a/sklearn/externals/_array_api_compat/common/_linalg.py b/sklearn/externals/_array_api_compat/common/_linalg.py deleted file mode 100644 index 07daefd9cfd99..0000000000000 --- a/sklearn/externals/_array_api_compat/common/_linalg.py +++ /dev/null @@ -1,146 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, NamedTuple -if TYPE_CHECKING: - from typing import Literal, Optional, Sequence, Tuple, Union - from ._typing import ndarray - -from numpy.core.numeric import normalize_axis_tuple - -from ._aliases import matmul, matrix_transpose, tensordot, vecdot -from .._internal import get_xp - -# These are in the main NumPy namespace but not in numpy.linalg -def cross(x1: ndarray, x2: ndarray, /, xp, *, axis: int = -1, **kwargs) -> ndarray: - return xp.cross(x1, x2, axis=axis, **kwargs) - -def outer(x1: ndarray, x2: ndarray, /, xp, **kwargs) -> ndarray: - return xp.outer(x1, x2, **kwargs) - -class EighResult(NamedTuple): - eigenvalues: ndarray - eigenvectors: ndarray - -class QRResult(NamedTuple): - Q: ndarray - R: ndarray - -class SlogdetResult(NamedTuple): - sign: ndarray - logabsdet: ndarray - -class SVDResult(NamedTuple): - U: ndarray - S: ndarray - Vh: ndarray - -# These functions are the same as their NumPy counterparts except they return -# a namedtuple. -def eigh(x: ndarray, /, xp, **kwargs) -> EighResult: - return EighResult(*xp.linalg.eigh(x, **kwargs)) - -def qr(x: ndarray, /, xp, *, mode: Literal['reduced', 'complete'] = 'reduced', - **kwargs) -> QRResult: - return QRResult(*xp.linalg.qr(x, mode=mode, **kwargs)) - -def slogdet(x: ndarray, /, xp, **kwargs) -> SlogdetResult: - return SlogdetResult(*xp.linalg.slogdet(x, **kwargs)) - -def svd(x: ndarray, /, xp, *, full_matrices: bool = True, **kwargs) -> SVDResult: - return SVDResult(*xp.linalg.svd(x, full_matrices=full_matrices, **kwargs)) - -# These functions have additional keyword arguments - -# The upper keyword argument is new from NumPy -def cholesky(x: ndarray, /, xp, *, upper: bool = False, **kwargs) -> ndarray: - L = xp.linalg.cholesky(x, **kwargs) - if upper: - return get_xp(xp)(matrix_transpose)(L) - return L - -# The rtol keyword argument of matrix_rank() and pinv() is new from NumPy. -# Note that it has a different semantic meaning from tol and rcond. -def matrix_rank(x: ndarray, - /, - xp, - *, - rtol: Optional[Union[float, ndarray]] = None, - **kwargs) -> ndarray: - # this is different from xp.linalg.matrix_rank, which supports 1 - # dimensional arrays. - if x.ndim < 2: - raise xp.linalg.LinAlgError("1-dimensional array given. Array must be at least two-dimensional") - S = xp.linalg.svd(x, compute_uv=False, **kwargs) - if rtol is None: - tol = S.max(axis=-1, keepdims=True) * max(x.shape[-2:]) * xp.finfo(S.dtype).eps - else: - # this is different from xp.linalg.matrix_rank, which does not - # multiply the tolerance by the largest singular value. - tol = S.max(axis=-1, keepdims=True)*xp.asarray(rtol)[..., xp.newaxis] - return xp.count_nonzero(S > tol, axis=-1) - -def pinv(x: ndarray, /, xp, *, rtol: Optional[Union[float, ndarray]] = None, **kwargs) -> ndarray: - # this is different from xp.linalg.pinv, which does not multiply the - # default tolerance by max(M, N). - if rtol is None: - rtol = max(x.shape[-2:]) * xp.finfo(x.dtype).eps - return xp.linalg.pinv(x, rcond=rtol, **kwargs) - -# These functions are new in the array API spec - -def matrix_norm(x: ndarray, /, xp, *, keepdims: bool = False, ord: Optional[Union[int, float, Literal['fro', 'nuc']]] = 'fro') -> ndarray: - return xp.linalg.norm(x, axis=(-2, -1), keepdims=keepdims, ord=ord) - -# svdvals is not in NumPy (but it is in SciPy). It is equivalent to -# xp.linalg.svd(compute_uv=False). -def svdvals(x: ndarray, /, xp) -> Union[ndarray, Tuple[ndarray, ...]]: - return xp.linalg.svd(x, compute_uv=False) - -def vector_norm(x: ndarray, /, xp, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, ord: Optional[Union[int, float]] = 2) -> ndarray: - # xp.linalg.norm tries to do a matrix norm whenever axis is a 2-tuple or - # when axis=None and the input is 2-D, so to force a vector norm, we make - # it so the input is 1-D (for axis=None), or reshape so that norm is done - # on a single dimension. - if axis is None: - # Note: xp.linalg.norm() doesn't handle 0-D arrays - x = x.ravel() - _axis = 0 - elif isinstance(axis, tuple): - # Note: The axis argument supports any number of axes, whereas - # xp.linalg.norm() only supports a single axis for vector norm. - normalized_axis = normalize_axis_tuple(axis, x.ndim) - rest = tuple(i for i in range(x.ndim) if i not in normalized_axis) - newshape = axis + rest - x = xp.transpose(x, newshape).reshape( - (xp.prod([x.shape[i] for i in axis], dtype=int), *[x.shape[i] for i in rest])) - _axis = 0 - else: - _axis = axis - - res = xp.linalg.norm(x, axis=_axis, ord=ord) - - if keepdims: - # We can't reuse xp.linalg.norm(keepdims) because of the reshape hacks - # above to avoid matrix norm logic. - shape = list(x.shape) - _axis = normalize_axis_tuple(range(x.ndim) if axis is None else axis, x.ndim) - for i in _axis: - shape[i] = 1 - res = xp.reshape(res, tuple(shape)) - - return res - -# xp.diagonal and xp.trace operate on the first two axes whereas these -# operates on the last two - -def diagonal(x: ndarray, /, xp, *, offset: int = 0, **kwargs) -> ndarray: - return xp.diagonal(x, offset=offset, axis1=-2, axis2=-1, **kwargs) - -def trace(x: ndarray, /, xp, *, offset: int = 0, **kwargs) -> ndarray: - return xp.asarray(xp.trace(x, offset=offset, axis1=-2, axis2=-1, **kwargs)) - -__all__ = ['cross', 'matmul', 'outer', 'tensordot', 'EighResult', - 'QRResult', 'SlogdetResult', 'SVDResult', 'eigh', 'qr', 'slogdet', - 'svd', 'cholesky', 'matrix_rank', 'pinv', 'matrix_norm', - 'matrix_transpose', 'svdvals', 'vecdot', 'vector_norm', 'diagonal', - 'trace'] diff --git a/sklearn/externals/_array_api_compat/common/_typing.py b/sklearn/externals/_array_api_compat/common/_typing.py deleted file mode 100644 index 3f17806094baa..0000000000000 --- a/sklearn/externals/_array_api_compat/common/_typing.py +++ /dev/null @@ -1,20 +0,0 @@ -from __future__ import annotations - -__all__ = [ - "NestedSequence", - "SupportsBufferProtocol", -] - -from typing import ( - Any, - TypeVar, - Protocol, -) - -_T_co = TypeVar("_T_co", covariant=True) - -class NestedSequence(Protocol[_T_co]): - def __getitem__(self, key: int, /) -> _T_co | NestedSequence[_T_co]: ... - def __len__(self, /) -> int: ... - -SupportsBufferProtocol = Any diff --git a/sklearn/externals/_array_api_compat/cupy/__init__.py b/sklearn/externals/_array_api_compat/cupy/__init__.py deleted file mode 100644 index 10c31bc6aad3a..0000000000000 --- a/sklearn/externals/_array_api_compat/cupy/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -from cupy import * - -# from cupy import * doesn't overwrite these builtin names -from cupy import abs, max, min, round - -# These imports may overwrite names from the import * above. -from ._aliases import * - -# See the comment in the numpy __init__.py -__import__(__package__ + '.linalg') - -from .linalg import matrix_transpose, vecdot - -from ..common._helpers import * - -__array_api_version__ = '2021.12' diff --git a/sklearn/externals/_array_api_compat/cupy/_aliases.py b/sklearn/externals/_array_api_compat/cupy/_aliases.py deleted file mode 100644 index b43c371f34f43..0000000000000 --- a/sklearn/externals/_array_api_compat/cupy/_aliases.py +++ /dev/null @@ -1,69 +0,0 @@ -from __future__ import annotations - -from functools import partial - -from ..common import _aliases - -from .._internal import get_xp - -asarray = asarray_cupy = partial(_aliases._asarray, namespace='cupy') -asarray.__doc__ = _aliases._asarray.__doc__ -del partial - -import cupy as cp -bool = cp.bool_ - -# Basic renames -acos = cp.arccos -acosh = cp.arccosh -asin = cp.arcsin -asinh = cp.arcsinh -atan = cp.arctan -atan2 = cp.arctan2 -atanh = cp.arctanh -bitwise_left_shift = cp.left_shift -bitwise_invert = cp.invert -bitwise_right_shift = cp.right_shift -concat = cp.concatenate -pow = cp.power - -arange = get_xp(cp)(_aliases.arange) -empty = get_xp(cp)(_aliases.empty) -empty_like = get_xp(cp)(_aliases.empty_like) -eye = get_xp(cp)(_aliases.eye) -full = get_xp(cp)(_aliases.full) -full_like = get_xp(cp)(_aliases.full_like) -linspace = get_xp(cp)(_aliases.linspace) -ones = get_xp(cp)(_aliases.ones) -ones_like = get_xp(cp)(_aliases.ones_like) -zeros = get_xp(cp)(_aliases.zeros) -zeros_like = get_xp(cp)(_aliases.zeros_like) -UniqueAllResult = get_xp(cp)(_aliases.UniqueAllResult) -UniqueCountsResult = get_xp(cp)(_aliases.UniqueCountsResult) -UniqueInverseResult = get_xp(cp)(_aliases.UniqueInverseResult) -unique_all = get_xp(cp)(_aliases.unique_all) -unique_counts = get_xp(cp)(_aliases.unique_counts) -unique_inverse = get_xp(cp)(_aliases.unique_inverse) -unique_values = get_xp(cp)(_aliases.unique_values) -astype = _aliases.astype -std = get_xp(cp)(_aliases.std) -var = get_xp(cp)(_aliases.var) -permute_dims = get_xp(cp)(_aliases.permute_dims) -reshape = get_xp(cp)(_aliases.reshape) -argsort = get_xp(cp)(_aliases.argsort) -sort = get_xp(cp)(_aliases.sort) -sum = get_xp(cp)(_aliases.sum) -prod = get_xp(cp)(_aliases.prod) -ceil = get_xp(cp)(_aliases.ceil) -floor = get_xp(cp)(_aliases.floor) -trunc = get_xp(cp)(_aliases.trunc) -matmul = get_xp(cp)(_aliases.matmul) -matrix_transpose = get_xp(cp)(_aliases.matrix_transpose) -tensordot = get_xp(cp)(_aliases.tensordot) -vecdot = get_xp(cp)(_aliases.vecdot) -isdtype = get_xp(cp)(_aliases.isdtype) - -__all__ = _aliases.__all__ + ['asarray', 'asarray_cupy', 'bool', 'acos', - 'acosh', 'asin', 'asinh', 'atan', 'atan2', - 'atanh', 'bitwise_left_shift', 'bitwise_invert', - 'bitwise_right_shift', 'concat', 'pow'] diff --git a/sklearn/externals/_array_api_compat/cupy/_typing.py b/sklearn/externals/_array_api_compat/cupy/_typing.py deleted file mode 100644 index f3d9aab67e52f..0000000000000 --- a/sklearn/externals/_array_api_compat/cupy/_typing.py +++ /dev/null @@ -1,46 +0,0 @@ -from __future__ import annotations - -__all__ = [ - "ndarray", - "Device", - "Dtype", -] - -import sys -from typing import ( - Union, - TYPE_CHECKING, -) - -from cupy import ( - ndarray, - dtype, - int8, - int16, - int32, - int64, - uint8, - uint16, - uint32, - uint64, - float32, - float64, -) - -from cupy.cuda.device import Device - -if TYPE_CHECKING or sys.version_info >= (3, 9): - Dtype = dtype[Union[ - int8, - int16, - int32, - int64, - uint8, - uint16, - uint32, - uint64, - float32, - float64, - ]] -else: - Dtype = dtype diff --git a/sklearn/externals/_array_api_compat/cupy/linalg.py b/sklearn/externals/_array_api_compat/cupy/linalg.py deleted file mode 100644 index 99c4cc68d783c..0000000000000 --- a/sklearn/externals/_array_api_compat/cupy/linalg.py +++ /dev/null @@ -1,41 +0,0 @@ -from cupy.linalg import * -# cupy.linalg doesn't have __all__. If it is added, replace this with -# -# from cupy.linalg import __all__ as linalg_all -_n = {} -exec('from cupy.linalg import *', _n) -del _n['__builtins__'] -linalg_all = list(_n) -del _n - -from ..common import _linalg -from .._internal import get_xp -from ._aliases import (matmul, matrix_transpose, tensordot, vecdot) - -import cupy as cp - -cross = get_xp(cp)(_linalg.cross) -outer = get_xp(cp)(_linalg.outer) -EighResult = _linalg.EighResult -QRResult = _linalg.QRResult -SlogdetResult = _linalg.SlogdetResult -SVDResult = _linalg.SVDResult -eigh = get_xp(cp)(_linalg.eigh) -qr = get_xp(cp)(_linalg.qr) -slogdet = get_xp(cp)(_linalg.slogdet) -svd = get_xp(cp)(_linalg.svd) -cholesky = get_xp(cp)(_linalg.cholesky) -matrix_rank = get_xp(cp)(_linalg.matrix_rank) -pinv = get_xp(cp)(_linalg.pinv) -matrix_norm = get_xp(cp)(_linalg.matrix_norm) -svdvals = get_xp(cp)(_linalg.svdvals) -vector_norm = get_xp(cp)(_linalg.vector_norm) -diagonal = get_xp(cp)(_linalg.diagonal) -trace = get_xp(cp)(_linalg.trace) - -__all__ = linalg_all + _linalg.__all__ - -del get_xp -del cp -del linalg_all -del _linalg diff --git a/sklearn/externals/_array_api_compat/numpy/__init__.py b/sklearn/externals/_array_api_compat/numpy/__init__.py deleted file mode 100644 index 745367bc8705e..0000000000000 --- a/sklearn/externals/_array_api_compat/numpy/__init__.py +++ /dev/null @@ -1,22 +0,0 @@ -from numpy import * - -# from numpy import * doesn't overwrite these builtin names -from numpy import abs, max, min, round - -# These imports may overwrite names from the import * above. -from ._aliases import * - -# Don't know why, but we have to do an absolute import to import linalg. If we -# instead do -# -# from . import linalg -# -# It doesn't overwrite np.linalg from above. The import is generated -# dynamically so that the library can be vendored. -__import__(__package__ + '.linalg') - -from .linalg import matrix_transpose, vecdot - -from ..common._helpers import * - -__array_api_version__ = '2021.12' diff --git a/sklearn/externals/_array_api_compat/numpy/_aliases.py b/sklearn/externals/_array_api_compat/numpy/_aliases.py deleted file mode 100644 index 08f4de0bafeeb..0000000000000 --- a/sklearn/externals/_array_api_compat/numpy/_aliases.py +++ /dev/null @@ -1,69 +0,0 @@ -from __future__ import annotations - -from functools import partial - -from ..common import _aliases - -from .._internal import get_xp - -asarray = asarray_numpy = partial(_aliases._asarray, namespace='numpy') -asarray.__doc__ = _aliases._asarray.__doc__ -del partial - -import numpy as np -bool = np.bool_ - -# Basic renames -acos = np.arccos -acosh = np.arccosh -asin = np.arcsin -asinh = np.arcsinh -atan = np.arctan -atan2 = np.arctan2 -atanh = np.arctanh -bitwise_left_shift = np.left_shift -bitwise_invert = np.invert -bitwise_right_shift = np.right_shift -concat = np.concatenate -pow = np.power - -arange = get_xp(np)(_aliases.arange) -empty = get_xp(np)(_aliases.empty) -empty_like = get_xp(np)(_aliases.empty_like) -eye = get_xp(np)(_aliases.eye) -full = get_xp(np)(_aliases.full) -full_like = get_xp(np)(_aliases.full_like) -linspace = get_xp(np)(_aliases.linspace) -ones = get_xp(np)(_aliases.ones) -ones_like = get_xp(np)(_aliases.ones_like) -zeros = get_xp(np)(_aliases.zeros) -zeros_like = get_xp(np)(_aliases.zeros_like) -UniqueAllResult = get_xp(np)(_aliases.UniqueAllResult) -UniqueCountsResult = get_xp(np)(_aliases.UniqueCountsResult) -UniqueInverseResult = get_xp(np)(_aliases.UniqueInverseResult) -unique_all = get_xp(np)(_aliases.unique_all) -unique_counts = get_xp(np)(_aliases.unique_counts) -unique_inverse = get_xp(np)(_aliases.unique_inverse) -unique_values = get_xp(np)(_aliases.unique_values) -astype = _aliases.astype -std = get_xp(np)(_aliases.std) -var = get_xp(np)(_aliases.var) -permute_dims = get_xp(np)(_aliases.permute_dims) -reshape = get_xp(np)(_aliases.reshape) -argsort = get_xp(np)(_aliases.argsort) -sort = get_xp(np)(_aliases.sort) -sum = get_xp(np)(_aliases.sum) -prod = get_xp(np)(_aliases.prod) -ceil = get_xp(np)(_aliases.ceil) -floor = get_xp(np)(_aliases.floor) -trunc = get_xp(np)(_aliases.trunc) -matmul = get_xp(np)(_aliases.matmul) -matrix_transpose = get_xp(np)(_aliases.matrix_transpose) -tensordot = get_xp(np)(_aliases.tensordot) -vecdot = get_xp(np)(_aliases.vecdot) -isdtype = get_xp(np)(_aliases.isdtype) - -__all__ = _aliases.__all__ + ['asarray', 'asarray_numpy', 'bool', 'acos', - 'acosh', 'asin', 'asinh', 'atan', 'atan2', - 'atanh', 'bitwise_left_shift', 'bitwise_invert', - 'bitwise_right_shift', 'concat', 'pow'] diff --git a/sklearn/externals/_array_api_compat/numpy/_typing.py b/sklearn/externals/_array_api_compat/numpy/_typing.py deleted file mode 100644 index c5ebb5abb9875..0000000000000 --- a/sklearn/externals/_array_api_compat/numpy/_typing.py +++ /dev/null @@ -1,46 +0,0 @@ -from __future__ import annotations - -__all__ = [ - "ndarray", - "Device", - "Dtype", -] - -import sys -from typing import ( - Literal, - Union, - TYPE_CHECKING, -) - -from numpy import ( - ndarray, - dtype, - int8, - int16, - int32, - int64, - uint8, - uint16, - uint32, - uint64, - float32, - float64, -) - -Device = Literal["cpu"] -if TYPE_CHECKING or sys.version_info >= (3, 9): - Dtype = dtype[Union[ - int8, - int16, - int32, - int64, - uint8, - uint16, - uint32, - uint64, - float32, - float64, - ]] -else: - Dtype = dtype diff --git a/sklearn/externals/_array_api_compat/numpy/linalg.py b/sklearn/externals/_array_api_compat/numpy/linalg.py deleted file mode 100644 index 26d6e88e1af47..0000000000000 --- a/sklearn/externals/_array_api_compat/numpy/linalg.py +++ /dev/null @@ -1,34 +0,0 @@ -from numpy.linalg import * -from numpy.linalg import __all__ as linalg_all - -from ..common import _linalg -from .._internal import get_xp -from ._aliases import (matmul, matrix_transpose, tensordot, vecdot) - -import numpy as np - -cross = get_xp(np)(_linalg.cross) -outer = get_xp(np)(_linalg.outer) -EighResult = _linalg.EighResult -QRResult = _linalg.QRResult -SlogdetResult = _linalg.SlogdetResult -SVDResult = _linalg.SVDResult -eigh = get_xp(np)(_linalg.eigh) -qr = get_xp(np)(_linalg.qr) -slogdet = get_xp(np)(_linalg.slogdet) -svd = get_xp(np)(_linalg.svd) -cholesky = get_xp(np)(_linalg.cholesky) -matrix_rank = get_xp(np)(_linalg.matrix_rank) -pinv = get_xp(np)(_linalg.pinv) -matrix_norm = get_xp(np)(_linalg.matrix_norm) -svdvals = get_xp(np)(_linalg.svdvals) -vector_norm = get_xp(np)(_linalg.vector_norm) -diagonal = get_xp(np)(_linalg.diagonal) -trace = get_xp(np)(_linalg.trace) - -__all__ = linalg_all + _linalg.__all__ - -del get_xp -del np -del linalg_all -del _linalg diff --git a/sklearn/externals/_array_api_compat/torch/__init__.py b/sklearn/externals/_array_api_compat/torch/__init__.py deleted file mode 100644 index 18776f1a0f73b..0000000000000 --- a/sklearn/externals/_array_api_compat/torch/__init__.py +++ /dev/null @@ -1,22 +0,0 @@ -from torch import * - -# Several names are not included in the above import * -import torch -for n in dir(torch): - if (n.startswith('_') - or n.endswith('_') - or 'cuda' in n - or 'cpu' in n - or 'backward' in n): - continue - exec(n + ' = torch.' + n) - -# These imports may overwrite names from the import * above. -from ._aliases import * - -# See the comment in the numpy __init__.py -__import__(__package__ + '.linalg') - -from ..common._helpers import * - -__array_api_version__ = '2021.12' diff --git a/sklearn/externals/_array_api_compat/torch/_aliases.py b/sklearn/externals/_array_api_compat/torch/_aliases.py deleted file mode 100644 index dbd4d8d9dccfa..0000000000000 --- a/sklearn/externals/_array_api_compat/torch/_aliases.py +++ /dev/null @@ -1,666 +0,0 @@ -from __future__ import annotations - -from functools import wraps -from builtins import all as builtin_all, any as builtin_any - -from ..common._aliases import (UniqueAllResult, UniqueCountsResult, - UniqueInverseResult, - matrix_transpose as _aliases_matrix_transpose, - vecdot as _aliases_vecdot) -from .._internal import get_xp - -import torch - -from typing import TYPE_CHECKING -if TYPE_CHECKING: - from typing import List, Optional, Sequence, Tuple, Union - from ..common._typing import Device - from torch import dtype as Dtype - - array = torch.Tensor - -_int_dtypes = { - torch.uint8, - torch.int8, - torch.int16, - torch.int32, - torch.int64, -} - -_array_api_dtypes = { - torch.bool, - *_int_dtypes, - torch.float32, - torch.float64, -} - -_promotion_table = { - # bool - (torch.bool, torch.bool): torch.bool, - # ints - (torch.int8, torch.int8): torch.int8, - (torch.int8, torch.int16): torch.int16, - (torch.int8, torch.int32): torch.int32, - (torch.int8, torch.int64): torch.int64, - (torch.int16, torch.int8): torch.int16, - (torch.int16, torch.int16): torch.int16, - (torch.int16, torch.int32): torch.int32, - (torch.int16, torch.int64): torch.int64, - (torch.int32, torch.int8): torch.int32, - (torch.int32, torch.int16): torch.int32, - (torch.int32, torch.int32): torch.int32, - (torch.int32, torch.int64): torch.int64, - (torch.int64, torch.int8): torch.int64, - (torch.int64, torch.int16): torch.int64, - (torch.int64, torch.int32): torch.int64, - (torch.int64, torch.int64): torch.int64, - # uints - (torch.uint8, torch.uint8): torch.uint8, - # ints and uints (mixed sign) - (torch.int8, torch.uint8): torch.int16, - (torch.int16, torch.uint8): torch.int16, - (torch.int32, torch.uint8): torch.int32, - (torch.int64, torch.uint8): torch.int64, - (torch.uint8, torch.int8): torch.int16, - (torch.uint8, torch.int16): torch.int16, - (torch.uint8, torch.int32): torch.int32, - (torch.uint8, torch.int64): torch.int64, - # floats - (torch.float32, torch.float32): torch.float32, - (torch.float32, torch.float64): torch.float64, - (torch.float64, torch.float32): torch.float64, - (torch.float64, torch.float64): torch.float64, -} - - -def _two_arg(f): - @wraps(f) - def _f(x1, x2, /, **kwargs): - x1, x2 = _fix_promotion(x1, x2) - return f(x1, x2, **kwargs) - if _f.__doc__ is None: - _f.__doc__ = f"""\ -Array API compatibility wrapper for torch.{f.__name__}. - -See the corresponding PyTorch documentation and/or the array API specification -for more details. - -""" - return _f - -def _fix_promotion(x1, x2, only_scalar=True): - if x1.dtype not in _array_api_dtypes or x2.dtype not in _array_api_dtypes: - return x1, x2 - # If an argument is 0-D pytorch downcasts the other argument - if not only_scalar or x1.shape == (): - dtype = result_type(x1, x2) - x2 = x2.to(dtype) - if not only_scalar or x2.shape == (): - dtype = result_type(x1, x2) - x1 = x1.to(dtype) - return x1, x2 - -def result_type(*arrays_and_dtypes: Union[array, Dtype]) -> Dtype: - if len(arrays_and_dtypes) == 0: - raise TypeError("At least one array or dtype must be provided") - if len(arrays_and_dtypes) == 1: - x = arrays_and_dtypes[0] - if isinstance(x, torch.dtype): - return x - return x.dtype - if len(arrays_and_dtypes) > 2: - return result_type(arrays_and_dtypes[0], result_type(*arrays_and_dtypes[1:])) - - x, y = arrays_and_dtypes - xdt = x.dtype if not isinstance(x, torch.dtype) else x - ydt = y.dtype if not isinstance(y, torch.dtype) else y - - if (xdt, ydt) in _promotion_table: - return _promotion_table[xdt, ydt] - - # This doesn't result_type(dtype, dtype) for non-array API dtypes - # because torch.result_type only accepts tensors. This does however, allow - # cross-kind promotion. - return torch.result_type(x, y) - -def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool: - if not isinstance(from_, torch.dtype): - from_ = from_.dtype - return torch.can_cast(from_, to) - -# Basic renames -permute_dims = torch.permute -bitwise_invert = torch.bitwise_not - -# Two-arg elementwise functions -# These require a wrapper to do the correct type promotion on 0-D tensors -add = _two_arg(torch.add) -atan2 = _two_arg(torch.atan2) -bitwise_and = _two_arg(torch.bitwise_and) -bitwise_left_shift = _two_arg(torch.bitwise_left_shift) -bitwise_or = _two_arg(torch.bitwise_or) -bitwise_right_shift = _two_arg(torch.bitwise_right_shift) -bitwise_xor = _two_arg(torch.bitwise_xor) -divide = _two_arg(torch.divide) -# Also a rename. torch.equal does not broadcast -equal = _two_arg(torch.eq) -floor_divide = _two_arg(torch.floor_divide) -greater = _two_arg(torch.greater) -greater_equal = _two_arg(torch.greater_equal) -less = _two_arg(torch.less) -less_equal = _two_arg(torch.less_equal) -logaddexp = _two_arg(torch.logaddexp) -# logical functions are not included here because they only accept bool in the -# spec, so type promotion is irrelevant. -multiply = _two_arg(torch.multiply) -not_equal = _two_arg(torch.not_equal) -pow = _two_arg(torch.pow) -remainder = _two_arg(torch.remainder) -subtract = _two_arg(torch.subtract) - -# These wrappers are mostly based on the fact that pytorch uses 'dim' instead -# of 'axis'. - -# torch.min and torch.max return a tuple and don't support multiple axes https://github.com/pytorch/pytorch/issues/58745 -def max(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> array: - # https://github.com/pytorch/pytorch/issues/29137 - if axis == (): - return torch.clone(x) - return torch.amax(x, axis, keepdims=keepdims) - -def min(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> array: - # https://github.com/pytorch/pytorch/issues/29137 - if axis == (): - return torch.clone(x) - return torch.amin(x, axis, keepdims=keepdims) - -# torch.sort also returns a tuple -# https://github.com/pytorch/pytorch/issues/70921 -def sort(x: array, /, *, axis: int = -1, descending: bool = False, stable: bool = True, **kwargs) -> array: - return torch.sort(x, dim=axis, descending=descending, stable=stable, **kwargs).values - -def _normalize_axes(axis, ndim): - axes = [] - if ndim == 0 and axis: - # Better error message in this case - raise IndexError(f"Dimension out of range: {axis[0]}") - lower, upper = -ndim, ndim - 1 - for a in axis: - if a < lower or a > upper: - # Match torch error message (e.g., from sum()) - raise IndexError(f"Dimension out of range (expected to be in range of [{lower}, {upper}], but got {a}") - if a < 0: - a = a + ndim - if a in axes: - # Use IndexError instead of RuntimeError, and "axis" instead of "dim" - raise IndexError(f"Axis {a} appears multiple times in the list of axes") - axes.append(a) - return sorted(axes) - -def _axis_none_keepdims(x, ndim, keepdims): - # Apply keepdims when axis=None - # (https://github.com/pytorch/pytorch/issues/71209) - # Note that this is only valid for the axis=None case. - if keepdims: - for i in range(ndim): - x = torch.unsqueeze(x, 0) - return x - -def _reduce_multiple_axes(f, x, axis, keepdims=False, **kwargs): - # Some reductions don't support multiple axes - # (https://github.com/pytorch/pytorch/issues/56586). - axes = _normalize_axes(axis, x.ndim) - for a in reversed(axes): - x = torch.movedim(x, a, -1) - x = torch.flatten(x, -len(axes)) - - out = f(x, -1, **kwargs) - - if keepdims: - for a in axes: - out = torch.unsqueeze(out, a) - return out - -def prod(x: array, - /, - *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, - dtype: Optional[Dtype] = None, - keepdims: bool = False, - **kwargs) -> array: - x = torch.asarray(x) - ndim = x.ndim - - # https://github.com/pytorch/pytorch/issues/29137. Separate from the logic - # below because it still needs to upcast. - if axis == (): - if dtype is None: - # We can't upcast uint8 according to the spec because there is no - # torch.uint64, so at least upcast to int64 which is what sum does - # when axis=None. - if x.dtype in [torch.int8, torch.int16, torch.int32, torch.uint8]: - return x.to(torch.int64) - return x.clone() - return x.to(dtype) - - # torch.prod doesn't support multiple axes - # (https://github.com/pytorch/pytorch/issues/56586). - if isinstance(axis, tuple): - return _reduce_multiple_axes(torch.prod, x, axis, keepdims=keepdims, dtype=dtype, **kwargs) - if axis is None: - # torch doesn't support keepdims with axis=None - # (https://github.com/pytorch/pytorch/issues/71209) - res = torch.prod(x, dtype=dtype, **kwargs) - res = _axis_none_keepdims(res, ndim, keepdims) - return res - - return torch.prod(x, axis, dtype=dtype, keepdims=keepdims, **kwargs) - - -def sum(x: array, - /, - *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, - dtype: Optional[Dtype] = None, - keepdims: bool = False, - **kwargs) -> array: - x = torch.asarray(x) - ndim = x.ndim - - # https://github.com/pytorch/pytorch/issues/29137. - # Make sure it upcasts. - if axis == (): - if dtype is None: - # We can't upcast uint8 according to the spec because there is no - # torch.uint64, so at least upcast to int64 which is what sum does - # when axis=None. - if x.dtype in [torch.int8, torch.int16, torch.int32, torch.uint8]: - return x.to(torch.int64) - return x.clone() - return x.to(dtype) - - if axis is None: - # torch doesn't support keepdims with axis=None - # (https://github.com/pytorch/pytorch/issues/71209) - res = torch.sum(x, dtype=dtype, **kwargs) - res = _axis_none_keepdims(res, ndim, keepdims) - return res - - return torch.sum(x, axis, dtype=dtype, keepdims=keepdims, **kwargs) - -def any(x: array, - /, - *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, - keepdims: bool = False, - **kwargs) -> array: - x = torch.asarray(x) - ndim = x.ndim - if axis == (): - return x.to(torch.bool) - # torch.any doesn't support multiple axes - # (https://github.com/pytorch/pytorch/issues/56586). - if isinstance(axis, tuple): - res = _reduce_multiple_axes(torch.any, x, axis, keepdims=keepdims, **kwargs) - return res.to(torch.bool) - if axis is None: - # torch doesn't support keepdims with axis=None - # (https://github.com/pytorch/pytorch/issues/71209) - res = torch.any(x, **kwargs) - res = _axis_none_keepdims(res, ndim, keepdims) - return res.to(torch.bool) - - # torch.any doesn't return bool for uint8 - return torch.any(x, axis, keepdims=keepdims).to(torch.bool) - -def all(x: array, - /, - *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, - keepdims: bool = False, - **kwargs) -> array: - x = torch.asarray(x) - ndim = x.ndim - if axis == (): - return x.to(torch.bool) - # torch.all doesn't support multiple axes - # (https://github.com/pytorch/pytorch/issues/56586). - if isinstance(axis, tuple): - res = _reduce_multiple_axes(torch.all, x, axis, keepdims=keepdims, **kwargs) - return res.to(torch.bool) - if axis is None: - # torch doesn't support keepdims with axis=None - # (https://github.com/pytorch/pytorch/issues/71209) - res = torch.all(x, **kwargs) - res = _axis_none_keepdims(res, ndim, keepdims) - return res.to(torch.bool) - - # torch.all doesn't return bool for uint8 - return torch.all(x, axis, keepdims=keepdims).to(torch.bool) - -def mean(x: array, - /, - *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, - keepdims: bool = False, - **kwargs) -> array: - # https://github.com/pytorch/pytorch/issues/29137 - if axis == (): - return torch.clone(x) - if axis is None: - # torch doesn't support keepdims with axis=None - # (https://github.com/pytorch/pytorch/issues/71209) - res = torch.mean(x, **kwargs) - res = _axis_none_keepdims(res, x.ndim, keepdims) - return res - return torch.mean(x, axis, keepdims=keepdims, **kwargs) - -def std(x: array, - /, - *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, - correction: Union[int, float] = 0.0, - keepdims: bool = False, - **kwargs) -> array: - # Note, float correction is not supported - # https://github.com/pytorch/pytorch/issues/61492. We don't try to - # implement it here for now. - - if isinstance(correction, float): - _correction = int(correction) - if correction != _correction: - raise NotImplementedError("float correction in torch std() is not yet supported") - - # https://github.com/pytorch/pytorch/issues/29137 - if axis == (): - return torch.zeros_like(x) - if isinstance(axis, int): - axis = (axis,) - if axis is None: - # torch doesn't support keepdims with axis=None - # (https://github.com/pytorch/pytorch/issues/71209) - res = torch.std(x, tuple(range(x.ndim)), correction=_correction, **kwargs) - res = _axis_none_keepdims(res, x.ndim, keepdims) - return res - return torch.std(x, axis, correction=_correction, keepdims=keepdims, **kwargs) - -def var(x: array, - /, - *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, - correction: Union[int, float] = 0.0, - keepdims: bool = False, - **kwargs) -> array: - # Note, float correction is not supported - # https://github.com/pytorch/pytorch/issues/61492. We don't try to - # implement it here for now. - - # if isinstance(correction, float): - # correction = int(correction) - - # https://github.com/pytorch/pytorch/issues/29137 - if axis == (): - return torch.zeros_like(x) - if isinstance(axis, int): - axis = (axis,) - if axis is None: - # torch doesn't support keepdims with axis=None - # (https://github.com/pytorch/pytorch/issues/71209) - res = torch.var(x, tuple(range(x.ndim)), correction=correction, **kwargs) - res = _axis_none_keepdims(res, x.ndim, keepdims) - return res - return torch.var(x, axis, correction=correction, keepdims=keepdims, **kwargs) - -# torch.concat doesn't support dim=None -# https://github.com/pytorch/pytorch/issues/70925 -def concat(arrays: Union[Tuple[array, ...], List[array]], - /, - *, - axis: Optional[int] = 0, - **kwargs) -> array: - if axis is None: - arrays = tuple(ar.flatten() for ar in arrays) - axis = 0 - return torch.concat(arrays, axis, **kwargs) - -# torch.squeeze only accepts int dim and doesn't require it -# https://github.com/pytorch/pytorch/issues/70924. Support for tuple dim was -# added at https://github.com/pytorch/pytorch/pull/89017. -def squeeze(x: array, /, axis: Union[int, Tuple[int, ...]]) -> array: - if isinstance(axis, int): - axis = (axis,) - for a in axis: - if x.shape[a] != 1: - raise ValueError("squeezed dimensions must be equal to 1") - axes = _normalize_axes(axis, x.ndim) - # Remove this once pytorch 1.14 is released with the above PR #89017. - sequence = [a - i for i, a in enumerate(axes)] - for a in sequence: - x = torch.squeeze(x, a) - return x - -# The axis parameter doesn't work for flip() and roll() -# https://github.com/pytorch/pytorch/issues/71210. Also torch.flip() doesn't -# accept axis=None -def flip(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> array: - if axis is None: - axis = tuple(range(x.ndim)) - # torch.flip doesn't accept dim as an int but the method does - # https://github.com/pytorch/pytorch/issues/18095 - return x.flip(axis) - -def roll(x: array, /, shift: Union[int, Tuple[int, ...]], *, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> array: - return torch.roll(x, shift, axis) - -def nonzero(x: array, /, **kwargs) -> Tuple[array, ...]: - return torch.nonzero(x, as_tuple=True, **kwargs) - -def where(condition: array, x1: array, x2: array, /) -> array: - x1, x2 = _fix_promotion(x1, x2) - return torch.where(condition, x1, x2) - -# torch.arange doesn't support returning empty arrays -# (https://github.com/pytorch/pytorch/issues/70915), and doesn't support some -# keyword argument combinations -# (https://github.com/pytorch/pytorch/issues/70914) -def arange(start: Union[int, float], - /, - stop: Optional[Union[int, float]] = None, - step: Union[int, float] = 1, - *, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, - **kwargs) -> array: - if stop is None: - start, stop = 0, start - if step > 0 and stop <= start or step < 0 and stop >= start: - if dtype is None: - if builtin_all(isinstance(i, int) for i in [start, stop, step]): - dtype = torch.int64 - else: - dtype = torch.float32 - return torch.empty(0, dtype=dtype, device=device, **kwargs) - return torch.arange(start, stop, step, dtype=dtype, device=device, **kwargs) - -# torch.eye does not accept None as a default for the second argument and -# doesn't support off-diagonals (https://github.com/pytorch/pytorch/issues/70910) -def eye(n_rows: int, - n_cols: Optional[int] = None, - /, - *, - k: int = 0, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, - **kwargs) -> array: - if n_cols is None: - n_cols = n_rows - z = torch.zeros(n_rows, n_cols, dtype=dtype, device=device, **kwargs) - if abs(k) <= n_rows + n_cols: - z.diagonal(k).fill_(1) - return z - -# torch.linspace doesn't have the endpoint parameter -def linspace(start: Union[int, float], - stop: Union[int, float], - /, - num: int, - *, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, - endpoint: bool = True, - **kwargs) -> array: - if not endpoint: - return torch.linspace(start, stop, num+1, dtype=dtype, device=device, **kwargs)[:-1] - return torch.linspace(start, stop, num, dtype=dtype, device=device, **kwargs) - -# torch.full does not accept an int size -# https://github.com/pytorch/pytorch/issues/70906 -def full(shape: Union[int, Tuple[int, ...]], - fill_value: Union[bool, int, float, complex], - *, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, - **kwargs) -> array: - if isinstance(shape, int): - shape = (shape,) - - return torch.full(shape, fill_value, dtype=dtype, device=device, **kwargs) - -# ones, zeros, and empty do not accept shape as a keyword argument -def ones(shape: Union[int, Tuple[int, ...]], - *, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, - **kwargs) -> array: - return torch.ones(shape, dtype=dtype, device=device, **kwargs) - -def zeros(shape: Union[int, Tuple[int, ...]], - *, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, - **kwargs) -> array: - return torch.zeros(shape, dtype=dtype, device=device, **kwargs) - -def empty(shape: Union[int, Tuple[int, ...]], - *, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, - **kwargs) -> array: - return torch.empty(shape, dtype=dtype, device=device, **kwargs) - -# tril and triu do not call the keyword argument k - -def tril(x: array, /, *, k: int = 0) -> array: - return torch.tril(x, k) - -def triu(x: array, /, *, k: int = 0) -> array: - return torch.triu(x, k) - -# Functions that aren't in torch https://github.com/pytorch/pytorch/issues/58742 -def expand_dims(x: array, /, *, axis: int = 0) -> array: - return torch.unsqueeze(x, axis) - -def astype(x: array, dtype: Dtype, /, *, copy: bool = True) -> array: - return x.to(dtype, copy=copy) - -def broadcast_arrays(*arrays: array) -> List[array]: - shape = torch.broadcast_shapes(*[a.shape for a in arrays]) - return [torch.broadcast_to(a, shape) for a in arrays] - -# https://github.com/pytorch/pytorch/issues/70920 -def unique_all(x: array) -> UniqueAllResult: - # torch.unique doesn't support returning indices. - # https://github.com/pytorch/pytorch/issues/36748. The workaround - # suggested in that issue doesn't actually function correctly (it relies - # on non-deterministic behavior of scatter()). - raise NotImplementedError("unique_all() not yet implemented for pytorch (see https://github.com/pytorch/pytorch/issues/36748)") - - # values, inverse_indices, counts = torch.unique(x, return_counts=True, return_inverse=True) - # # torch.unique incorrectly gives a 0 count for nan values. - # # https://github.com/pytorch/pytorch/issues/94106 - # counts[torch.isnan(values)] = 1 - # return UniqueAllResult(values, indices, inverse_indices, counts) - -def unique_counts(x: array) -> UniqueCountsResult: - values, counts = torch.unique(x, return_counts=True) - - # torch.unique incorrectly gives a 0 count for nan values. - # https://github.com/pytorch/pytorch/issues/94106 - counts[torch.isnan(values)] = 1 - return UniqueCountsResult(values, counts) - -def unique_inverse(x: array) -> UniqueInverseResult: - values, inverse = torch.unique(x, return_inverse=True) - return UniqueInverseResult(values, inverse) - -def unique_values(x: array) -> array: - return torch.unique(x) - -def matmul(x1: array, x2: array, /, **kwargs) -> array: - # torch.matmul doesn't type promote (but differently from _fix_promotion) - x1, x2 = _fix_promotion(x1, x2, only_scalar=False) - return torch.matmul(x1, x2, **kwargs) - -matrix_transpose = get_xp(torch)(_aliases_matrix_transpose) -_vecdot = get_xp(torch)(_aliases_vecdot) - -def vecdot(x1: array, x2: array, /, *, axis: int = -1) -> array: - x1, x2 = _fix_promotion(x1, x2, only_scalar=False) - return _vecdot(x1, x2, axis=axis) - -# torch.tensordot uses dims instead of axes -def tensordot(x1: array, x2: array, /, *, axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2, **kwargs) -> array: - # Note: torch.tensordot fails with integer dtypes when there is only 1 - # element in the axis (https://github.com/pytorch/pytorch/issues/84530). - x1, x2 = _fix_promotion(x1, x2, only_scalar=False) - return torch.tensordot(x1, x2, dims=axes, **kwargs) - - -def isdtype( - dtype: Dtype, kind: Union[Dtype, str, Tuple[Union[Dtype, str], ...]], - *, _tuple=True, # Disallow nested tuples -) -> bool: - """ - Returns a boolean indicating whether a provided dtype is of a specified data type ``kind``. - - Note that outside of this function, this compat library does not yet fully - support complex numbers. - - See - https://data-apis.org/array-api/latest/API_specification/generated/array_api.isdtype.html - for more details - """ - if isinstance(kind, tuple) and _tuple: - return builtin_any(isdtype(dtype, k, _tuple=False) for k in kind) - elif isinstance(kind, str): - if kind == 'bool': - return dtype == torch.bool - elif kind == 'signed integer': - return dtype in _int_dtypes and dtype.is_signed - elif kind == 'unsigned integer': - return dtype in _int_dtypes and not dtype.is_signed - elif kind == 'integral': - return dtype in _int_dtypes - elif kind == 'real floating': - return dtype.is_floating_point - elif kind == 'complex floating': - return dtype.is_complex - elif kind == 'numeric': - return isdtype(dtype, ('integral', 'real floating', 'complex floating')) - else: - raise ValueError(f"Unrecognized data type kind: {kind!r}") - else: - return dtype == kind - -__all__ = ['result_type', 'can_cast', 'permute_dims', 'bitwise_invert', 'add', - 'atan2', 'bitwise_and', 'bitwise_left_shift', 'bitwise_or', - 'bitwise_right_shift', 'bitwise_xor', 'divide', 'equal', - 'floor_divide', 'greater', 'greater_equal', 'less', 'less_equal', - 'logaddexp', 'multiply', 'not_equal', 'pow', 'remainder', - 'subtract', 'max', 'min', 'sort', 'prod', 'sum', 'any', 'all', - 'mean', 'std', 'var', 'concat', 'squeeze', 'flip', 'roll', - 'nonzero', 'where', 'arange', 'eye', 'linspace', 'full', 'ones', - 'zeros', 'empty', 'tril', 'triu', 'expand_dims', 'astype', - 'broadcast_arrays', 'unique_all', 'unique_counts', - 'unique_inverse', 'unique_values', 'matmul', 'matrix_transpose', - 'vecdot', 'tensordot', 'isdtype'] diff --git a/sklearn/externals/_array_api_compat/torch/linalg.py b/sklearn/externals/_array_api_compat/torch/linalg.py deleted file mode 100644 index c803228abc604..0000000000000 --- a/sklearn/externals/_array_api_compat/torch/linalg.py +++ /dev/null @@ -1,27 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING -if TYPE_CHECKING: - import torch - array = torch.Tensor - -from torch.linalg import * - -# torch.linalg doesn't define __all__ -# from torch.linalg import __all__ as linalg_all -from torch import linalg as torch_linalg -linalg_all = [i for i in dir(torch_linalg) if not i.startswith('_')] - -# These are implemented in torch but aren't in the linalg namespace -from torch import outer, trace -from ._aliases import _fix_promotion, matrix_transpose, tensordot - -# Note: torch.linalg.cross does not default to axis=-1 (it defaults to the -# first axis with size 3), see https://github.com/pytorch/pytorch/issues/58743 -def cross(x1: array, x2: array, /, *, axis: int = -1) -> array: - x1, x2 = _fix_promotion(x1, x2, only_scalar=False) - return torch_linalg.cross(x1, x2, dim=axis) - -__all__ = linalg_all + ['outer', 'trace', 'matrix_transpose', 'tensordot'] - -del linalg_all diff --git a/sklearn/tests/test_discriminant_analysis.py b/sklearn/tests/test_discriminant_analysis.py index afebb66a05685..76b17c944e7af 100644 --- a/sklearn/tests/test_discriminant_analysis.py +++ b/sklearn/tests/test_discriminant_analysis.py @@ -13,6 +13,7 @@ from sklearn.utils._testing import assert_almost_equal from sklearn.utils._array_api import _convert_to_numpy from sklearn.utils._testing import _convert_container +from sklearn.utils._testing import skip_if_no_array_api_compat from sklearn.datasets import make_blobs from sklearn.discriminant_analysis import LinearDiscriminantAnalysis @@ -676,6 +677,7 @@ def test_get_feature_names_out(): assert_array_equal(names_out, expected_names_out) +@skip_if_no_array_api_compat @pytest.mark.parametrize("array_namespace", ["numpy.array_api", "cupy.array_api"]) def test_lda_array_api(array_namespace): """Check that the array_api Array gives the same results as ndarrays.""" @@ -733,6 +735,7 @@ def test_lda_array_api(array_namespace): ) +@skip_if_no_array_api_compat @pytest.mark.parametrize("device", ["cuda", "cpu"]) @pytest.mark.parametrize("dtype", ["float32", "float64"]) def test_lda_array_torch(device, dtype): diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py index 357830c96ba39..8b7d0897a1a31 100644 --- a/sklearn/utils/_array_api.py +++ b/sklearn/utils/_array_api.py @@ -1,23 +1,92 @@ """Tools to support array_api.""" +from functools import wraps +import math + import numpy import scipy.special as special -import sklearn.externals._array_api_compat as array_api_compat +from .._config import get_config + -import sklearn.externals._array_api_compat.numpy as array_api_compat_numpy -from sklearn.externals._array_api_compat import device, size # noqa +def device(x): + """ + Hardware device the array data resides on. -from .._config import get_config, config_context + Parameters + ---------- + x: array + array instance from NumPy or an array API compatible library. + + Returns + ------- + out: device + a ``device`` object (see the "Device Support" section of the array API spec). + """ + if isinstance(x, (numpy.ndarray, numpy.generic)): + return "cpu" + return x.device + + +def size(x): + """ + Return the total number of elements of x + """ + if None in x.shape: + return None + return math.prod(x.shape) def _is_numpy_namespace(xp): return xp.__name__ in { "numpy", - "sklearn.externals._array_api_compat.numpy", + "_NumPyApiWrapper", + "array_api_compat.numpy", "numpy.array_api", } +def isdtype(dtype, kind, *, xp): + """Returns a boolean indicating whether a provided dtype is of type "kind". + + Included in the v2022.12 of the Array API spec. + https://data-apis.org/array-api/latest/API_specification/generated/array_api.isdtype.html + """ + if isinstance(kind, tuple): + return any(_isdtype_single(dtype, k, xp=xp) for k in kind) + else: + return _isdtype_single(dtype, kind, xp=xp) + + +def _isdtype_single(dtype, kind, *, xp): + if isinstance(kind, str): + if kind == "bool": + return dtype == xp.bool + elif kind == "signed integer": + return dtype in {xp.int8, xp.int16, xp.int32, xp.int64} + elif kind == "unsigned integer": + return dtype in {xp.uint8, xp.uint16, xp.uint32, xp.uint64} + elif kind == "integral": + return isdtype(dtype, ("signed integer", "unsigned integer"), xp=xp) + elif kind == "real floating": + return dtype in {xp.float32, xp.float64} + elif kind == "complex floating": + # Some name spaces do not have complex, such as cupy.array_api + # and numpy.array_api + if hasattr(xp, "complex64"): + return dtype == xp.complex64 + if hasattr(xp, "complex128"): + return dtype == xp.complex128 + return False + elif kind == "numeric": + return isdtype( + dtype, ("integral", "real floating", "complex floating"), xp=xp + ) + else: + raise ValueError(f"Unrecognized data type kind: {kind!r}") + else: + return dtype == kind + + class _ArrayAPIWrapper: """sklearn specific Array API compatibility wrapper @@ -68,37 +137,92 @@ def isdtype(self, dtype, kind): Included in the v2022.12 of the Array API spec. https://data-apis.org/array-api/latest/API_specification/generated/array_api.isdtype.html """ - if isinstance(kind, tuple): - return any(self._isdtype_single(dtype, k) for k in kind) - else: - return self._isdtype_single(dtype, kind) - - def _isdtype_single(self, dtype, kind): - xp = self._namespace - if isinstance(kind, str): - if kind == "bool": - return dtype == xp.bool - elif kind == "signed integer": - return dtype in {xp.int8, xp.int16, xp.int32, xp.int64} - elif kind == "unsigned integer": - return dtype in {xp.uint8, xp.uint16, xp.uint32, xp.uint64} - elif kind == "integral": - return self.isdtype(dtype, ("signed integer", "unsigned integer")) - elif kind == "real floating": - return dtype in {xp.float32, xp.float64} - elif kind == "complex floating": - # cupy.array_api and numpy.array_cpi does not have copmlex - if xp.__name__ in {"cupy.array_api", "numpy.array_api"}: - return False - return dtype in {xp.complex64, xp.float128} - elif kind == "numeric": - return self.isdtype( - dtype, ("integral", "real floating", "complex floating") - ) - else: - raise ValueError(f"Unrecognized data type kind: {kind!r}") + return isdtype(dtype, kind, xp=self._namespace) + + +def _check_device_cpu(device): # noqa + if device not in {"cpu", None}: + raise ValueError(f"Unsupported device for NumPy: {device!r}") + + +def _accept_device_cpu(func): + @wraps(func) + def wrapped_func(*args, **kwargs): + _check_device_cpu(kwargs.pop("device", None)) + return func(*args, **kwargs) + + return wrapped_func + + +class _NumPyApiWrapper: + """Array API compat wrapper for any numpy version + + NumPy < 1.22 does not expose the numpy.array_api namespace. This + wrapper makes it possible to write code that uses the standard + Array API while working with any version of NumPy supported by + scikit-learn. + + See the `get_namespace()` public function for more details. + """ + + def __getattr__(self, name): + attr = getattr(numpy, name) + # accept device + if name in { + "arange", + "empty", + "empty_like", + "eye", + "full", + "full_like", + "linspace", + "ones", + "ones_like", + "zeros", + "zeros_like", + }: + return _accept_device_cpu(attr) + + return attr + + @property + def bool(self): + return numpy.bool_ + + def astype(self, x, dtype, *, copy=True, casting="unsafe"): + # astype is not defined in the top level NumPy namespace + return x.astype(dtype, copy=copy, casting=casting) + + def asarray(self, x, *, dtype=None, device=None, copy=None): # noqa + _check_device_cpu(device) + # Support copy in NumPy namespace + if copy is True: + return numpy.array(x, copy=True, dtype=dtype) else: - return dtype == kind + return numpy.asarray(x, dtype=dtype) + + def unique_inverse(self, x): + return numpy.unique(x, return_inverse=True) + + def unique_counts(self, x): + return numpy.unique(x, return_counts=True) + + def unique_values(self, x): + return numpy.unique(x) + + def concat(self, arrays, *, axis=None): + return numpy.concatenate(arrays, axis=axis) + + def isdtype(self, dtype, kind): + """Returns a boolean indicating whether a provided dtype is of type "kind". + + Included in the v2022.12 of the Array API spec. + https://data-apis.org/array-api/latest/API_specification/generated/array_api.isdtype.html + """ + return isdtype(dtype, kind, xp=self) + + +_NUMPY_API_WRAPPER_INSTANCE = _NumPyApiWrapper() def get_namespace(*arrays): @@ -144,13 +268,41 @@ def get_namespace(*arrays): True of the arrays are containers that implement the Array API spec. """ array_api_dispatch = get_config()["array_api_dispatch"] + return _get_namespace(*arrays, array_api_dispatch=array_api_dispatch) + + +def _get_namespace(*arrays, array_api_dispatch): + """Helper method for get_namespace that dispatches with array_api_dispatch. + + Parameters + ---------- + *arrays : array objects + Array objects. + + array_api_dispatch : bool + If True, the array namespace is obtained from the array objects. + + Returns + ------- + namespace : module + Namespace shared by array objects. If any of the `arrays` are not arrays, + the namespace defaults to NumPy. + + is_array_api : bool + True of the arrays are containers that implement the Array API spec. + """ if not array_api_dispatch: - return array_api_compat_numpy, False + return _NUMPY_API_WRAPPER_INSTANCE, False + + try: + import array_api_compat + except ImportError: + return _NUMPY_API_WRAPPER_INSTANCE, False try: namespace, is_array_api = array_api_compat.get_namespace(*arrays), True except TypeError: - return array_api_compat_numpy, False + return _NUMPY_API_WRAPPER_INSTANCE, False if namespace.__name__ in {"numpy.array_api", "cupy.array_api"}: namespace = _ArrayAPIWrapper(namespace) @@ -166,7 +318,7 @@ def _expit(X): return 1.0 / (1.0 + xp.exp(-X)) -def _asarray_with_order(array, dtype=None, order=None, copy=None, xp=None): +def _asarray_with_order(array, dtype=None, order=None, copy=None, *, xp=None): """Helper to support the order kwarg only for NumPy-backed arrays Memory layout parameter `order` is not exposed in the Array API standard, @@ -194,16 +346,15 @@ def _asarray_with_order(array, dtype=None, order=None, copy=None, xp=None): def _convert_to_numpy(array): """Convert X into a NumPy ndarray on the CPU.""" - with config_context(array_api_dispatch=True): - xp, _ = get_namespace(array) + xp, _ = _get_namespace(array, array_api_dispatch=True) xp_name = xp.__name__ - if xp_name in {"sklearn.externals._array_api_compat.torch", "torch"}: + if xp_name in {"array_api_compat.torch", "torch"}: return array.cpu().numpy() elif xp_name == "cupy.array_api": return array._array.get() - elif xp_name in {"sklearn.externals._array_api_compat.cupy", "cupy"}: + elif xp_name in {"array_api_compat.cupy", "cupy"}: return array.get() return numpy.asarray(array) @@ -229,8 +380,7 @@ def _estimator_with_converted_arrays(estimator, converter): new_estimator = clone(estimator) for key, attribute in vars(estimator).items(): - with config_context(array_api_dispatch=True): - _, is_array_api = get_namespace(attribute) + _, is_array_api = _get_namespace(attribute, array_api_dispatch=True) if is_array_api: attribute = converter(attribute) setattr(new_estimator, key, attribute) diff --git a/sklearn/utils/_testing.py b/sklearn/utils/_testing.py index 482d3ea818563..86fba721850fc 100644 --- a/sklearn/utils/_testing.py +++ b/sklearn/utils/_testing.py @@ -388,6 +388,13 @@ def set_random_state(estimator, random_state=0): estimator.set_params(random_state=random_state) +try: + import array_api_compat # noqa + + ARRAY_API_COMPAT_INSTALLED = True +except ImportError: + ARRAY_API_COMPAT_INSTALLED = False + try: import pytest @@ -400,6 +407,10 @@ def set_random_state(estimator, random_state=0): skip_if_no_parallel = pytest.mark.skipif( not joblib.parallel.mp, reason="joblib is in serial mode" ) + skip_if_no_array_api_compat = pytest.mark.skipif( + not ARRAY_API_COMPAT_INSTALLED, + reason="requires array_api_compat installed", + ) # Decorator for tests involving both BLAS calls and multiprocessing. # diff --git a/sklearn/utils/tests/test_array_api.py b/sklearn/utils/tests/test_array_api.py index 65d56fa77713d..10f8f4a9f4320 100644 --- a/sklearn/utils/tests/test_array_api.py +++ b/sklearn/utils/tests/test_array_api.py @@ -4,13 +4,14 @@ from sklearn.base import BaseEstimator from sklearn.utils._array_api import get_namespace +from sklearn.utils._array_api import _NumPyApiWrapper from sklearn.utils._array_api import _ArrayAPIWrapper from sklearn.utils._array_api import _asarray_with_order from sklearn.utils._array_api import _convert_to_numpy from sklearn.utils._array_api import _estimator_with_converted_arrays +from sklearn.utils._testing import skip_if_no_array_api_compat -import sklearn.externals._array_api_compat.numpy as array_api_compat_numpy from sklearn._config import config_context pytestmark = pytest.mark.filterwarnings( @@ -18,19 +19,39 @@ ) -def test_get_namespace_ndarray(): +@pytest.mark.parametrize("X", [numpy.asarray([1, 2, 3]), [1, 2, 3]]) +def test_get_namespace_ndarray_default(X): + """Check that get_namespace returns NumPy wrapper""" + xp_out, is_array_api = get_namespace(X) + assert isinstance(xp_out, _NumPyApiWrapper) + assert not is_array_api + + +def test_get_namespace_ndarray_creation_device(): + """Check expected behavior with device and creation functions.""" + X = numpy.asarray([1, 2, 3]) + xp_out, _ = get_namespace(X) + + full_array = xp_out.full(10, fill_value=2.0, device="cpu") + assert_allclose(full_array, [2.0] * 10) + + with pytest.raises(ValueError, match="Unsupported device"): + xp_out.zeros(10, device="cuda") + + +def test_get_namespace_ndarray_with_dispatch(): """Test get_namespace on NumPy ndarrays.""" - pytest.importorskip("numpy.array_api") + array_api_compat = pytest.importorskip("array_api_compat") X_np = numpy.asarray([[1, 2, 3]]) - for array_api_dispatch in [True, False]: - with config_context(array_api_dispatch=array_api_dispatch): - xp_out, is_array_api = get_namespace(X_np) - assert is_array_api == array_api_dispatch - assert xp_out is array_api_compat_numpy + with config_context(array_api_dispatch=True): + xp_out, is_array_api = get_namespace(X_np) + assert is_array_api + assert xp_out is array_api_compat.numpy +@skip_if_no_array_api_compat def test_get_namespace_array_api(): """Test get_namespace for ArrayAPI arrays.""" xp = pytest.importorskip("numpy.array_api") @@ -42,12 +63,8 @@ def test_get_namespace_array_api(): assert is_array_api assert isinstance(xp_out, _ArrayAPIWrapper) - # check errors - with pytest.raises(TypeError, match="Multiple namespaces"): - get_namespace(X_np, X_xp) - xp_out, is_array_api = get_namespace(1) - assert xp_out == array_api_compat_numpy + assert isinstance(xp_out, _NumPyApiWrapper) assert not is_array_api @@ -115,7 +132,9 @@ def test_array_api_wrapper_take(): xp.take(xp.asarray([[[0]]]), xp.asarray([0]), axis=0) -@pytest.mark.parametrize("is_array_api", [True, False]) +@pytest.mark.parametrize( + "is_array_api", [pytest.param(True, marks=skip_if_no_array_api_compat), False] +) def test_asarray_with_order(is_array_api): """Test _asarray_with_order passes along order for NumPy arrays.""" if is_array_api: @@ -124,7 +143,7 @@ def test_asarray_with_order(is_array_api): xp = numpy X = xp.asarray([1.2, 3.4, 5.1]) - X_new = _asarray_with_order(X, order="F") + X_new = _asarray_with_order(X, order="F", xp=xp) X_new_np = numpy.asarray(X_new) assert X_new_np.flags["F_CONTIGUOUS"] @@ -145,6 +164,7 @@ def test_asarray_with_order_ignored(): assert not X_new_np.flags["F_CONTIGUOUS"] +@skip_if_no_array_api_compat @pytest.mark.parametrize("library", ["cupy", "torch", "cupy.array_api"]) def test_convert_to_numpy_gpu(library): """Check convert_to_numpy for GPU backed libraries.""" @@ -209,7 +229,9 @@ def test_convert_estimator_to_array_api(): assert hasattr(new_est.X_, "__array_namespace__") -@pytest.mark.parametrize("array_api_dispatch", [True, False]) +@pytest.mark.parametrize( + "array_api_dispatch", [pytest.param(True, marks=skip_if_no_array_api_compat), False] +) def test_get_namespace_array_api_isdtype(array_api_dispatch): """Test isdtype implementation from _ArrayAPIWrapper and array_api_compat.""" xp = pytest.importorskip("numpy.array_api") @@ -235,7 +257,9 @@ def test_get_namespace_array_api_isdtype(array_api_dispatch): assert xp_out.isdtype(xp_out.uint32, "numeric") -@pytest.mark.parametrize("array_api_dispatch", [True, False]) +@pytest.mark.parametrize( + "array_api_dispatch", [pytest.param(True, marks=skip_if_no_array_api_compat), False] +) def test_get_namespace_list(array_api_dispatch): """Test get_namespace for lists.""" @@ -243,4 +267,4 @@ def test_get_namespace_list(array_api_dispatch): with config_context(array_api_dispatch=array_api_dispatch): xp_out, is_array = get_namespace(X) assert not is_array - assert xp_out is array_api_compat_numpy + assert isinstance(xp_out, _NumPyApiWrapper) From 2ae9363c69a90435cf5576fa0bbc8e94d4b1fd4c Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Wed, 29 Mar 2023 13:44:06 -0400 Subject: [PATCH 08/42] FIX Fixes issue with string or python dtype args --- build_tools/azure/install.sh | 4 +++ sklearn/_config.py | 27 ++++++++++++++++++ sklearn/utils/_array_api.py | 53 ++++++++++++++++++++++++++---------- sklearn/utils/_testing.py | 12 ++++---- sklearn/utils/validation.py | 2 ++ 5 files changed, 78 insertions(+), 20 deletions(-) diff --git a/build_tools/azure/install.sh b/build_tools/azure/install.sh index 5238cd1121d2e..19e6dae4917b5 100755 --- a/build_tools/azure/install.sh +++ b/build_tools/azure/install.sh @@ -54,6 +54,10 @@ python_environment_install_and_activate() { conda-lock install --name $VIRTUALENV $LOCK_FILE source activate $VIRTUALENV + # TODO: Remove when array_api_compat ships a new release with latest changes + # install development feature of array_api_compat for testing purposes + pip install git+https://github.com/data-apis/array-api-compat + elif [[ "$DISTRIB" == "ubuntu" || "$DISTRIB" == "debian-32" ]]; then python3 -m virtualenv --system-site-packages --python=python3 $VIRTUALENV source $VIRTUALENV/bin/activate diff --git a/sklearn/_config.py b/sklearn/_config.py index e4c398c9c5444..213d34c36a99f 100644 --- a/sklearn/_config.py +++ b/sklearn/_config.py @@ -27,6 +27,32 @@ def _get_threadlocal_config(): return _threadlocal.global_config +def _check_array_api_dispatch(array_api_dispatch): + """Check that array_api_compat is installed and NumPy version is compatible. + + array_api_compat follows NEP29, which has a higher minimum NumPy version than + scikit-learn. + """ + if array_api_dispatch: + try: + import array_api_compat # noqa + except ImportError: + raise ImportError( + "array_api_compat is required when array_api_dispatch=True" + ) + + import numpy + from .utils.fixes import parse_version + + numpy_version = parse_version(numpy.__version__) + min_numpy_version = "1.21" + if numpy_version < parse_version(min_numpy_version): + raise ImportError( + f"NumPy must be newer than {min_numpy_version} when" + " array_api_dispatch=True" + ) + + def get_config(): """Retrieve current values for configuration set by :func:`set_config`. @@ -154,6 +180,7 @@ def set_config( if enable_cython_pairwise_dist is not None: local_config["enable_cython_pairwise_dist"] = enable_cython_pairwise_dist if array_api_dispatch is not None: + # _check_array_api_dispatch(array_api_dispatch) local_config["array_api_dispatch"] = array_api_dispatch if transform_output is not None: local_config["transform_output"] = transform_output diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py index 8b7d0897a1a31..08dff9621e241 100644 --- a/sklearn/utils/_array_api.py +++ b/sklearn/utils/_array_api.py @@ -165,24 +165,49 @@ class _NumPyApiWrapper: See the `get_namespace()` public function for more details. """ + # Creation functions in spec: + # https://data-apis.org/array-api/latest/API_specification/creation_functions.html + _CREATION_FUNCS = { + "arange", + "empty", + "empty_like", + "eye", + "full", + "full_like", + "linspace", + "ones", + "ones_like", + "zeros", + "zeros_like", + } + # Data types in spec + # https://data-apis.org/array-api/latest/API_specification/data_types.html + _DTYPES = { + "int8", + "int16", + "int32", + "int64", + "uint8", + "uint16", + "uint32", + "uint64", + "float32", + "float64", + "complex64", + "complex128", + } + def __getattr__(self, name): attr = getattr(numpy, name) - # accept device - if name in { - "arange", - "empty", - "empty_like", - "eye", - "full", - "full_like", - "linspace", - "ones", - "ones_like", - "zeros", - "zeros_like", - }: + + # Support device kwargs and make sure they are on the CPU + if name in self._CREATION_FUNCS: return _accept_device_cpu(attr) + # Convert to dtype objects + if name in self._DTYPES: + return numpy.dtype(attr) + return attr @property diff --git a/sklearn/utils/_testing.py b/sklearn/utils/_testing.py index 86fba721850fc..128806269512f 100644 --- a/sklearn/utils/_testing.py +++ b/sklearn/utils/_testing.py @@ -54,6 +54,7 @@ _IS_32BIT, _in_unstable_openblas_configuration, ) +from sklearn._config import _check_array_api_dispatch from sklearn.utils.multiclass import check_classification_targets from sklearn.utils.validation import ( check_array, @@ -389,11 +390,10 @@ def set_random_state(estimator, random_state=0): try: - import array_api_compat # noqa - - ARRAY_API_COMPAT_INSTALLED = True + _check_array_api_dispatch(True) + ARRAY_API_COMPAT_FUNCTIONAL = True except ImportError: - ARRAY_API_COMPAT_INSTALLED = False + ARRAY_API_COMPAT_FUNCTIONAL = False try: import pytest @@ -408,8 +408,8 @@ def set_random_state(estimator, random_state=0): not joblib.parallel.mp, reason="joblib is in serial mode" ) skip_if_no_array_api_compat = pytest.mark.skipif( - not ARRAY_API_COMPAT_INSTALLED, - reason="requires array_api_compat installed", + not ARRAY_API_COMPAT_FUNCTIONAL, + reason="requires array_api_compat installed and a new enough version of NumPy", ) # Decorator for tests involving both BLAS calls and multiprocessing. diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index 38c5a9acde808..f5d3ffc522a30 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -758,6 +758,8 @@ def check_array( # store whether originally we wanted numeric dtype dtype_numeric = isinstance(dtype, str) and dtype == "numeric" + if dtype is int or dtype == "int": + dtype = xp.int64 dtype_orig = getattr(array, "dtype", None) if not is_array_api and not hasattr(dtype_orig, "kind"): From 42b04f808bd16e3fa70bdeaf5da53b566da4fe00 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Wed, 29 Mar 2023 13:44:39 -0400 Subject: [PATCH 09/42] CLN Rename is_array_api to is_array_api_compliant --- sklearn/discriminant_analysis.py | 10 +++++----- sklearn/utils/_array_api.py | 15 +++++++++------ sklearn/utils/extmath.py | 2 +- sklearn/utils/multiclass.py | 16 ++++++++-------- sklearn/utils/tests/test_array_api.py | 23 ++++++++++++----------- sklearn/utils/validation.py | 4 ++-- 6 files changed, 37 insertions(+), 33 deletions(-) diff --git a/sklearn/discriminant_analysis.py b/sklearn/discriminant_analysis.py index 52a4b495f8c23..066415e3321c5 100644 --- a/sklearn/discriminant_analysis.py +++ b/sklearn/discriminant_analysis.py @@ -107,11 +107,11 @@ def _class_means(X, y): means : array-like of shape (n_classes, n_features) Class means. """ - xp, is_array_api = get_namespace(X) + xp, is_array_api_compliant = get_namespace(X) classes, y = xp.unique_inverse(y) means = xp.zeros((classes.shape[0], X.shape[1]), device=device(X), dtype=X.dtype) - if is_array_api: + if is_array_api_compliant: for i in range(classes.shape[0]): means[i, :] = xp.mean(X[y == i], axis=0) else: @@ -483,9 +483,9 @@ def _solve_svd(self, X, y): y : array-like of shape (n_samples,) or (n_samples, n_targets) Target values. """ - xp, is_array_api = get_namespace(X) + xp, is_array_api_compliant = get_namespace(X) - if is_array_api: + if is_array_api_compliant: svd = xp.linalg.svd else: svd = scipy.linalg.svd @@ -688,7 +688,7 @@ def predict_proba(self, X): Estimated probabilities. """ check_is_fitted(self) - xp, is_array_api = get_namespace(X) + xp, is_array_api_compliant = get_namespace(X) decision = self.decision_function(X) if size(self.classes_) == 2: proba = _expit(decision) diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py index 08dff9621e241..5ca6e1c729e69 100644 --- a/sklearn/utils/_array_api.py +++ b/sklearn/utils/_array_api.py @@ -289,7 +289,7 @@ def get_namespace(*arrays): Namespace shared by array objects. If any of the `arrays` are not arrays, the namespace defaults to NumPy. - is_array_api : bool + is_array_api_compliant : bool True of the arrays are containers that implement the Array API spec. """ array_api_dispatch = get_config()["array_api_dispatch"] @@ -313,7 +313,7 @@ def _get_namespace(*arrays, array_api_dispatch): Namespace shared by array objects. If any of the `arrays` are not arrays, the namespace defaults to NumPy. - is_array_api : bool + is_array_api_compliant : bool True of the arrays are containers that implement the Array API spec. """ if not array_api_dispatch: @@ -325,14 +325,17 @@ def _get_namespace(*arrays, array_api_dispatch): return _NUMPY_API_WRAPPER_INSTANCE, False try: - namespace, is_array_api = array_api_compat.get_namespace(*arrays), True + namespace, is_array_api_compliant = ( + array_api_compat.get_namespace(*arrays), + True, + ) except TypeError: return _NUMPY_API_WRAPPER_INSTANCE, False if namespace.__name__ in {"numpy.array_api", "cupy.array_api"}: namespace = _ArrayAPIWrapper(namespace) - return namespace, is_array_api + return namespace, is_array_api_compliant def _expit(X): @@ -405,8 +408,8 @@ def _estimator_with_converted_arrays(estimator, converter): new_estimator = clone(estimator) for key, attribute in vars(estimator).items(): - _, is_array_api = _get_namespace(attribute, array_api_dispatch=True) - if is_array_api: + _, is_array_api_compliant = _get_namespace(attribute, array_api_dispatch=True) + if is_array_api_compliant: attribute = converter(attribute) setattr(new_estimator, key, attribute) return new_estimator diff --git a/sklearn/utils/extmath.py b/sklearn/utils/extmath.py index cf933cef783c4..e18e4f800b302 100644 --- a/sklearn/utils/extmath.py +++ b/sklearn/utils/extmath.py @@ -880,7 +880,7 @@ def softmax(X, copy=True): out : ndarray of shape (M, N) Softmax function evaluated at every point in x. """ - xp, is_array_api = get_namespace(X) + xp, is_array_api_compliant = get_namespace(X) if copy: X = xp.asarray(X, copy=True) max_prob = xp.reshape(xp.max(X, axis=1), (-1, 1)) diff --git a/sklearn/utils/multiclass.py b/sklearn/utils/multiclass.py index 72bf499a6e31d..02725f8986948 100644 --- a/sklearn/utils/multiclass.py +++ b/sklearn/utils/multiclass.py @@ -21,8 +21,8 @@ def _unique_multiclass(y): - xp, is_array_api = get_namespace(y) - if hasattr(y, "__array__") or is_array_api: + xp, is_array_api_compliant = get_namespace(y) + if hasattr(y, "__array__") or is_array_api_compliant: return xp.unique_values(xp.asarray(y)) else: return set(y) @@ -73,7 +73,7 @@ def unique_labels(*ys): >>> unique_labels([1, 2, 10], [5, 11]) array([ 1, 2, 5, 10, 11]) """ - xp, is_array_api = get_namespace(*ys) + xp, is_array_api_compliant = get_namespace(*ys) if not ys: raise ValueError("No argument has been passed.") # Check that we don't mix label format @@ -106,7 +106,7 @@ def unique_labels(*ys): if not _unique_labels: raise ValueError("Unknown label type: %s" % repr(ys)) - if is_array_api: + if is_array_api_compliant: # array_api does not allow for mixed dtypes unique_ys = xp.concat([_unique_labels(y) for y in ys]) return xp.unique_values(unique_ys) @@ -151,8 +151,8 @@ def is_multilabel(y): >>> is_multilabel(np.array([[1, 0, 0]])) True """ - xp, is_array_api = get_namespace(y) - if hasattr(y, "__array__") or isinstance(y, Sequence) or is_array_api: + xp, is_array_api_compliant = get_namespace(y) + if hasattr(y, "__array__") or isinstance(y, Sequence) or is_array_api_compliant: # DeprecationWarning will be replaced by ValueError, see NEP 34 # https://numpy.org/neps/nep-0034-infer-dtype-is-object.html check_y_kwargs = dict( @@ -290,11 +290,11 @@ def type_of_target(y, input_name=""): >>> type_of_target(np.array([[0, 1], [1, 1]])) 'multilabel-indicator' """ - xp, is_array_api = get_namespace(y) + xp, is_array_api_compliant = get_namespace(y) valid = ( (isinstance(y, Sequence) or issparse(y) or hasattr(y, "__array__")) and not isinstance(y, str) - or is_array_api + or is_array_api_compliant ) if not valid: diff --git a/sklearn/utils/tests/test_array_api.py b/sklearn/utils/tests/test_array_api.py index 10f8f4a9f4320..b13bbeb49c398 100644 --- a/sklearn/utils/tests/test_array_api.py +++ b/sklearn/utils/tests/test_array_api.py @@ -22,9 +22,9 @@ @pytest.mark.parametrize("X", [numpy.asarray([1, 2, 3]), [1, 2, 3]]) def test_get_namespace_ndarray_default(X): """Check that get_namespace returns NumPy wrapper""" - xp_out, is_array_api = get_namespace(X) + xp_out, is_array_api_compliant = get_namespace(X) assert isinstance(xp_out, _NumPyApiWrapper) - assert not is_array_api + assert not is_array_api_compliant def test_get_namespace_ndarray_creation_device(): @@ -46,8 +46,8 @@ def test_get_namespace_ndarray_with_dispatch(): X_np = numpy.asarray([[1, 2, 3]]) with config_context(array_api_dispatch=True): - xp_out, is_array_api = get_namespace(X_np) - assert is_array_api + xp_out, is_array_api_compliant = get_namespace(X_np) + assert is_array_api_compliant assert xp_out is array_api_compat.numpy @@ -59,13 +59,13 @@ def test_get_namespace_array_api(): X_np = numpy.asarray([[1, 2, 3]]) X_xp = xp.asarray(X_np) with config_context(array_api_dispatch=True): - xp_out, is_array_api = get_namespace(X_xp) - assert is_array_api + xp_out, is_array_api_compliant = get_namespace(X_xp) + assert is_array_api_compliant assert isinstance(xp_out, _ArrayAPIWrapper) - xp_out, is_array_api = get_namespace(1) + xp_out, is_array_api_compliant = get_namespace(1) assert isinstance(xp_out, _NumPyApiWrapper) - assert not is_array_api + assert not is_array_api_compliant class _AdjustableNameAPITestWrapper(_ArrayAPIWrapper): @@ -133,11 +133,12 @@ def test_array_api_wrapper_take(): @pytest.mark.parametrize( - "is_array_api", [pytest.param(True, marks=skip_if_no_array_api_compat), False] + "is_array_api_compliant", + [pytest.param(True, marks=skip_if_no_array_api_compat), False], ) -def test_asarray_with_order(is_array_api): +def test_asarray_with_order(is_array_api_compliant): """Test _asarray_with_order passes along order for NumPy arrays.""" - if is_array_api: + if is_array_api_compliant: xp = pytest.importorskip("numpy.array_api") else: xp = numpy diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index f5d3ffc522a30..ed010e9c44c37 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -750,7 +750,7 @@ def check_array( "https://numpy.org/doc/stable/reference/generated/numpy.matrix.html" ) - xp, is_array_api = get_namespace(array) + xp, is_array_api_compliant = get_namespace(array) # store reference to original array to check if copy is needed when # function returns @@ -762,7 +762,7 @@ def check_array( dtype = xp.int64 dtype_orig = getattr(array, "dtype", None) - if not is_array_api and not hasattr(dtype_orig, "kind"): + if not is_array_api_compliant and not hasattr(dtype_orig, "kind"): # not a data type (e.g. a column named dtype in a pandas DataFrame) dtype_orig = None From 54120774da1708edce2c739e9c844b3e9d208428 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Wed, 29 Mar 2023 13:58:30 -0400 Subject: [PATCH 10/42] CLN Improve support for dtypes that are not dtype objects --- sklearn/utils/validation.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index ed010e9c44c37..867b5ffbf16fe 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -758,8 +758,6 @@ def check_array( # store whether originally we wanted numeric dtype dtype_numeric = isinstance(dtype, str) and dtype == "numeric" - if dtype is int or dtype == "int": - dtype = xp.int64 dtype_orig = getattr(array, "dtype", None) if not is_array_api_compliant and not hasattr(dtype_orig, "kind"): @@ -832,6 +830,9 @@ def check_array( # Since we converted here, we do not need to convert again later dtype = None + if dtype is not None and _is_numpy_namespace(xp): + dtype = np.dtype(dtype) + if force_all_finite not in (True, False, "allow-nan"): raise ValueError( 'force_all_finite should be a bool or "allow-nan". Got {!r} instead'.format( From e5c2c41ae0b4c2bd977705f0fecfa0c412737262 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Wed, 29 Mar 2023 14:03:02 -0400 Subject: [PATCH 11/42] CLN Reduce diff --- pyproject.toml | 11 ----------- setup.cfg | 4 +++- sklearn/tests/test_common.py | 2 +- sklearn/tests/test_docstring_parameters.py | 3 --- sklearn/utils/tests/test_array_api.py | 1 - 5 files changed, 4 insertions(+), 17 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 1e3683752ba8b..fbb1d53ef6602 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,14 +38,3 @@ exclude = ''' | asv_benchmarks/env )/ ''' - -[tool.mypy] -ignore_missing_imports = true -allow_redefinition = true -exclude = [ - "sklearn/externals" -] - -[[tool.mypy.overrides]] -module = "sklearn.externals.*" -follow_imports = "skip" diff --git a/setup.cfg b/setup.cfg index 637e0ba0675f2..3ed576cedf92f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -73,7 +73,9 @@ per-file-ignores = examples/*: E402 doc/conf.py: E402 - +[mypy] +ignore_missing_imports = True +allow_redefinition = True [check-manifest] # ignore files missing in VCS diff --git a/sklearn/tests/test_common.py b/sklearn/tests/test_common.py index 96cf0835979a1..6ef0eaa433d20 100644 --- a/sklearn/tests/test_common.py +++ b/sklearn/tests/test_common.py @@ -202,7 +202,7 @@ def test_import_all_consistency(): ) submods = [modname for _, modname, _ in pkgs] for modname in submods + ["sklearn"]: - if ".tests." in modname or ".externals._array_api_compat." in modname: + if ".tests." in modname: continue if IS_PYPY and ( "_svmlight_format_io" in modname diff --git a/sklearn/tests/test_docstring_parameters.py b/sklearn/tests/test_docstring_parameters.py index aa043eeff41c7..8bf3e5dd7b24a 100644 --- a/sklearn/tests/test_docstring_parameters.py +++ b/sklearn/tests/test_docstring_parameters.py @@ -162,9 +162,6 @@ def test_tabs(): ): continue - if ".externals._array_api_compat." in modname: - continue - # because we don't import mod = importlib.import_module(modname) diff --git a/sklearn/utils/tests/test_array_api.py b/sklearn/utils/tests/test_array_api.py index b13bbeb49c398..235d57b1d6287 100644 --- a/sklearn/utils/tests/test_array_api.py +++ b/sklearn/utils/tests/test_array_api.py @@ -6,7 +6,6 @@ from sklearn.utils._array_api import get_namespace from sklearn.utils._array_api import _NumPyApiWrapper from sklearn.utils._array_api import _ArrayAPIWrapper - from sklearn.utils._array_api import _asarray_with_order from sklearn.utils._array_api import _convert_to_numpy from sklearn.utils._array_api import _estimator_with_converted_arrays From 0efbb85479026d894cef222574ddaf73e3bf6506 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Wed, 29 Mar 2023 16:33:27 -0400 Subject: [PATCH 12/42] ENH Check array_api_compat installation --- sklearn/_config.py | 6 ++-- sklearn/tests/test_config.py | 44 +++++++++++++++++++++++++++ sklearn/utils/_array_api.py | 9 ++---- sklearn/utils/tests/test_array_api.py | 2 ++ 4 files changed, 52 insertions(+), 9 deletions(-) diff --git a/sklearn/_config.py b/sklearn/_config.py index 213d34c36a99f..1bf20dee2ad0e 100644 --- a/sklearn/_config.py +++ b/sklearn/_config.py @@ -3,6 +3,7 @@ import os from contextlib import contextmanager as contextmanager import threading +import numpy _global_config = { "assume_finite": bool(os.environ.get("SKLEARN_ASSUME_FINITE", False)), @@ -41,14 +42,13 @@ def _check_array_api_dispatch(array_api_dispatch): "array_api_compat is required when array_api_dispatch=True" ) - import numpy from .utils.fixes import parse_version numpy_version = parse_version(numpy.__version__) min_numpy_version = "1.21" if numpy_version < parse_version(min_numpy_version): raise ImportError( - f"NumPy must be newer than {min_numpy_version} when" + f"NumPy must be {min_numpy_version} or newer when" " array_api_dispatch=True" ) @@ -180,7 +180,7 @@ def set_config( if enable_cython_pairwise_dist is not None: local_config["enable_cython_pairwise_dist"] = enable_cython_pairwise_dist if array_api_dispatch is not None: - # _check_array_api_dispatch(array_api_dispatch) + _check_array_api_dispatch(array_api_dispatch) local_config["array_api_dispatch"] = array_api_dispatch if transform_output is not None: local_config["transform_output"] = transform_output diff --git a/sklearn/tests/test_config.py b/sklearn/tests/test_config.py index bcc4c233e7ea3..7b5ebddef90bb 100644 --- a/sklearn/tests/test_config.py +++ b/sklearn/tests/test_config.py @@ -1,9 +1,11 @@ +import builtins import time from concurrent.futures import ThreadPoolExecutor import pytest from sklearn import get_config, set_config, config_context +import sklearn from sklearn.utils.parallel import delayed, Parallel @@ -145,3 +147,45 @@ def test_config_threadsafe(): ] assert items == [False, True, False, True] + + +def test_config_array_api_dispatch_error(monkeypatch): + """Check error is raised when array_api_compat is not installed.""" + + # Hide array_api_compat import + orig_import = builtins.__import__ + + def mocked_import(name, *args, **kwargs): + if name == "array_api_compat": + raise ImportError + return orig_import(name, *args, **kwargs) + + monkeypatch.setattr(builtins, "__import__", mocked_import) + + with pytest.raises(ImportError, match="array_api_compat is required"): + with config_context(array_api_dispatch=True): + pass + + with pytest.raises(ImportError, match="array_api_compat is required"): + set_config(array_api_dispatch=True) + + +def test_config_array_api_dispatch_error_numpy(monkeypatch): + """Check error when NumPy is too old""" + # Pretend that array_api_compat is installed. + orig_import = builtins.__import__ + + def mocked_import(name, *args, **kwargs): + if name == "array_api_compat": + return object() + return orig_import(name, *args, **kwargs) + + monkeypatch.setattr(builtins, "__import__", mocked_import) + monkeypatch.setattr(sklearn._config.numpy, "__version__", "1.20") + + with pytest.raises(ImportError, match="NumPy must be 1.21 or newer"): + with config_context(array_api_dispatch=True): + pass + + with pytest.raises(ImportError, match="NumPy must be 1.21 or newer"): + set_config(array_api_dispatch=True) diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py index 5ca6e1c729e69..944f229b71c83 100644 --- a/sklearn/utils/_array_api.py +++ b/sklearn/utils/_array_api.py @@ -297,7 +297,7 @@ def get_namespace(*arrays): def _get_namespace(*arrays, array_api_dispatch): - """Helper method for get_namespace that dispatches with array_api_dispatch. + """Helper method for get_namespace that dispatches based on array_api_dispatch. Parameters ---------- @@ -321,15 +321,12 @@ def _get_namespace(*arrays, array_api_dispatch): try: import array_api_compat - except ImportError: - return _NUMPY_API_WRAPPER_INSTANCE, False - try: namespace, is_array_api_compliant = ( array_api_compat.get_namespace(*arrays), True, ) - except TypeError: + except (TypeError, ImportError): return _NUMPY_API_WRAPPER_INSTANCE, False if namespace.__name__ in {"numpy.array_api", "cupy.array_api"}: @@ -409,7 +406,7 @@ def _estimator_with_converted_arrays(estimator, converter): new_estimator = clone(estimator) for key, attribute in vars(estimator).items(): _, is_array_api_compliant = _get_namespace(attribute, array_api_dispatch=True) - if is_array_api_compliant: + if is_array_api_compliant or isinstance(attribute, numpy.ndarray): attribute = converter(attribute) setattr(new_estimator, key, attribute) return new_estimator diff --git a/sklearn/utils/tests/test_array_api.py b/sklearn/utils/tests/test_array_api.py index 235d57b1d6287..69bc1fc064aa3 100644 --- a/sklearn/utils/tests/test_array_api.py +++ b/sklearn/utils/tests/test_array_api.py @@ -199,6 +199,7 @@ def fit(self, X, y=None): return self +@skip_if_no_array_api_compat @pytest.mark.parametrize( "array_namespace, converter", [ @@ -218,6 +219,7 @@ def test_convert_estimator_to_ndarray(array_namespace, converter): assert isinstance(new_est.X_, numpy.ndarray) +@skip_if_no_array_api_compat def test_convert_estimator_to_array_api(): """Convert estimator attributes to ArrayAPI arrays.""" xp = pytest.importorskip("numpy.array_api") From 075ce99ded17f6cdebc94e8970937ac5d2036650 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Wed, 29 Mar 2023 16:35:52 -0400 Subject: [PATCH 13/42] DOC Add array_api_compat requirement in user guide --- doc/modules/array_api.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/doc/modules/array_api.rst b/doc/modules/array_api.rst index 7b4e3453badb1..f9a6e1058c473 100644 --- a/doc/modules/array_api.rst +++ b/doc/modules/array_api.rst @@ -12,6 +12,8 @@ Array API support (experimental) The `Array API `_ specification defines a standard API for all array manipulation libraries with a NumPy-like API. +Scikit-learn's Array API support requires +`array-api-compat `__ to be installed. Some scikit-learn estimators that primarily rely on NumPy (as opposed to using Cython) to implement the algorithmic logic of their `fit`, `predict` or From 5f6ad0e46b3c7c89c7a54ca1687c10da84c05e18 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Wed, 29 Mar 2023 16:59:20 -0400 Subject: [PATCH 14/42] CLN Refactor version check --- sklearn/_config.py | 28 ++--------------- sklearn/tests/test_config.py | 2 +- sklearn/utils/_array_api.py | 43 +++++++++++++++++++++++++++ sklearn/utils/_testing.py | 2 +- sklearn/utils/tests/test_array_api.py | 9 ++++++ 5 files changed, 56 insertions(+), 28 deletions(-) diff --git a/sklearn/_config.py b/sklearn/_config.py index 1bf20dee2ad0e..025e601329d7e 100644 --- a/sklearn/_config.py +++ b/sklearn/_config.py @@ -3,7 +3,6 @@ import os from contextlib import contextmanager as contextmanager import threading -import numpy _global_config = { "assume_finite": bool(os.environ.get("SKLEARN_ASSUME_FINITE", False)), @@ -28,31 +27,6 @@ def _get_threadlocal_config(): return _threadlocal.global_config -def _check_array_api_dispatch(array_api_dispatch): - """Check that array_api_compat is installed and NumPy version is compatible. - - array_api_compat follows NEP29, which has a higher minimum NumPy version than - scikit-learn. - """ - if array_api_dispatch: - try: - import array_api_compat # noqa - except ImportError: - raise ImportError( - "array_api_compat is required when array_api_dispatch=True" - ) - - from .utils.fixes import parse_version - - numpy_version = parse_version(numpy.__version__) - min_numpy_version = "1.21" - if numpy_version < parse_version(min_numpy_version): - raise ImportError( - f"NumPy must be {min_numpy_version} or newer when" - " array_api_dispatch=True" - ) - - def get_config(): """Retrieve current values for configuration set by :func:`set_config`. @@ -180,6 +154,8 @@ def set_config( if enable_cython_pairwise_dist is not None: local_config["enable_cython_pairwise_dist"] = enable_cython_pairwise_dist if array_api_dispatch is not None: + from .utils._array_api import _check_array_api_dispatch + _check_array_api_dispatch(array_api_dispatch) local_config["array_api_dispatch"] = array_api_dispatch if transform_output is not None: diff --git a/sklearn/tests/test_config.py b/sklearn/tests/test_config.py index 7b5ebddef90bb..73356e92119a1 100644 --- a/sklearn/tests/test_config.py +++ b/sklearn/tests/test_config.py @@ -181,7 +181,7 @@ def mocked_import(name, *args, **kwargs): return orig_import(name, *args, **kwargs) monkeypatch.setattr(builtins, "__import__", mocked_import) - monkeypatch.setattr(sklearn._config.numpy, "__version__", "1.20") + monkeypatch.setattr(sklearn.utils._array_api.numpy, "__version__", "1.20") with pytest.raises(ImportError, match="NumPy must be 1.21 or newer"): with config_context(array_api_dispatch=True): diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py index 944f229b71c83..04d0069e05fe8 100644 --- a/sklearn/utils/_array_api.py +++ b/sklearn/utils/_array_api.py @@ -6,6 +6,31 @@ import scipy.special as special from .._config import get_config +from .fixes import parse_version + + +def _check_array_api_dispatch(array_api_dispatch): + """Check that array_api_compat is installed and NumPy version is compatible. + + array_api_compat follows NEP29, which has a higher minimum NumPy version than + scikit-learn. + """ + if array_api_dispatch: + try: + import array_api_compat # noqa + except ImportError: + raise ImportError( + "array_api_compat is required to dispatch arrays using the API" + " specification" + ) + + numpy_version = parse_version(numpy.__version__) + min_numpy_version = "1.21" + if numpy_version < parse_version(min_numpy_version): + raise ImportError( + f"NumPy must be {min_numpy_version} or newer to dispatch array using" + " the API specification" + ) def device(x): @@ -246,6 +271,22 @@ def isdtype(self, dtype, kind): """ return isdtype(dtype, kind, xp=self) + def reshape(self, x, shape, *, copy=None): + """Gives a new shape to an array without changing its data. + + The Array API specification requires shape to be a tuple. + https://data-apis.org/array-api/latest/API_specification/generated/array_api.reshape.html + """ + if not isinstance(shape, tuple): + raise TypeError("shape must be a tuple") + + if copy is True: + x = x.copy() + elif copy is False: + x.shape = shape + return x + return numpy.reshape(x, shape) + _NUMPY_API_WRAPPER_INSTANCE = _NumPyApiWrapper() @@ -319,6 +360,8 @@ def _get_namespace(*arrays, array_api_dispatch): if not array_api_dispatch: return _NUMPY_API_WRAPPER_INSTANCE, False + _check_array_api_dispatch(array_api_dispatch) + try: import array_api_compat diff --git a/sklearn/utils/_testing.py b/sklearn/utils/_testing.py index 128806269512f..9c3fd05d34efe 100644 --- a/sklearn/utils/_testing.py +++ b/sklearn/utils/_testing.py @@ -54,7 +54,7 @@ _IS_32BIT, _in_unstable_openblas_configuration, ) -from sklearn._config import _check_array_api_dispatch +from sklearn.utils._array_api import _check_array_api_dispatch from sklearn.utils.multiclass import check_classification_targets from sklearn.utils.validation import ( check_array, diff --git a/sklearn/utils/tests/test_array_api.py b/sklearn/utils/tests/test_array_api.py index 69bc1fc064aa3..62aae66f61fc1 100644 --- a/sklearn/utils/tests/test_array_api.py +++ b/sklearn/utils/tests/test_array_api.py @@ -270,3 +270,12 @@ def test_get_namespace_list(array_api_dispatch): xp_out, is_array = get_namespace(X) assert not is_array assert isinstance(xp_out, _NumPyApiWrapper) + + +def test_error_when_reshape_is_not_a_tuple(): + """Raise error in reshape when shape is not a tuple.""" + X = numpy.asarray([[1, 2, 3], [3, 4, 5]]) + xp, _ = get_namespace(X) + + with pytest.raises(TypeError, match="shape must be a tuple"): + xp.reshape(X, -1) From 39d45896e67a8d74c5cc084443969a0369da9709 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Wed, 29 Mar 2023 17:01:48 -0400 Subject: [PATCH 15/42] CLN Reduce complexity by removing private helper --- sklearn/utils/_array_api.py | 31 +++++-------------------------- 1 file changed, 5 insertions(+), 26 deletions(-) diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py index 04d0069e05fe8..4ba3b2e5ed39b 100644 --- a/sklearn/utils/_array_api.py +++ b/sklearn/utils/_array_api.py @@ -5,7 +5,7 @@ import numpy import scipy.special as special -from .._config import get_config +from .._config import get_config, config_context from .fixes import parse_version @@ -334,29 +334,6 @@ def get_namespace(*arrays): True of the arrays are containers that implement the Array API spec. """ array_api_dispatch = get_config()["array_api_dispatch"] - return _get_namespace(*arrays, array_api_dispatch=array_api_dispatch) - - -def _get_namespace(*arrays, array_api_dispatch): - """Helper method for get_namespace that dispatches based on array_api_dispatch. - - Parameters - ---------- - *arrays : array objects - Array objects. - - array_api_dispatch : bool - If True, the array namespace is obtained from the array objects. - - Returns - ------- - namespace : module - Namespace shared by array objects. If any of the `arrays` are not arrays, - the namespace defaults to NumPy. - - is_array_api_compliant : bool - True of the arrays are containers that implement the Array API spec. - """ if not array_api_dispatch: return _NUMPY_API_WRAPPER_INSTANCE, False @@ -414,7 +391,8 @@ def _asarray_with_order(array, dtype=None, order=None, copy=None, *, xp=None): def _convert_to_numpy(array): """Convert X into a NumPy ndarray on the CPU.""" - xp, _ = _get_namespace(array, array_api_dispatch=True) + with config_context(array_api_dispatch=True): + xp, _ = get_namespace(array) xp_name = xp.__name__ @@ -448,7 +426,8 @@ def _estimator_with_converted_arrays(estimator, converter): new_estimator = clone(estimator) for key, attribute in vars(estimator).items(): - _, is_array_api_compliant = _get_namespace(attribute, array_api_dispatch=True) + with config_context(array_api_dispatch=True): + _, is_array_api_compliant = get_namespace(attribute) if is_array_api_compliant or isinstance(attribute, numpy.ndarray): attribute = converter(attribute) setattr(new_estimator, key, attribute) From d15ef65f0791a3ab9630a7238788b8383e67628a Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Thu, 30 Mar 2023 10:41:47 -0400 Subject: [PATCH 16/42] TST Skip test if array_api_compat is not installed --- sklearn/utils/tests/test_validation.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sklearn/utils/tests/test_validation.py b/sklearn/utils/tests/test_validation.py index 7dcf77913e7c6..06adc9c493427 100644 --- a/sklearn/utils/tests/test_validation.py +++ b/sklearn/utils/tests/test_validation.py @@ -67,6 +67,7 @@ from sklearn.exceptions import NotFittedError, PositiveSpectrumWarning from sklearn.utils._testing import TempMemmap +from sklearn.utils._testing import skip_if_no_array_api_compat def test_as_float_array(): @@ -1839,6 +1840,7 @@ def test_pandas_array_returns_ndarray(input_values): assert_allclose(result, input_values) +@skip_if_no_array_api_compat @pytest.mark.parametrize("array_namespace", ["numpy.array_api", "cupy.array_api"]) def test_check_array_array_api_has_non_finite(array_namespace): """Checks that Array API arrays checks non-finite correctly.""" From 0e0f4592f9388b110f8960edcc089025617f8063 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Thu, 30 Mar 2023 10:43:55 -0400 Subject: [PATCH 17/42] DOC Adds docstring about global config --- sklearn/utils/_array_api.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py index 4ba3b2e5ed39b..cf09eeea6deb8 100644 --- a/sklearn/utils/_array_api.py +++ b/sklearn/utils/_array_api.py @@ -332,6 +332,7 @@ def get_namespace(*arrays): is_array_api_compliant : bool True of the arrays are containers that implement the Array API spec. + Always False when array_api_dispatch=False. """ array_api_dispatch = get_config()["array_api_dispatch"] if not array_api_dispatch: From deb22e811db448696702b2393ccc693b2b44b951 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Thu, 30 Mar 2023 11:35:50 -0400 Subject: [PATCH 18/42] TST Add skip for compat --- sklearn/tests/test_discriminant_analysis.py | 6 +++--- sklearn/utils/_testing.py | 2 +- sklearn/utils/tests/test_array_api.py | 19 +++++++++++-------- sklearn/utils/tests/test_validation.py | 4 ++-- 4 files changed, 17 insertions(+), 14 deletions(-) diff --git a/sklearn/tests/test_discriminant_analysis.py b/sklearn/tests/test_discriminant_analysis.py index 76b17c944e7af..4924609db0b8a 100644 --- a/sklearn/tests/test_discriminant_analysis.py +++ b/sklearn/tests/test_discriminant_analysis.py @@ -13,7 +13,7 @@ from sklearn.utils._testing import assert_almost_equal from sklearn.utils._array_api import _convert_to_numpy from sklearn.utils._testing import _convert_container -from sklearn.utils._testing import skip_if_no_array_api_compat +from sklearn.utils._testing import skip_if_array_api_compat_not_configured from sklearn.datasets import make_blobs from sklearn.discriminant_analysis import LinearDiscriminantAnalysis @@ -677,7 +677,7 @@ def test_get_feature_names_out(): assert_array_equal(names_out, expected_names_out) -@skip_if_no_array_api_compat +@skip_if_array_api_compat_not_configured @pytest.mark.parametrize("array_namespace", ["numpy.array_api", "cupy.array_api"]) def test_lda_array_api(array_namespace): """Check that the array_api Array gives the same results as ndarrays.""" @@ -735,7 +735,7 @@ def test_lda_array_api(array_namespace): ) -@skip_if_no_array_api_compat +@skip_if_array_api_compat_not_configured @pytest.mark.parametrize("device", ["cuda", "cpu"]) @pytest.mark.parametrize("dtype", ["float32", "float64"]) def test_lda_array_torch(device, dtype): diff --git a/sklearn/utils/_testing.py b/sklearn/utils/_testing.py index 9c3fd05d34efe..efd5aaee40efb 100644 --- a/sklearn/utils/_testing.py +++ b/sklearn/utils/_testing.py @@ -407,7 +407,7 @@ def set_random_state(estimator, random_state=0): skip_if_no_parallel = pytest.mark.skipif( not joblib.parallel.mp, reason="joblib is in serial mode" ) - skip_if_no_array_api_compat = pytest.mark.skipif( + skip_if_array_api_compat_not_configured = pytest.mark.skipif( not ARRAY_API_COMPAT_FUNCTIONAL, reason="requires array_api_compat installed and a new enough version of NumPy", ) diff --git a/sklearn/utils/tests/test_array_api.py b/sklearn/utils/tests/test_array_api.py index 62aae66f61fc1..9cc481114b386 100644 --- a/sklearn/utils/tests/test_array_api.py +++ b/sklearn/utils/tests/test_array_api.py @@ -9,7 +9,7 @@ from sklearn.utils._array_api import _asarray_with_order from sklearn.utils._array_api import _convert_to_numpy from sklearn.utils._array_api import _estimator_with_converted_arrays -from sklearn.utils._testing import skip_if_no_array_api_compat +from sklearn.utils._testing import skip_if_array_api_compat_not_configured from sklearn._config import config_context @@ -38,6 +38,7 @@ def test_get_namespace_ndarray_creation_device(): xp_out.zeros(10, device="cuda") +@skip_if_array_api_compat_not_configured def test_get_namespace_ndarray_with_dispatch(): """Test get_namespace on NumPy ndarrays.""" array_api_compat = pytest.importorskip("array_api_compat") @@ -50,7 +51,7 @@ def test_get_namespace_ndarray_with_dispatch(): assert xp_out is array_api_compat.numpy -@skip_if_no_array_api_compat +@skip_if_array_api_compat_not_configured def test_get_namespace_array_api(): """Test get_namespace for ArrayAPI arrays.""" xp = pytest.importorskip("numpy.array_api") @@ -133,7 +134,7 @@ def test_array_api_wrapper_take(): @pytest.mark.parametrize( "is_array_api_compliant", - [pytest.param(True, marks=skip_if_no_array_api_compat), False], + [pytest.param(True, marks=skip_if_array_api_compat_not_configured), False], ) def test_asarray_with_order(is_array_api_compliant): """Test _asarray_with_order passes along order for NumPy arrays.""" @@ -164,7 +165,7 @@ def test_asarray_with_order_ignored(): assert not X_new_np.flags["F_CONTIGUOUS"] -@skip_if_no_array_api_compat +@skip_if_array_api_compat_not_configured @pytest.mark.parametrize("library", ["cupy", "torch", "cupy.array_api"]) def test_convert_to_numpy_gpu(library): """Check convert_to_numpy for GPU backed libraries.""" @@ -199,7 +200,7 @@ def fit(self, X, y=None): return self -@skip_if_no_array_api_compat +@skip_if_array_api_compat_not_configured @pytest.mark.parametrize( "array_namespace, converter", [ @@ -219,7 +220,7 @@ def test_convert_estimator_to_ndarray(array_namespace, converter): assert isinstance(new_est.X_, numpy.ndarray) -@skip_if_no_array_api_compat +@skip_if_array_api_compat_not_configured def test_convert_estimator_to_array_api(): """Convert estimator attributes to ArrayAPI arrays.""" xp = pytest.importorskip("numpy.array_api") @@ -232,7 +233,8 @@ def test_convert_estimator_to_array_api(): @pytest.mark.parametrize( - "array_api_dispatch", [pytest.param(True, marks=skip_if_no_array_api_compat), False] + "array_api_dispatch", + [pytest.param(True, marks=skip_if_array_api_compat_not_configured), False], ) def test_get_namespace_array_api_isdtype(array_api_dispatch): """Test isdtype implementation from _ArrayAPIWrapper and array_api_compat.""" @@ -260,7 +262,8 @@ def test_get_namespace_array_api_isdtype(array_api_dispatch): @pytest.mark.parametrize( - "array_api_dispatch", [pytest.param(True, marks=skip_if_no_array_api_compat), False] + "array_api_dispatch", + [pytest.param(True, marks=skip_if_array_api_compat_not_configured), False], ) def test_get_namespace_list(array_api_dispatch): """Test get_namespace for lists.""" diff --git a/sklearn/utils/tests/test_validation.py b/sklearn/utils/tests/test_validation.py index 06adc9c493427..4a765d1404794 100644 --- a/sklearn/utils/tests/test_validation.py +++ b/sklearn/utils/tests/test_validation.py @@ -67,7 +67,7 @@ from sklearn.exceptions import NotFittedError, PositiveSpectrumWarning from sklearn.utils._testing import TempMemmap -from sklearn.utils._testing import skip_if_no_array_api_compat +from sklearn.utils._testing import skip_if_array_api_compat_not_configured def test_as_float_array(): @@ -1840,7 +1840,7 @@ def test_pandas_array_returns_ndarray(input_values): assert_allclose(result, input_values) -@skip_if_no_array_api_compat +@skip_if_array_api_compat_not_configured @pytest.mark.parametrize("array_namespace", ["numpy.array_api", "cupy.array_api"]) def test_check_array_array_api_has_non_finite(array_namespace): """Checks that Array API arrays checks non-finite correctly.""" From 4d1ef4bbf1d5f196c126eb60f7da1736091da41c Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Thu, 30 Mar 2023 11:37:46 -0400 Subject: [PATCH 19/42] CLN Be less recursive --- sklearn/utils/_array_api.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py index cf09eeea6deb8..590b765f0117c 100644 --- a/sklearn/utils/_array_api.py +++ b/sklearn/utils/_array_api.py @@ -91,7 +91,10 @@ def _isdtype_single(dtype, kind, *, xp): elif kind == "unsigned integer": return dtype in {xp.uint8, xp.uint16, xp.uint32, xp.uint64} elif kind == "integral": - return isdtype(dtype, ("signed integer", "unsigned integer"), xp=xp) + return any( + _isdtype_single(dtype, k, xp=xp) + for k in ("signed integer", "unsigned integer") + ) elif kind == "real floating": return dtype in {xp.float32, xp.float64} elif kind == "complex floating": @@ -103,8 +106,9 @@ def _isdtype_single(dtype, kind, *, xp): return dtype == xp.complex128 return False elif kind == "numeric": - return isdtype( - dtype, ("integral", "real floating", "complex floating"), xp=xp + return any( + _isdtype_single(dtype, k, xp=xp) + for k in ("integral", "real floating", "complex floating") ) else: raise ValueError(f"Unrecognized data type kind: {kind!r}") From 5229052ce650d36deae6f83f659acd44c71578ed Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Thu, 30 Mar 2023 13:16:42 -0400 Subject: [PATCH 20/42] ENH Fixes isdtype implementation --- sklearn/utils/_array_api.py | 17 +++------ sklearn/utils/tests/test_array_api.py | 50 +++++++++++++++------------ 2 files changed, 32 insertions(+), 35 deletions(-) diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py index 590b765f0117c..3eda6647cdcec 100644 --- a/sklearn/utils/_array_api.py +++ b/sklearn/utils/_array_api.py @@ -100,11 +100,12 @@ def _isdtype_single(dtype, kind, *, xp): elif kind == "complex floating": # Some name spaces do not have complex, such as cupy.array_api # and numpy.array_api + complex_dtypes = set() if hasattr(xp, "complex64"): - return dtype == xp.complex64 + complex_dtypes.add(xp.complex64) if hasattr(xp, "complex128"): - return dtype == xp.complex128 - return False + complex_dtypes.add(xp.complex128) + return dtype in complex_dtypes elif kind == "numeric": return any( _isdtype_single(dtype, k, xp=xp) @@ -161,11 +162,6 @@ def take(self, X, indices, *, axis=0): return self._namespace.stack(selected, axis=axis) def isdtype(self, dtype, kind): - """Returns a boolean indicating whether a provided dtype is of type "kind". - - Included in the v2022.12 of the Array API spec. - https://data-apis.org/array-api/latest/API_specification/generated/array_api.isdtype.html - """ return isdtype(dtype, kind, xp=self._namespace) @@ -268,11 +264,6 @@ def concat(self, arrays, *, axis=None): return numpy.concatenate(arrays, axis=axis) def isdtype(self, dtype, kind): - """Returns a boolean indicating whether a provided dtype is of type "kind". - - Included in the v2022.12 of the Array API spec. - https://data-apis.org/array-api/latest/API_specification/generated/array_api.isdtype.html - """ return isdtype(dtype, kind, xp=self) def reshape(self, x, shape, *, copy=None): diff --git a/sklearn/utils/tests/test_array_api.py b/sklearn/utils/tests/test_array_api.py index 9cc481114b386..c69ff6bbf937f 100644 --- a/sklearn/utils/tests/test_array_api.py +++ b/sklearn/utils/tests/test_array_api.py @@ -232,33 +232,39 @@ def test_convert_estimator_to_array_api(): assert hasattr(new_est.X_, "__array_namespace__") -@pytest.mark.parametrize( - "array_api_dispatch", - [pytest.param(True, marks=skip_if_array_api_compat_not_configured), False], -) -def test_get_namespace_array_api_isdtype(array_api_dispatch): - """Test isdtype implementation from _ArrayAPIWrapper and array_api_compat.""" - xp = pytest.importorskip("numpy.array_api") +@pytest.mark.parametrize("wrapper", [_ArrayAPIWrapper, _NumPyApiWrapper]) +def test_get_namespace_array_api_isdtype(wrapper): + """Test isdtype implementation from _ArrayAPIWrapper and _NumPyApiWrapper.""" - X_xp = xp.asarray([[1, 2, 3]]) - with config_context(array_api_dispatch=array_api_dispatch): - xp_out, _ = get_namespace(X_xp) - assert xp_out.isdtype(xp_out.float32, "real floating") - assert xp_out.isdtype(xp_out.float64, "real floating") - assert not xp_out.isdtype(xp_out.int32, "real floating") + if wrapper == _ArrayAPIWrapper: + xp_ = pytest.importorskip("numpy.array_api") + xp = _ArrayAPIWrapper(xp_) + else: + xp = _NumPyApiWrapper() + + assert xp.isdtype(xp.float32, "real floating") + assert xp.isdtype(xp.float64, "real floating") + assert not xp.isdtype(xp.int32, "real floating") + + assert xp.isdtype(xp.bool, "bool") + assert not xp.isdtype(xp.float32, "bool") + + assert xp.isdtype(xp.int16, "signed integer") + assert not xp.isdtype(xp.uint32, "signed integer") - assert xp_out.isdtype(xp_out.bool, "bool") - assert not xp_out.isdtype(xp_out.float32, "bool") + assert xp.isdtype(xp.uint16, "unsigned integer") + assert not xp.isdtype(xp.int64, "unsigned integer") - assert xp_out.isdtype(xp_out.int16, "signed integer") - assert not xp_out.isdtype(xp_out.uint32, "signed integer") + assert xp.isdtype(xp.int64, "numeric") + assert xp.isdtype(xp.float32, "numeric") + assert xp.isdtype(xp.uint32, "numeric") - assert xp_out.isdtype(xp_out.uint16, "unsigned integer") - assert not xp_out.isdtype(xp_out.int64, "unsigned integer") + assert not xp.isdtype(xp.float32, "complex floating") - assert xp_out.isdtype(xp_out.int64, "numeric") - assert xp_out.isdtype(xp_out.float32, "numeric") - assert xp_out.isdtype(xp_out.uint32, "numeric") + if wrapper == _NumPyApiWrapper: + assert not xp.isdtype(xp.int8, "complex floating") + assert xp.isdtype(xp.complex64, "complex floating") + assert xp.isdtype(xp.complex128, "complex floating") @pytest.mark.parametrize( From 2bf884545c82cc40ebd832f1e329d463081f5d5c Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Thu, 30 Mar 2023 13:19:06 -0400 Subject: [PATCH 21/42] DOC Fixes docstring --- sklearn/utils/_array_api.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py index 3eda6647cdcec..5efef3fd6654d 100644 --- a/sklearn/utils/_array_api.py +++ b/sklearn/utils/_array_api.py @@ -34,18 +34,17 @@ def _check_array_api_dispatch(array_api_dispatch): def device(x): - """ - Hardware device the array data resides on. + """Hardware device the array data resides on. Parameters ---------- - x: array - array instance from NumPy or an array API compatible library. + x : array + Array instance from NumPy or an array API compatible library. Returns ------- - out: device - a ``device`` object (see the "Device Support" section of the array API spec). + out : device + `device` object (see the "Device Support" section of the array API spec). """ if isinstance(x, (numpy.ndarray, numpy.generic)): return "cpu" @@ -53,8 +52,17 @@ def device(x): def size(x): - """ - Return the total number of elements of x + """Return the total number of elements of x. + + Parameters + ---------- + x : array + Array instance from NumPy or an array API compatible library. + + Returns + ------- + out : int + Total number of elements. """ if None in x.shape: return None From 0f691a680854ccabc419f2127126154c88e91c49 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Thu, 30 Mar 2023 13:47:33 -0400 Subject: [PATCH 22/42] CLN Remove unneeded name --- sklearn/utils/_array_api.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py index 5efef3fd6654d..f522f1a667a8f 100644 --- a/sklearn/utils/_array_api.py +++ b/sklearn/utils/_array_api.py @@ -70,12 +70,7 @@ def size(x): def _is_numpy_namespace(xp): - return xp.__name__ in { - "numpy", - "_NumPyApiWrapper", - "array_api_compat.numpy", - "numpy.array_api", - } + return xp.__name__ in {"numpy", "array_api_compat.numpy", "numpy.array_api"} def isdtype(dtype, kind, *, xp): From d51f9749dc8bb65cb9f86ab88af335865d7ce13f Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Thu, 30 Mar 2023 14:01:05 -0400 Subject: [PATCH 23/42] TST Adds skip test --- sklearn/utils/tests/test_array_api.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/sklearn/utils/tests/test_array_api.py b/sklearn/utils/tests/test_array_api.py index c69ff6bbf937f..59edd584cd33f 100644 --- a/sklearn/utils/tests/test_array_api.py +++ b/sklearn/utils/tests/test_array_api.py @@ -132,16 +132,10 @@ def test_array_api_wrapper_take(): xp.take(xp.asarray([[[0]]]), xp.asarray([0]), axis=0) -@pytest.mark.parametrize( - "is_array_api_compliant", - [pytest.param(True, marks=skip_if_array_api_compat_not_configured), False], -) -def test_asarray_with_order(is_array_api_compliant): +@pytest.mark.parametrize("array_api", ["numpy", "numpy.array_api"]) +def test_asarray_with_order(array_api): """Test _asarray_with_order passes along order for NumPy arrays.""" - if is_array_api_compliant: - xp = pytest.importorskip("numpy.array_api") - else: - xp = numpy + xp = pytest.importorskip(array_api) X = xp.asarray([1.2, 3.4, 5.1]) X_new = _asarray_with_order(X, order="F", xp=xp) @@ -183,6 +177,7 @@ def test_convert_to_numpy_gpu(library): assert_allclose(X_cpu, expected_output) +@skip_if_array_api_compat_not_configured def test_convert_to_numpy_cpu(): """Check convert_to_numpy for PyTorch CPU arrays.""" torch = pytest.importorskip("torch") From d2d6e9c7cacd4d38cdc1e4daf5c4704bee406c98 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Thu, 30 Mar 2023 14:04:42 -0400 Subject: [PATCH 24/42] ENH Allows xp to be passed into _convert_to_numpy --- sklearn/tests/test_discriminant_analysis.py | 8 ++++---- sklearn/utils/_array_api.py | 4 ++-- sklearn/utils/tests/test_array_api.py | 5 ++--- 3 files changed, 8 insertions(+), 9 deletions(-) diff --git a/sklearn/tests/test_discriminant_analysis.py b/sklearn/tests/test_discriminant_analysis.py index 4924609db0b8a..d780e9a31bf41 100644 --- a/sklearn/tests/test_discriminant_analysis.py +++ b/sklearn/tests/test_discriminant_analysis.py @@ -703,7 +703,7 @@ def test_lda_array_api(array_namespace): lda_xp_param = getattr(lda_xp, key) assert hasattr(lda_xp_param, "__array_namespace__") - lda_xp_param_np = _convert_to_numpy(lda_xp_param) + lda_xp_param_np = _convert_to_numpy(lda_xp_param, xp=xp) assert_allclose( attribute, lda_xp_param_np, err_msg=f"{key} not the same", atol=1e-3 ) @@ -725,7 +725,7 @@ def test_lda_array_api(array_namespace): result_xp, "__array_namespace__" ), f"{method} did not output an array_namespace" - result_xp_np = _convert_to_numpy(result_xp) + result_xp_np = _convert_to_numpy(result_xp, xp=xp) assert_allclose( result, @@ -764,7 +764,7 @@ def test_lda_array_torch(device, dtype): assert isinstance(lda_xp_param, torch.Tensor) assert lda_xp_param.device.type == device - lda_xp_param_np = _convert_to_numpy(lda_xp_param) + lda_xp_param_np = _convert_to_numpy(lda_xp_param, xp=torch) assert_allclose( attribute, lda_xp_param_np, err_msg=f"{key} not the same", atol=1e-3 ) @@ -785,7 +785,7 @@ def test_lda_array_torch(device, dtype): assert isinstance(result_xp, torch.Tensor) assert result_xp.device.type == device - result_xp_np = _convert_to_numpy(result_xp) + result_xp_np = _convert_to_numpy(result_xp, xp=torch) assert_allclose( result, diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py index f522f1a667a8f..8f37e55bc472b 100644 --- a/sklearn/utils/_array_api.py +++ b/sklearn/utils/_array_api.py @@ -388,9 +388,9 @@ def _asarray_with_order(array, dtype=None, order=None, copy=None, *, xp=None): return xp.asarray(array, dtype=dtype, copy=copy) -def _convert_to_numpy(array): +def _convert_to_numpy(array, xp=None): """Convert X into a NumPy ndarray on the CPU.""" - with config_context(array_api_dispatch=True): + if xp is None: xp, _ = get_namespace(array) xp_name = xp.__name__ diff --git a/sklearn/utils/tests/test_array_api.py b/sklearn/utils/tests/test_array_api.py index 59edd584cd33f..a6ad796eb4135 100644 --- a/sklearn/utils/tests/test_array_api.py +++ b/sklearn/utils/tests/test_array_api.py @@ -172,18 +172,17 @@ def test_convert_to_numpy_gpu(library): else: X_gpu = xp.asarray([1.0, 2.0, 3.0]) - X_cpu = _convert_to_numpy(X_gpu) + X_cpu = _convert_to_numpy(X_gpu, xp=xp) expected_output = numpy.asarray([1.0, 2.0, 3.0]) assert_allclose(X_cpu, expected_output) -@skip_if_array_api_compat_not_configured def test_convert_to_numpy_cpu(): """Check convert_to_numpy for PyTorch CPU arrays.""" torch = pytest.importorskip("torch") X_torch = torch.asarray([1.0, 2.0, 3.0], device="cpu") - X_cpu = _convert_to_numpy(X_torch) + X_cpu = _convert_to_numpy(X_torch, xp=torch) expected_output = numpy.asarray([1.0, 2.0, 3.0]) assert_allclose(X_cpu, expected_output) From de5633e073971cf44d1b140ef64f153287dd6374 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Fri, 31 Mar 2023 11:54:53 -0400 Subject: [PATCH 25/42] CLN Remove None in x.shape --- sklearn/utils/_array_api.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py index 8f37e55bc472b..856a4c126ad45 100644 --- a/sklearn/utils/_array_api.py +++ b/sklearn/utils/_array_api.py @@ -64,8 +64,6 @@ def size(x): out : int Total number of elements. """ - if None in x.shape: - return None return math.prod(x.shape) From 3d2553413a2b79c18ba6f8371d15f904cef68604 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Fri, 31 Mar 2023 11:59:42 -0400 Subject: [PATCH 26/42] Apply suggestions from code review Co-authored-by: Olivier Grisel --- sklearn/utils/_array_api.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py index 856a4c126ad45..ce9bf2655a8ee 100644 --- a/sklearn/utils/_array_api.py +++ b/sklearn/utils/_array_api.py @@ -337,6 +337,10 @@ def get_namespace(*arrays): _check_array_api_dispatch(array_api_dispatch) try: + # array-api-compat is a required dependency of scikit-learn only when + # configuring `array_api_dispatch=True`. Its import should therefore be + # protected by _check_array_api_dispatch to display an informative error + # message in case it is missing. import array_api_compat namespace, is_array_api_compliant = ( From bd90dc3b8c0605df5f3755e5c9103a778083d4ef Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Fri, 31 Mar 2023 12:28:06 -0400 Subject: [PATCH 27/42] CLN Simplify get_namespace logic --- sklearn/utils/_array_api.py | 26 ++++++++++++-------------- sklearn/utils/tests/test_array_api.py | 19 ++----------------- 2 files changed, 14 insertions(+), 31 deletions(-) diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py index ce9bf2655a8ee..a87688bb98115 100644 --- a/sklearn/utils/_array_api.py +++ b/sklearn/utils/_array_api.py @@ -336,19 +336,13 @@ def get_namespace(*arrays): _check_array_api_dispatch(array_api_dispatch) - try: - # array-api-compat is a required dependency of scikit-learn only when - # configuring `array_api_dispatch=True`. Its import should therefore be - # protected by _check_array_api_dispatch to display an informative error - # message in case it is missing. - import array_api_compat - - namespace, is_array_api_compliant = ( - array_api_compat.get_namespace(*arrays), - True, - ) - except (TypeError, ImportError): - return _NUMPY_API_WRAPPER_INSTANCE, False + # array-api-compat is a required dependency of scikit-learn only when + # configuring `array_api_dispatch=True`. Its import should therefore be + # protected by _check_array_api_dispatch to display an informative error + # message in case it is missing. + import array_api_compat + + namespace, is_array_api_compliant = array_api_compat.get_namespace(*arrays), True if namespace.__name__ in {"numpy.array_api", "cupy.array_api"}: namespace = _ArrayAPIWrapper(namespace) @@ -428,7 +422,11 @@ def _estimator_with_converted_arrays(estimator, converter): new_estimator = clone(estimator) for key, attribute in vars(estimator).items(): with config_context(array_api_dispatch=True): - _, is_array_api_compliant = get_namespace(attribute) + try: + _, is_array_api_compliant = get_namespace(attribute) + except TypeError: + is_array_api_compliant = False + if is_array_api_compliant or isinstance(attribute, numpy.ndarray): attribute = converter(attribute) setattr(new_estimator, key, attribute) diff --git a/sklearn/utils/tests/test_array_api.py b/sklearn/utils/tests/test_array_api.py index a6ad796eb4135..f6a3762259156 100644 --- a/sklearn/utils/tests/test_array_api.py +++ b/sklearn/utils/tests/test_array_api.py @@ -63,9 +63,8 @@ def test_get_namespace_array_api(): assert is_array_api_compliant assert isinstance(xp_out, _ArrayAPIWrapper) - xp_out, is_array_api_compliant = get_namespace(1) - assert isinstance(xp_out, _NumPyApiWrapper) - assert not is_array_api_compliant + with pytest.raises(TypeError): + xp_out, is_array_api_compliant = get_namespace(X_xp, X_np) class _AdjustableNameAPITestWrapper(_ArrayAPIWrapper): @@ -261,20 +260,6 @@ def test_get_namespace_array_api_isdtype(wrapper): assert xp.isdtype(xp.complex128, "complex floating") -@pytest.mark.parametrize( - "array_api_dispatch", - [pytest.param(True, marks=skip_if_array_api_compat_not_configured), False], -) -def test_get_namespace_list(array_api_dispatch): - """Test get_namespace for lists.""" - - X = [1, 2, 3] - with config_context(array_api_dispatch=array_api_dispatch): - xp_out, is_array = get_namespace(X) - assert not is_array - assert isinstance(xp_out, _NumPyApiWrapper) - - def test_error_when_reshape_is_not_a_tuple(): """Raise error in reshape when shape is not a tuple.""" X = numpy.asarray([[1, 2, 3], [3, 4, 5]]) From 32f11045144612b2fe16be5da165668511be4ba5 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Fri, 31 Mar 2023 12:38:25 -0400 Subject: [PATCH 28/42] TST Improve coverage --- sklearn/utils/_array_api.py | 7 +++---- sklearn/utils/tests/test_array_api.py | 14 ++++++++++---- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py index a87688bb98115..85f18e69801a0 100644 --- a/sklearn/utils/_array_api.py +++ b/sklearn/utils/_array_api.py @@ -274,13 +274,12 @@ def reshape(self, x, shape, *, copy=None): https://data-apis.org/array-api/latest/API_specification/generated/array_api.reshape.html """ if not isinstance(shape, tuple): - raise TypeError("shape must be a tuple") + raise TypeError( + f"shape must be a tuple, got {shape!r} of type {type(shape)}" + ) if copy is True: x = x.copy() - elif copy is False: - x.shape = shape - return x return numpy.reshape(x, shape) diff --git a/sklearn/utils/tests/test_array_api.py b/sklearn/utils/tests/test_array_api.py index f6a3762259156..440bd093bc5ac 100644 --- a/sklearn/utils/tests/test_array_api.py +++ b/sklearn/utils/tests/test_array_api.py @@ -260,10 +260,16 @@ def test_get_namespace_array_api_isdtype(wrapper): assert xp.isdtype(xp.complex128, "complex floating") -def test_error_when_reshape_is_not_a_tuple(): - """Raise error in reshape when shape is not a tuple.""" - X = numpy.asarray([[1, 2, 3], [3, 4, 5]]) - xp, _ = get_namespace(X) +def test_reshape_behavior(): + """Check reshape behavior with copy and is strict with non-tuple shape.""" + xp = _NumPyApiWrapper() + X = xp.asarray([[1, 2, 3], [3, 4, 5]]) + + X_no_copy = xp.reshape(X, (-1,), copy=False) + assert X_no_copy.base is X.base + + X_copy = xp.reshape(X, (6, 1), copy=True) + assert X_copy.base is not X.base with pytest.raises(TypeError, match="shape must be a tuple"): xp.reshape(X, -1) From c54a0727bc214b1f0e51fd317aa2ca7c50f64e90 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Fri, 31 Mar 2023 13:15:20 -0400 Subject: [PATCH 29/42] CI Adds pytorch to CI --- ...latest_conda_forge_mkl_linux-64_conda.lock | 43 ++++++++++--------- ...t_conda_forge_mkl_linux-64_environment.yml | 2 + .../update_environments_and_lock_files.py | 4 +- 3 files changed, 27 insertions(+), 22 deletions(-) diff --git a/build_tools/azure/pylatest_conda_forge_mkl_linux-64_conda.lock b/build_tools/azure/pylatest_conda_forge_mkl_linux-64_conda.lock index f3b23884d3746..915bb9ac8aa1f 100644 --- a/build_tools/azure/pylatest_conda_forge_mkl_linux-64_conda.lock +++ b/build_tools/azure/pylatest_conda_forge_mkl_linux-64_conda.lock @@ -1,6 +1,6 @@ # Generated by conda-lock. # platform: linux-64 -# input_hash: 074039c4568f960fb05dba525d0bfb9c74b80f243b83b0abdc0cdf9eafa1ef94 +# input_hash: 11b2d8c50a1abca3839d7665afc331d1b3d8f3088d7c5ff50a62cdffdd90b072 @EXPLICIT https://conda.anaconda.org/conda-forge/linux-64/_libgcc_mutex-0.1-conda_forge.tar.bz2#d7c89558ba9fa0495403155b64376d81 https://conda.anaconda.org/conda-forge/linux-64/ca-certificates-2022.12.7-ha878542_0.conda#ff9f73d45c4a07d6f424495288a26080 @@ -11,9 +11,8 @@ https://conda.anaconda.org/conda-forge/noarch/font-ttf-ubuntu-0.83-hab24e00_0.ta https://conda.anaconda.org/conda-forge/linux-64/ld_impl_linux-64-2.40-h41732ed_0.conda#7aca3059a1729aa76c597603f10b0dd3 https://conda.anaconda.org/conda-forge/linux-64/libgfortran5-12.2.0-h337968e_19.tar.bz2#164b4b1acaedc47ee7e658ae6b308ca3 https://conda.anaconda.org/conda-forge/linux-64/libstdcxx-ng-12.2.0-h46fd767_19.tar.bz2#1030b1f38c129f2634eae026f704fe60 -https://conda.anaconda.org/conda-forge/linux-64/mkl-include-2022.1.0-h84fe81f_915.tar.bz2#2dcd1acca05c11410d4494d7fc7dfa2a https://conda.anaconda.org/conda-forge/linux-64/python_abi-3.11-3_cp311.conda#c2e2630ddb68cf52eec74dc7dfab20b5 -https://conda.anaconda.org/conda-forge/noarch/tzdata-2022g-h191b570_0.conda#51fc4fcfb19f5d95ffc8c339db5068e8 +https://conda.anaconda.org/conda-forge/noarch/tzdata-2023c-h71feb2d_0.conda#939e3e74d8be4dac89ce83b20de2492a https://conda.anaconda.org/conda-forge/noarch/fonts-conda-forge-1-0.tar.bz2#f766549260d6815b0c52253f1fb1bb29 https://conda.anaconda.org/conda-forge/linux-64/libgfortran-ng-12.2.0-h69a702a_19.tar.bz2#cd7a806282c16e1f2d39a7e80d3a3e0d https://conda.anaconda.org/conda-forge/noarch/fonts-conda-ecosystem-1-0.tar.bz2#fee5683a3f04bd15cbd8318b096a27ab @@ -22,7 +21,6 @@ https://conda.anaconda.org/conda-forge/linux-64/libgcc-ng-12.2.0-h65d4601_19.tar https://conda.anaconda.org/conda-forge/linux-64/alsa-lib-1.2.8-h166bdaf_0.tar.bz2#be733e69048951df1e4b4b7bb8c7666f https://conda.anaconda.org/conda-forge/linux-64/attr-2.5.1-h166bdaf_1.tar.bz2#d9c69a24ad678ffce24c6543a0176b00 https://conda.anaconda.org/conda-forge/linux-64/bzip2-1.0.8-h7f98852_4.tar.bz2#a1fd65c7ccbf10880423d82bca54eb54 -https://conda.anaconda.org/conda-forge/linux-64/expat-2.5.0-h27087fc_0.tar.bz2#c4fbad8d4bddeb3c085f18cbf97fbfad https://conda.anaconda.org/conda-forge/linux-64/fftw-3.3.10-nompi_hf0379b8_106.conda#d7407e695358f068a2a7f8295cde0567 https://conda.anaconda.org/conda-forge/linux-64/gettext-0.21.1-h27087fc_0.tar.bz2#14947d8770185e5153fdd04d4673ed37 https://conda.anaconda.org/conda-forge/linux-64/graphite2-1.3.13-h58526e2_1001.tar.bz2#8c54672728e8ec6aa6db90cf2806d220 @@ -33,7 +31,8 @@ https://conda.anaconda.org/conda-forge/linux-64/lame-3.100-h166bdaf_1003.tar.bz2 https://conda.anaconda.org/conda-forge/linux-64/lerc-4.0.0-h27087fc_0.tar.bz2#76bbff344f0134279f225174e9064c8f https://conda.anaconda.org/conda-forge/linux-64/libbrotlicommon-1.0.9-h166bdaf_8.tar.bz2#9194c9bf9428035a05352d031462eae4 https://conda.anaconda.org/conda-forge/linux-64/libdb-6.2.32-h9c3ff4c_0.tar.bz2#3f3258d8f841fbac63b36b75bdac1afd -https://conda.anaconda.org/conda-forge/linux-64/libdeflate-1.17-h0b41bf4_0.conda#5cc781fd91968b11a8a7fdbee0982676 +https://conda.anaconda.org/conda-forge/linux-64/libdeflate-1.18-h0b41bf4_0.conda#6aa9c9de5542ecb07fdda9ca626252d8 +https://conda.anaconda.org/conda-forge/linux-64/libexpat-2.5.0-hcb278e6_1.conda#6305a3dd2752c76335295da4e581f2fd https://conda.anaconda.org/conda-forge/linux-64/libffi-3.4.2-h7f98852_5.tar.bz2#d645c6d2ac96843a2bfaccd2d62b3ac3 https://conda.anaconda.org/conda-forge/linux-64/libhiredis-1.0.2-h2cc385e_0.tar.bz2#b34907d3a81a3cd8095ee83d174c074a https://conda.anaconda.org/conda-forge/linux-64/libiconv-1.17-h166bdaf_0.tar.bz2#b62b52da46c39ee2bc3c162ac7f1804d @@ -42,7 +41,7 @@ https://conda.anaconda.org/conda-forge/linux-64/libnsl-2.0.0-h7f98852_0.tar.bz2# https://conda.anaconda.org/conda-forge/linux-64/libogg-1.3.4-h7f98852_1.tar.bz2#6e8cc2173440d77708196c5b93771680 https://conda.anaconda.org/conda-forge/linux-64/libopus-1.3.1-h7f98852_1.tar.bz2#15345e56d527b330e1cacbdf58676e8f https://conda.anaconda.org/conda-forge/linux-64/libtool-2.4.7-h27087fc_0.conda#f204c8ba400ec475452737094fb81d52 -https://conda.anaconda.org/conda-forge/linux-64/libuuid-2.32.1-h7f98852_1000.tar.bz2#772d69f030955d9646d3d0eaf21d859d +https://conda.anaconda.org/conda-forge/linux-64/libuuid-2.38.1-h0b41bf4_0.conda#40b61aab5c7ba9ff276c41cfffe6b80b https://conda.anaconda.org/conda-forge/linux-64/libwebp-base-1.3.0-h0b41bf4_0.conda#0d4a7508d8c6c65314f2b9c1f56ad408 https://conda.anaconda.org/conda-forge/linux-64/libzlib-1.2.13-h166bdaf_4.tar.bz2#f3f9de449d32ca9b9c66a22863c96f41 https://conda.anaconda.org/conda-forge/linux-64/lz4-c-1.9.4-hcb278e6_0.conda#318b08df404f9c9be5712aaa5a6f0bb0 @@ -52,6 +51,7 @@ https://conda.anaconda.org/conda-forge/linux-64/nspr-4.35-h27087fc_0.conda#da0ec https://conda.anaconda.org/conda-forge/linux-64/openssl-3.1.0-h0b41bf4_0.conda#2d833be81a21128e317325a01326d36f https://conda.anaconda.org/conda-forge/linux-64/pixman-0.40.0-h36c2ea0_0.tar.bz2#660e72c82f2e75a6b3fe6a6e75c79f19 https://conda.anaconda.org/conda-forge/linux-64/pthread-stubs-0.4-h36c2ea0_1001.tar.bz2#22dad4df6e8630e8dff2428f6f6a7036 +https://conda.anaconda.org/conda-forge/linux-64/sleef-3.5.1-h9b69904_2.tar.bz2#6e016cf4c525d04a7bd038cee53ad3fd https://conda.anaconda.org/conda-forge/linux-64/xkeyboard-config-2.38-h0b41bf4_0.conda#9ac34337e5101a87e5d91da05d84aa48 https://conda.anaconda.org/conda-forge/linux-64/xorg-kbproto-1.0.7-h7f98852_1002.tar.bz2#4b230e8381279d76131116660f5a241a https://conda.anaconda.org/conda-forge/linux-64/xorg-libice-1.0.10-h7f98852_0.tar.bz2#d6b0b50b49eccfe0be0373be628be0f3 @@ -62,6 +62,7 @@ https://conda.anaconda.org/conda-forge/linux-64/xorg-xextproto-7.3.0-h0b41bf4_10 https://conda.anaconda.org/conda-forge/linux-64/xorg-xf86vidmodeproto-2.3.1-h7f98852_1002.tar.bz2#3ceea9668625c18f19530de98b15d5b0 https://conda.anaconda.org/conda-forge/linux-64/xorg-xproto-7.0.31-h7f98852_1007.tar.bz2#b4a4381d54784606820704f7b5f05a15 https://conda.anaconda.org/conda-forge/linux-64/xz-5.2.6-h166bdaf_0.tar.bz2#2161070d867d1b1204ea749c8eec4ef0 +https://conda.anaconda.org/conda-forge/linux-64/expat-2.5.0-hcb278e6_1.conda#8b9b5aca60558d02ddaa09d599e55920 https://conda.anaconda.org/conda-forge/linux-64/jack-1.9.22-h11f4161_0.conda#504fa9e712b99494a9cf4630e3ca7d78 https://conda.anaconda.org/conda-forge/linux-64/libbrotlidec-1.0.9-h166bdaf_8.tar.bz2#4ae4d7795d33e02bd20f6b23d91caf82 https://conda.anaconda.org/conda-forge/linux-64/libbrotlienc-1.0.9-h166bdaf_8.tar.bz2#04bac51ba35ea023dc48af73c1c88c25 @@ -71,6 +72,7 @@ https://conda.anaconda.org/conda-forge/linux-64/libevent-2.1.10-h28343ad_4.tar.b https://conda.anaconda.org/conda-forge/linux-64/libflac-1.4.2-h27087fc_0.tar.bz2#7daf72d8e2a8e848e11d63ed6d1026e0 https://conda.anaconda.org/conda-forge/linux-64/libgpg-error-1.46-h620e276_0.conda#27e745f6f2e4b757e95dd7225fbe6bdb https://conda.anaconda.org/conda-forge/linux-64/libpng-1.6.39-h753d276_0.conda#e1c890aebdebbfbf87e2c917187b4416 +https://conda.anaconda.org/conda-forge/linux-64/libprotobuf-3.21.12-h3eb15da_0.conda#4b36c68184c6c85d88c6e595a32a1ede https://conda.anaconda.org/conda-forge/linux-64/libsqlite-3.40.0-h753d276_0.tar.bz2#2e5f9a37d487e1019fd4d8113adb2f9f https://conda.anaconda.org/conda-forge/linux-64/libvorbis-1.3.7-h9c3ff4c_0.tar.bz2#309dec04b70a3cc0f1e84a4013683bc0 https://conda.anaconda.org/conda-forge/linux-64/libxcb-1.13-h7f98852_1004.tar.bz2#b3653fdc58d03face9724f602218a904 @@ -91,13 +93,13 @@ https://conda.anaconda.org/conda-forge/linux-64/libglib-2.74.1-h606061b_1.tar.bz https://conda.anaconda.org/conda-forge/linux-64/libhwloc-2.9.0-hd6dc26d_0.conda#ab9d052373c9376c0ebcff4dfef3d296 https://conda.anaconda.org/conda-forge/linux-64/libllvm15-15.0.7-hadd5161_1.conda#17d91085ccf5934ce652cb448d0cb65a https://conda.anaconda.org/conda-forge/linux-64/libsndfile-1.2.0-hb75c966_0.conda#c648d19cd9c8625898d5d370414de7c7 -https://conda.anaconda.org/conda-forge/linux-64/libtiff-4.5.0-hddfeb54_5.conda#d2343e6594c2a4a654a475a6131ef20d +https://conda.anaconda.org/conda-forge/linux-64/libtiff-4.5.0-ha587672_6.conda#4e5ee4b062c21519efbee7e2ae608748 https://conda.anaconda.org/conda-forge/linux-64/libudev1-253-h0b41bf4_1.conda#bb38b19a41bb94e8a19dbfb062d499c7 https://conda.anaconda.org/conda-forge/linux-64/libxkbcommon-1.5.0-h79f4944_1.conda#04a39cdd663f295653fc143851830563 https://conda.anaconda.org/conda-forge/linux-64/llvm-openmp-16.0.0-h417c0b6_0.conda#8ac4c157172ea816f5c9a0dc33df69d8 https://conda.anaconda.org/conda-forge/linux-64/mysql-libs-8.0.32-hd7da12d_1.conda#a69fa6f218cfed8e2d61753eeacaf034 https://conda.anaconda.org/conda-forge/linux-64/nss-3.89-he45b914_0.conda#2745719a58eeaab6657256a3f142f099 -https://conda.anaconda.org/conda-forge/linux-64/python-3.11.0-he550d4f_1_cpython.conda#8d14fc2aa12db370a443753c8230be1e +https://conda.anaconda.org/conda-forge/linux-64/python-3.11.1-h2755cc3_0_cpython.conda#2b276315a584c0237e384829ef95fae3 https://conda.anaconda.org/conda-forge/linux-64/xcb-util-0.4.0-h166bdaf_0.tar.bz2#384e7fcb3cd162ba3e4aed4b687df566 https://conda.anaconda.org/conda-forge/linux-64/xcb-util-keysyms-0.4.0-h166bdaf_0.tar.bz2#637054603bb7594302e3bf83f0a99879 https://conda.anaconda.org/conda-forge/linux-64/xcb-util-renderutil-0.3.9-h166bdaf_0.tar.bz2#732e22f1741bccea861f5668cf7342a7 @@ -132,8 +134,8 @@ https://conda.anaconda.org/conda-forge/noarch/py-1.11.0-pyh6c4a22f_0.tar.bz2#b46 https://conda.anaconda.org/conda-forge/noarch/pycparser-2.21-pyhd8ed1ab_0.tar.bz2#076becd9e05608f8dc72757d5f3a91ff https://conda.anaconda.org/conda-forge/noarch/pyparsing-3.0.9-pyhd8ed1ab_0.tar.bz2#e8fbc1b54b25f4b08281467bc13b70cc https://conda.anaconda.org/conda-forge/noarch/pysocks-1.7.1-pyha2e5f31_6.tar.bz2#2a7de29fb590ca14b5243c4c812c8025 -https://conda.anaconda.org/conda-forge/noarch/pytz-2022.7.1-pyhd8ed1ab_0.conda#f59d49a7b464901cf714b9e7984d01a2 -https://conda.anaconda.org/conda-forge/noarch/setuptools-67.6.0-pyhd8ed1ab_0.conda#e18ed61c37145bb9b48d1d98801960f7 +https://conda.anaconda.org/conda-forge/noarch/pytz-2023.3-pyhd8ed1ab_0.conda#d3076b483092a435832603243567bc31 +https://conda.anaconda.org/conda-forge/noarch/setuptools-67.6.1-pyhd8ed1ab_0.conda#6c443cccff3daa3d83b2b807b0a298ce https://conda.anaconda.org/conda-forge/noarch/six-1.16.0-pyh6c4a22f_0.tar.bz2#e5f25f8dbc060e9a8d912e432202afc2 https://conda.anaconda.org/conda-forge/linux-64/tbb-2021.8.0-hf52228f_0.conda#b4188d0c54ead87b3c6bc9cb07281f40 https://conda.anaconda.org/conda-forge/noarch/threadpoolctl-3.1.0-pyh8a188c0_0.tar.bz2#a2995ee828f65687ac5b1e71a2ab1e0c @@ -147,24 +149,24 @@ https://conda.anaconda.org/conda-forge/linux-64/xorg-libxrender-0.9.10-h7f98852_ https://conda.anaconda.org/conda-forge/linux-64/cairo-1.16.0-ha61ee94_1014.tar.bz2#d1a88f3ed5b52e1024b80d4bcd26a7a0 https://conda.anaconda.org/conda-forge/linux-64/cffi-1.15.1-py311h409f033_3.conda#9025d0786dbbe4bc91fd8e85502decce https://conda.anaconda.org/conda-forge/linux-64/coverage-7.2.2-py311h2582759_0.conda#41f2bca794aa1b4e70c1a23f8ec2dfa5 -https://conda.anaconda.org/conda-forge/linux-64/fonttools-4.39.2-py311h2582759_0.conda#f227528f25c3d45717f71774222a2200 +https://conda.anaconda.org/conda-forge/linux-64/fonttools-4.39.3-py311h2582759_0.conda#55741f37ab19d949b8e7316cfe286824 https://conda.anaconda.org/conda-forge/linux-64/glib-2.74.1-h6239696_1.tar.bz2#f3220a9e9d3abcbfca43419a219df7e4 https://conda.anaconda.org/conda-forge/noarch/joblib-1.2.0-pyhd8ed1ab_0.tar.bz2#7583652522d71ad78ba536bba06940eb https://conda.anaconda.org/conda-forge/linux-64/libclang-15.0.7-default_had23c3d_1.conda#36c65ed73b7c92589bd9562ef8a6023d -https://conda.anaconda.org/conda-forge/linux-64/mkl-2022.1.0-h84fe81f_915.tar.bz2#b9c8f925797a93dbff45e1626b025a6b +https://conda.anaconda.org/conda-forge/linux-64/mkl-2022.2.1-h84fe81f_16997.conda#a7ce56d5757f5b57e7daabe703ade5bb https://conda.anaconda.org/conda-forge/linux-64/pillow-9.4.0-py311h573f0d3_2.conda#7321881b545202cf9ab8bd24b4151dcb https://conda.anaconda.org/conda-forge/linux-64/pulseaudio-client-16.1-h5195f5e_3.conda#caeb3302ef1dc8b342b20c710a86f8a9 https://conda.anaconda.org/conda-forge/noarch/pytest-7.2.2-pyhd8ed1ab_0.conda#60958b19354e0ec295b43f6ab5cfab86 https://conda.anaconda.org/conda-forge/noarch/python-dateutil-2.8.2-pyhd8ed1ab_0.tar.bz2#dd999d1cc9f79e67dbb855c8924c7984 https://conda.anaconda.org/conda-forge/linux-64/sip-6.7.7-py311hcafe171_0.conda#cf1adb3a0138cca4bbab415ddf2f57f1 https://conda.anaconda.org/conda-forge/noarch/typing-extensions-4.5.0-hd8ed1ab_0.conda#b3c594fde1a80a1fc3eb9cc4a5dfe392 +https://conda.anaconda.org/conda-forge/linux-64/blas-1.0-mkl.tar.bz2#349aef876b1d8c9dccae01de20d5b385 https://conda.anaconda.org/conda-forge/linux-64/brotlipy-0.7.0-py311hd4cff14_1005.tar.bz2#9bdac7084ecfc08338bae1b976535724 -https://conda.anaconda.org/conda-forge/linux-64/cryptography-39.0.2-py311h9b4c7bb_0.conda#f2e305191c0e28ef746a4e6b365807f9 +https://conda.anaconda.org/conda-forge/linux-64/cryptography-40.0.1-py311h9b4c7bb_0.conda#7638d1e31ce0a029211e12fc719e52e6 https://conda.anaconda.org/conda-forge/linux-64/gstreamer-1.22.0-h25f0c4b_2.conda#461541cb1b387c2a28ab6217f3d38502 https://conda.anaconda.org/conda-forge/linux-64/harfbuzz-6.0.0-h8e241bc_0.conda#448fe40d2fed88ccf4d9ded37cbb2b38 https://conda.anaconda.org/conda-forge/linux-64/libblas-3.9.0-16_linux64_mkl.tar.bz2#85f61af03fd291dae33150ffe89dc09a -https://conda.anaconda.org/conda-forge/linux-64/mkl-devel-2022.1.0-ha770c72_916.tar.bz2#69ba49e445f87aea2cba343a71a35ca2 -https://conda.anaconda.org/conda-forge/noarch/platformdirs-3.1.1-pyhd8ed1ab_0.conda#1d1a27f637808c76dd83e3f469aa6f7e +https://conda.anaconda.org/conda-forge/noarch/platformdirs-3.2.0-pyhd8ed1ab_0.conda#f10c2cf447ca96f12a326b83c75b8e33 https://conda.anaconda.org/conda-forge/linux-64/pulseaudio-daemon-16.1-ha8d29e2_3.conda#34d9d75ca896f5919c372a34e25f23ea https://conda.anaconda.org/conda-forge/linux-64/pyqt5-sip-12.11.0-py311hcafe171_3.conda#0d79df2a96f6572fed2883374400b235 https://conda.anaconda.org/conda-forge/noarch/pytest-cov-4.0.0-pyhd8ed1ab_0.tar.bz2#c9e3f8bfdb9bfc34aa1836a6ed4b25d7 @@ -173,20 +175,19 @@ https://conda.anaconda.org/conda-forge/linux-64/gst-plugins-base-1.22.0-h4243ec0 https://conda.anaconda.org/conda-forge/linux-64/libcblas-3.9.0-16_linux64_mkl.tar.bz2#361bf757b95488de76c4f123805742d3 https://conda.anaconda.org/conda-forge/linux-64/liblapack-3.9.0-16_linux64_mkl.tar.bz2#a2f166748917d6d6e4707841ca1f519e https://conda.anaconda.org/conda-forge/linux-64/pulseaudio-16.1-hcb278e6_3.conda#8b452ab959166d91949af4c2d28f81db -https://conda.anaconda.org/conda-forge/noarch/pyopenssl-23.0.0-pyhd8ed1ab_0.conda#d41957700e83bbb925928764cb7f8878 +https://conda.anaconda.org/conda-forge/noarch/pyopenssl-23.1.1-pyhd8ed1ab_0.conda#0b34aa3ab7e7ccb1765a03dd9ed29938 https://conda.anaconda.org/conda-forge/noarch/pytest-xdist-2.5.0-pyhd8ed1ab_0.tar.bz2#1fdd1f3baccf0deb647385c677a1a48e -https://conda.anaconda.org/conda-forge/linux-64/liblapacke-3.9.0-16_linux64_mkl.tar.bz2#44ccc4d4dca6a8d57fa17442bc64b5a1 https://conda.anaconda.org/conda-forge/linux-64/numpy-1.24.2-py311h8e6699e_0.conda#90db8cc0dfa20853329bfc6642f887aa https://conda.anaconda.org/conda-forge/linux-64/qt-main-5.15.8-h67dfc38_7.conda#f140acdc15a0196eef664f120c396266 https://conda.anaconda.org/conda-forge/noarch/urllib3-1.26.15-pyhd8ed1ab_0.conda#27db656619a55d727eaf5a6ece3d2fd6 -https://conda.anaconda.org/conda-forge/linux-64/blas-devel-3.9.0-16_linux64_mkl.tar.bz2#3f92c1c9e1c0e183462c5071aa02cae1 https://conda.anaconda.org/conda-forge/linux-64/contourpy-1.0.7-py311ha3edf6b_0.conda#e7548e7f58965a2fe97a95950a5fedc6 -https://conda.anaconda.org/conda-forge/linux-64/pandas-1.5.3-py311h2872171_0.conda#a129a2aa7f5c2f45808399d60c3080f2 +https://conda.anaconda.org/conda-forge/linux-64/pandas-1.5.3-py311h2872171_1.conda#6bb03bf6d4fab68174eae8b06c3b6934 https://conda.anaconda.org/conda-forge/linux-64/pyqt-5.15.7-py311ha74522f_3.conda#ad6dd0bed0cdf5f2d4eb2b989d6253b3 +https://conda.anaconda.org/conda-forge/linux-64/pytorch-1.13.1-cpu_py311h410fd25_1.conda#ddd2fadddf89e3dc3d541a2537fce010 https://conda.anaconda.org/conda-forge/noarch/requests-2.28.2-pyhd8ed1ab_0.conda#11d178fc55199482ee48d6812ea83983 -https://conda.anaconda.org/conda-forge/linux-64/blas-2.116-mkl.tar.bz2#c196a26abf6b4f132c88828ab7c2231c https://conda.anaconda.org/conda-forge/linux-64/matplotlib-base-3.7.1-py311h8597a09_0.conda#70c3b734ffe82c16b6d121aaa11929a8 -https://conda.anaconda.org/conda-forge/noarch/pooch-1.7.0-pyha770c72_2.conda#2794d1b53ad3ddb1e61f27d1151ccfc7 +https://conda.anaconda.org/conda-forge/noarch/pooch-1.7.0-pyha770c72_3.conda#5936894aade8240c867d292aa0d980c6 +https://conda.anaconda.org/conda-forge/linux-64/pytorch-cpu-1.13.1-cpu_py311hdb170b5_1.conda#a805d5f103e493f207613283d8acbbe1 https://conda.anaconda.org/conda-forge/linux-64/matplotlib-3.7.1-py311h38be061_0.conda#8fd462c8bcbba5a3affcb2d04e387476 https://conda.anaconda.org/conda-forge/linux-64/scipy-1.10.1-py311h8e6699e_0.conda#a9dba1242a54275e4914a2540f4eb233 https://conda.anaconda.org/conda-forge/linux-64/pyamg-4.2.3-py311hcb41070_2.conda#bcf32a1a23df6e4ae047f90d401b7517 diff --git a/build_tools/azure/pylatest_conda_forge_mkl_linux-64_environment.yml b/build_tools/azure/pylatest_conda_forge_mkl_linux-64_environment.yml index ac132f87a2aba..823c8d1734d1e 100644 --- a/build_tools/azure/pylatest_conda_forge_mkl_linux-64_environment.yml +++ b/build_tools/azure/pylatest_conda_forge_mkl_linux-64_environment.yml @@ -20,3 +20,5 @@ dependencies: - pytest-cov - coverage - ccache + - pytorch=1.13 + - pytorch-cpu diff --git a/build_tools/update_environments_and_lock_files.py b/build_tools/update_environments_and_lock_files.py index ec375078bfab8..9e4d8da71e318 100644 --- a/build_tools/update_environments_and_lock_files.py +++ b/build_tools/update_environments_and_lock_files.py @@ -88,9 +88,11 @@ def remove_from(alist, to_remove): "folder": "build_tools/azure", "platform": "linux-64", "channel": "conda-forge", - "conda_dependencies": common_dependencies + ["ccache"], + "conda_dependencies": common_dependencies + + ["ccache", "pytorch", "pytorch-cpu"], "package_constraints": { "blas": "[build=mkl]", + "pytorch": "1.13", }, }, { From 75ca39003ccbe2e203d6beaaf9aabac15ef27096 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Fri, 31 Mar 2023 13:17:47 -0400 Subject: [PATCH 30/42] FIX Fixes test --- sklearn/utils/tests/test_array_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/utils/tests/test_array_api.py b/sklearn/utils/tests/test_array_api.py index 440bd093bc5ac..a1c6a600fac3d 100644 --- a/sklearn/utils/tests/test_array_api.py +++ b/sklearn/utils/tests/test_array_api.py @@ -266,7 +266,7 @@ def test_reshape_behavior(): X = xp.asarray([[1, 2, 3], [3, 4, 5]]) X_no_copy = xp.reshape(X, (-1,), copy=False) - assert X_no_copy.base is X.base + assert X_no_copy.base is X X_copy = xp.reshape(X, (6, 1), copy=True) assert X_copy.base is not X.base From 7b4e6c6ea8a80bdbc0315e49b108a41af653618a Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Mon, 3 Apr 2023 14:11:09 -0400 Subject: [PATCH 31/42] FIX Fix merge --- sklearn/utils/_array_api.py | 52 ------------------------------------- 1 file changed, 52 deletions(-) diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py index 58c2e396ab2ce..8571a7546677b 100644 --- a/sklearn/utils/_array_api.py +++ b/sklearn/utils/_array_api.py @@ -118,58 +118,6 @@ def _isdtype_single(dtype, kind, *, xp): return dtype == kind -def _is_numpy_namespace(xp): - """Return True if xp is backed by NumPy.""" - return xp.__name__ in {"numpy", "numpy.array_api"} - - -def isdtype(dtype, kind, *, xp): - """Returns a boolean indicating whether a provided dtype is of type "kind". - - Included in the v2022.12 of the Array API spec. - https://data-apis.org/array-api/latest/API_specification/generated/array_api.isdtype.html - """ - if isinstance(kind, tuple): - return any(_isdtype_single(dtype, k, xp=xp) for k in kind) - else: - return _isdtype_single(dtype, kind, xp=xp) - - -def _isdtype_single(dtype, kind, *, xp): - if isinstance(kind, str): - if kind == "bool": - return dtype == xp.bool - elif kind == "signed integer": - return dtype in {xp.int8, xp.int16, xp.int32, xp.int64} - elif kind == "unsigned integer": - return dtype in {xp.uint8, xp.uint16, xp.uint32, xp.uint64} - elif kind == "integral": - return any( - _isdtype_single(dtype, k, xp=xp) - for k in ("signed integer", "unsigned integer") - ) - elif kind == "real floating": - return dtype in {xp.float32, xp.float64} - elif kind == "complex floating": - # Some name spaces do not have complex, such as cupy.array_api - # and numpy.array_api - complex_dtypes = set() - if hasattr(xp, "complex64"): - complex_dtypes.add(xp.complex64) - if hasattr(xp, "complex128"): - complex_dtypes.add(xp.complex128) - return dtype in complex_dtypes - elif kind == "numeric": - return any( - _isdtype_single(dtype, k, xp=xp) - for k in ("integral", "real floating", "complex floating") - ) - else: - raise ValueError(f"Unrecognized data type kind: {kind!r}") - else: - return dtype == kind - - class _ArrayAPIWrapper: """sklearn specific Array API compatibility wrapper From cc6982921a841c11065963f4afea09a2b3817b15 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Mon, 3 Apr 2023 14:13:12 -0400 Subject: [PATCH 32/42] CLN Minimize diff --- sklearn/utils/_array_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py index 8571a7546677b..e60ba5608a14b 100644 --- a/sklearn/utils/_array_api.py +++ b/sklearn/utils/_array_api.py @@ -68,6 +68,7 @@ def size(x): def _is_numpy_namespace(xp): + """Return True if xp is backed by NumPy.""" return xp.__name__ in {"numpy", "array_api_compat.numpy", "numpy.array_api"} @@ -233,7 +234,6 @@ def __getattr__(self, name): # Convert to dtype objects if name in self._DTYPES: return numpy.dtype(attr) - return attr @property From 5e737e4540debc7f6c63273cae40a4582abf5170 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Tue, 4 Apr 2023 10:37:28 -0400 Subject: [PATCH 33/42] STY Fix linting --- sklearn/utils/_array_api.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py index 64192e6e560d6..d7bf8448f7020 100644 --- a/sklearn/utils/_array_api.py +++ b/sklearn/utils/_array_api.py @@ -282,21 +282,6 @@ def reshape(self, x, shape, *, copy=None): def isdtype(self, dtype, kind): return isdtype(dtype, kind, xp=self) - def reshape(self, x, shape, *, copy=None): - """Gives a new shape to an array without changing its data. - - The Array API specification requires shape to be a tuple. - https://data-apis.org/array-api/latest/API_specification/generated/array_api.reshape.html - """ - if not isinstance(shape, tuple): - raise TypeError( - f"shape must be a tuple, got {shape!r} of type {type(shape)}" - ) - - if copy is True: - x = x.copy() - return numpy.reshape(x, shape) - _NUMPY_API_WRAPPER_INSTANCE = _NumPyAPIWrapper() From 3682beeea88e2f7d8b6af5a86790419cf7b4eb2f Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Tue, 4 Apr 2023 11:57:41 -0400 Subject: [PATCH 34/42] CI Fix array api installation --- build_tools/azure/install.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build_tools/azure/install.sh b/build_tools/azure/install.sh index 19e6dae4917b5..1ee23ac471146 100755 --- a/build_tools/azure/install.sh +++ b/build_tools/azure/install.sh @@ -56,7 +56,7 @@ python_environment_install_and_activate() { # TODO: Remove when array_api_compat ships a new release with latest changes # install development feature of array_api_compat for testing purposes - pip install git+https://github.com/data-apis/array-api-compat + python -m pip install git+https://github.com/data-apis/array-api-compat elif [[ "$DISTRIB" == "ubuntu" || "$DISTRIB" == "debian-32" ]]; then python3 -m virtualenv --system-site-packages --python=python3 $VIRTUALENV From d88c38c598a1348b1ac6c4cde81c3e2d5411fe16 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Tue, 4 Apr 2023 17:37:53 -0400 Subject: [PATCH 35/42] Use array-api-compat from conda-forge --- build_tools/azure/install.sh | 4 -- ...latest_conda_forge_mkl_linux-64_conda.lock | 38 ++++++++----------- ...t_conda_forge_mkl_linux-64_environment.yml | 1 + .../update_environments_and_lock_files.py | 2 +- 4 files changed, 18 insertions(+), 27 deletions(-) diff --git a/build_tools/azure/install.sh b/build_tools/azure/install.sh index 1ee23ac471146..5238cd1121d2e 100755 --- a/build_tools/azure/install.sh +++ b/build_tools/azure/install.sh @@ -54,10 +54,6 @@ python_environment_install_and_activate() { conda-lock install --name $VIRTUALENV $LOCK_FILE source activate $VIRTUALENV - # TODO: Remove when array_api_compat ships a new release with latest changes - # install development feature of array_api_compat for testing purposes - python -m pip install git+https://github.com/data-apis/array-api-compat - elif [[ "$DISTRIB" == "ubuntu" || "$DISTRIB" == "debian-32" ]]; then python3 -m virtualenv --system-site-packages --python=python3 $VIRTUALENV source $VIRTUALENV/bin/activate diff --git a/build_tools/azure/pylatest_conda_forge_mkl_linux-64_conda.lock b/build_tools/azure/pylatest_conda_forge_mkl_linux-64_conda.lock index 915bb9ac8aa1f..ca6884ffd605e 100644 --- a/build_tools/azure/pylatest_conda_forge_mkl_linux-64_conda.lock +++ b/build_tools/azure/pylatest_conda_forge_mkl_linux-64_conda.lock @@ -1,6 +1,6 @@ # Generated by conda-lock. # platform: linux-64 -# input_hash: 11b2d8c50a1abca3839d7665afc331d1b3d8f3088d7c5ff50a62cdffdd90b072 +# input_hash: ad4ff58ffa47068d7b421a5fa4734d67c8cf6cece6d6b0bd8595df15fb44b10c @EXPLICIT https://conda.anaconda.org/conda-forge/linux-64/_libgcc_mutex-0.1-conda_forge.tar.bz2#d7c89558ba9fa0495403155b64376d81 https://conda.anaconda.org/conda-forge/linux-64/ca-certificates-2022.12.7-ha878542_0.conda#ff9f73d45c4a07d6f424495288a26080 @@ -21,16 +21,13 @@ https://conda.anaconda.org/conda-forge/linux-64/libgcc-ng-12.2.0-h65d4601_19.tar https://conda.anaconda.org/conda-forge/linux-64/alsa-lib-1.2.8-h166bdaf_0.tar.bz2#be733e69048951df1e4b4b7bb8c7666f https://conda.anaconda.org/conda-forge/linux-64/attr-2.5.1-h166bdaf_1.tar.bz2#d9c69a24ad678ffce24c6543a0176b00 https://conda.anaconda.org/conda-forge/linux-64/bzip2-1.0.8-h7f98852_4.tar.bz2#a1fd65c7ccbf10880423d82bca54eb54 -https://conda.anaconda.org/conda-forge/linux-64/fftw-3.3.10-nompi_hf0379b8_106.conda#d7407e695358f068a2a7f8295cde0567 https://conda.anaconda.org/conda-forge/linux-64/gettext-0.21.1-h27087fc_0.tar.bz2#14947d8770185e5153fdd04d4673ed37 https://conda.anaconda.org/conda-forge/linux-64/graphite2-1.3.13-h58526e2_1001.tar.bz2#8c54672728e8ec6aa6db90cf2806d220 -https://conda.anaconda.org/conda-forge/linux-64/gstreamer-orc-0.4.33-h166bdaf_0.tar.bz2#879c93426c9d0b84a9de4513fbce5f4f -https://conda.anaconda.org/conda-forge/linux-64/icu-70.1-h27087fc_0.tar.bz2#87473a15119779e021c314249d4b4aed +https://conda.anaconda.org/conda-forge/linux-64/icu-72.1-hcb278e6_0.conda#7c8d20d847bb45f56bd941578fcfa146 https://conda.anaconda.org/conda-forge/linux-64/keyutils-1.6.1-h166bdaf_0.tar.bz2#30186d27e2c9fa62b45fb1476b7200e3 https://conda.anaconda.org/conda-forge/linux-64/lame-3.100-h166bdaf_1003.tar.bz2#a8832b479f93521a9e7b5b743803be51 https://conda.anaconda.org/conda-forge/linux-64/lerc-4.0.0-h27087fc_0.tar.bz2#76bbff344f0134279f225174e9064c8f https://conda.anaconda.org/conda-forge/linux-64/libbrotlicommon-1.0.9-h166bdaf_8.tar.bz2#9194c9bf9428035a05352d031462eae4 -https://conda.anaconda.org/conda-forge/linux-64/libdb-6.2.32-h9c3ff4c_0.tar.bz2#3f3258d8f841fbac63b36b75bdac1afd https://conda.anaconda.org/conda-forge/linux-64/libdeflate-1.18-h0b41bf4_0.conda#6aa9c9de5542ecb07fdda9ca626252d8 https://conda.anaconda.org/conda-forge/linux-64/libexpat-2.5.0-hcb278e6_1.conda#6305a3dd2752c76335295da4e581f2fd https://conda.anaconda.org/conda-forge/linux-64/libffi-3.4.2-h7f98852_5.tar.bz2#d645c6d2ac96843a2bfaccd2d62b3ac3 @@ -40,7 +37,6 @@ https://conda.anaconda.org/conda-forge/linux-64/libjpeg-turbo-2.1.5.1-h0b41bf4_0 https://conda.anaconda.org/conda-forge/linux-64/libnsl-2.0.0-h7f98852_0.tar.bz2#39b1328babf85c7c3a61636d9cd50206 https://conda.anaconda.org/conda-forge/linux-64/libogg-1.3.4-h7f98852_1.tar.bz2#6e8cc2173440d77708196c5b93771680 https://conda.anaconda.org/conda-forge/linux-64/libopus-1.3.1-h7f98852_1.tar.bz2#15345e56d527b330e1cacbdf58676e8f -https://conda.anaconda.org/conda-forge/linux-64/libtool-2.4.7-h27087fc_0.conda#f204c8ba400ec475452737094fb81d52 https://conda.anaconda.org/conda-forge/linux-64/libuuid-2.38.1-h0b41bf4_0.conda#40b61aab5c7ba9ff276c41cfffe6b80b https://conda.anaconda.org/conda-forge/linux-64/libwebp-base-1.3.0-h0b41bf4_0.conda#0d4a7508d8c6c65314f2b9c1f56ad408 https://conda.anaconda.org/conda-forge/linux-64/libzlib-1.2.13-h166bdaf_4.tar.bz2#f3f9de449d32ca9b9c66a22863c96f41 @@ -63,7 +59,6 @@ https://conda.anaconda.org/conda-forge/linux-64/xorg-xf86vidmodeproto-2.3.1-h7f9 https://conda.anaconda.org/conda-forge/linux-64/xorg-xproto-7.0.31-h7f98852_1007.tar.bz2#b4a4381d54784606820704f7b5f05a15 https://conda.anaconda.org/conda-forge/linux-64/xz-5.2.6-h166bdaf_0.tar.bz2#2161070d867d1b1204ea749c8eec4ef0 https://conda.anaconda.org/conda-forge/linux-64/expat-2.5.0-hcb278e6_1.conda#8b9b5aca60558d02ddaa09d599e55920 -https://conda.anaconda.org/conda-forge/linux-64/jack-1.9.22-h11f4161_0.conda#504fa9e712b99494a9cf4630e3ca7d78 https://conda.anaconda.org/conda-forge/linux-64/libbrotlidec-1.0.9-h166bdaf_8.tar.bz2#4ae4d7795d33e02bd20f6b23d91caf82 https://conda.anaconda.org/conda-forge/linux-64/libbrotlienc-1.0.9-h166bdaf_8.tar.bz2#04bac51ba35ea023dc48af73c1c88c25 https://conda.anaconda.org/conda-forge/linux-64/libcap-2.67-he9d0100_0.conda#d05556c80caffff164d17bdea0105a1a @@ -76,7 +71,7 @@ https://conda.anaconda.org/conda-forge/linux-64/libprotobuf-3.21.12-h3eb15da_0.c https://conda.anaconda.org/conda-forge/linux-64/libsqlite-3.40.0-h753d276_0.tar.bz2#2e5f9a37d487e1019fd4d8113adb2f9f https://conda.anaconda.org/conda-forge/linux-64/libvorbis-1.3.7-h9c3ff4c_0.tar.bz2#309dec04b70a3cc0f1e84a4013683bc0 https://conda.anaconda.org/conda-forge/linux-64/libxcb-1.13-h7f98852_1004.tar.bz2#b3653fdc58d03face9724f602218a904 -https://conda.anaconda.org/conda-forge/linux-64/libxml2-2.10.3-hca2bb57_4.conda#bb808b654bdc3c783deaf107a2ffb503 +https://conda.anaconda.org/conda-forge/linux-64/libxml2-2.10.3-hfdac1af_6.conda#7eecaadc2eaeef464c5fe17702f17c86 https://conda.anaconda.org/conda-forge/linux-64/mysql-common-8.0.32-ha901b37_1.conda#2c18a7a26ec0d0c23a917f37a65fc9a2 https://conda.anaconda.org/conda-forge/linux-64/pcre2-10.40-hc3806b6_0.tar.bz2#69e2c796349cd9b273890bee0febfe1b https://conda.anaconda.org/conda-forge/linux-64/readline-8.2-h8228510_1.conda#47d31b792659ce70f470b5c82fdfb7a4 @@ -91,24 +86,24 @@ https://conda.anaconda.org/conda-forge/linux-64/krb5-1.20.1-h81ceb04_0.conda#89a https://conda.anaconda.org/conda-forge/linux-64/libgcrypt-1.10.1-h166bdaf_0.tar.bz2#f967fc95089cd247ceed56eda31de3a9 https://conda.anaconda.org/conda-forge/linux-64/libglib-2.74.1-h606061b_1.tar.bz2#ed5349aa96776e00b34eccecf4a948fe https://conda.anaconda.org/conda-forge/linux-64/libhwloc-2.9.0-hd6dc26d_0.conda#ab9d052373c9376c0ebcff4dfef3d296 -https://conda.anaconda.org/conda-forge/linux-64/libllvm15-15.0.7-hadd5161_1.conda#17d91085ccf5934ce652cb448d0cb65a +https://conda.anaconda.org/conda-forge/linux-64/libllvm16-16.0.0-hadd5161_1.conda#b04b132afca2c002ad2c214845808f47 https://conda.anaconda.org/conda-forge/linux-64/libsndfile-1.2.0-hb75c966_0.conda#c648d19cd9c8625898d5d370414de7c7 https://conda.anaconda.org/conda-forge/linux-64/libtiff-4.5.0-ha587672_6.conda#4e5ee4b062c21519efbee7e2ae608748 -https://conda.anaconda.org/conda-forge/linux-64/libudev1-253-h0b41bf4_1.conda#bb38b19a41bb94e8a19dbfb062d499c7 https://conda.anaconda.org/conda-forge/linux-64/libxkbcommon-1.5.0-h79f4944_1.conda#04a39cdd663f295653fc143851830563 https://conda.anaconda.org/conda-forge/linux-64/llvm-openmp-16.0.0-h417c0b6_0.conda#8ac4c157172ea816f5c9a0dc33df69d8 https://conda.anaconda.org/conda-forge/linux-64/mysql-libs-8.0.32-hd7da12d_1.conda#a69fa6f218cfed8e2d61753eeacaf034 https://conda.anaconda.org/conda-forge/linux-64/nss-3.89-he45b914_0.conda#2745719a58eeaab6657256a3f142f099 -https://conda.anaconda.org/conda-forge/linux-64/python-3.11.1-h2755cc3_0_cpython.conda#2b276315a584c0237e384829ef95fae3 +https://conda.anaconda.org/conda-forge/linux-64/python-3.11.2-h2755cc3_0_cpython.conda#1895d5e5122832e59184dd5d18c7ea1d https://conda.anaconda.org/conda-forge/linux-64/xcb-util-0.4.0-h166bdaf_0.tar.bz2#384e7fcb3cd162ba3e4aed4b687df566 https://conda.anaconda.org/conda-forge/linux-64/xcb-util-keysyms-0.4.0-h166bdaf_0.tar.bz2#637054603bb7594302e3bf83f0a99879 https://conda.anaconda.org/conda-forge/linux-64/xcb-util-renderutil-0.3.9-h166bdaf_0.tar.bz2#732e22f1741bccea861f5668cf7342a7 https://conda.anaconda.org/conda-forge/linux-64/xcb-util-wm-0.4.1-h166bdaf_0.tar.bz2#0a8e20a8aef954390b9481a527421a8c https://conda.anaconda.org/conda-forge/linux-64/xorg-libx11-1.8.4-h0b41bf4_0.conda#ea8fbfeb976ac49cbeb594e985393514 +https://conda.anaconda.org/conda-forge/noarch/array-api-compat-1.2-pyhd8ed1ab_0.conda#3d34f2f6987f8d098ab00198c170a77e https://conda.anaconda.org/conda-forge/noarch/attrs-22.2.0-pyh71513ae_0.conda#8b76db7818a4e401ed4486c4c1635cd9 https://conda.anaconda.org/conda-forge/linux-64/brotli-1.0.9-h166bdaf_8.tar.bz2#2ff08978892a3e8b954397c461f18418 https://conda.anaconda.org/conda-forge/noarch/certifi-2022.12.7-pyhd8ed1ab_0.conda#fb9addc3db06e56abe03e0e9f21a63e6 -https://conda.anaconda.org/conda-forge/noarch/charset-normalizer-2.1.1-pyhd8ed1ab_0.tar.bz2#c1d5b294fbf9a795dec349a6f4d8be8e +https://conda.anaconda.org/conda-forge/noarch/charset-normalizer-3.1.0-pyhd8ed1ab_0.conda#7fcff9f6f123696e940bda77bd4d6551 https://conda.anaconda.org/conda-forge/noarch/colorama-0.4.6-pyhd8ed1ab_0.tar.bz2#3faab06a954c2a04039983f2c4a50d99 https://conda.anaconda.org/conda-forge/noarch/cycler-0.11.0-pyhd8ed1ab_0.tar.bz2#a50559fad0affdbb33729a68669ca1cb https://conda.anaconda.org/conda-forge/linux-64/cython-0.29.33-py311hcafe171_0.conda#3e792927e2e16119f8e6910cca25a063 @@ -121,7 +116,7 @@ https://conda.anaconda.org/conda-forge/noarch/idna-3.4-pyhd8ed1ab_0.tar.bz2#3427 https://conda.anaconda.org/conda-forge/noarch/iniconfig-2.0.0-pyhd8ed1ab_0.conda#f800d2da156d08e289b14e87e43c1ae5 https://conda.anaconda.org/conda-forge/linux-64/kiwisolver-1.4.4-py311h4dd048b_1.tar.bz2#46d451f575392c01dc193069bd89766d https://conda.anaconda.org/conda-forge/linux-64/lcms2-2.15-haa2dc70_1.conda#980d8aca0bc23ca73fa8caa3e7c84c28 -https://conda.anaconda.org/conda-forge/linux-64/libclang13-15.0.7-default_h3e3d535_1.conda#a3a0f7a6f0885f5e1e0ec691566afb77 +https://conda.anaconda.org/conda-forge/linux-64/libclang13-16.0.0-default_h9b593c0_1.conda#1131f987052770bb669e70f95b89eebe https://conda.anaconda.org/conda-forge/linux-64/libcups-2.3.3-h36d4200_3.conda#c9f4416a34bc91e0eb029f912c68f81f https://conda.anaconda.org/conda-forge/linux-64/libpq-15.2-hb675445_0.conda#4654b17eccaba55b8581d6b9c77f53cc https://conda.anaconda.org/conda-forge/linux-64/libsystemd0-253-h8c4010b_1.conda#9176b1e2cb8beca37a7510b0e801e38f @@ -134,6 +129,7 @@ https://conda.anaconda.org/conda-forge/noarch/py-1.11.0-pyh6c4a22f_0.tar.bz2#b46 https://conda.anaconda.org/conda-forge/noarch/pycparser-2.21-pyhd8ed1ab_0.tar.bz2#076becd9e05608f8dc72757d5f3a91ff https://conda.anaconda.org/conda-forge/noarch/pyparsing-3.0.9-pyhd8ed1ab_0.tar.bz2#e8fbc1b54b25f4b08281467bc13b70cc https://conda.anaconda.org/conda-forge/noarch/pysocks-1.7.1-pyha2e5f31_6.tar.bz2#2a7de29fb590ca14b5243c4c812c8025 +https://conda.anaconda.org/conda-forge/noarch/python-tzdata-2023.3-pyhd8ed1ab_0.conda#2590495f608a63625e165915fb4e2e34 https://conda.anaconda.org/conda-forge/noarch/pytz-2023.3-pyhd8ed1ab_0.conda#d3076b483092a435832603243567bc31 https://conda.anaconda.org/conda-forge/noarch/setuptools-67.6.1-pyhd8ed1ab_0.conda#6c443cccff3daa3d83b2b807b0a298ce https://conda.anaconda.org/conda-forge/noarch/six-1.16.0-pyh6c4a22f_0.tar.bz2#e5f25f8dbc060e9a8d912e432202afc2 @@ -146,45 +142,43 @@ https://conda.anaconda.org/conda-forge/noarch/typing_extensions-4.5.0-pyha770c72 https://conda.anaconda.org/conda-forge/linux-64/xcb-util-image-0.4.0-h166bdaf_0.tar.bz2#c9b568bd804cb2903c6be6f5f68182e4 https://conda.anaconda.org/conda-forge/linux-64/xorg-libxext-1.3.4-h0b41bf4_2.conda#82b6df12252e6f32402b96dacc656fec https://conda.anaconda.org/conda-forge/linux-64/xorg-libxrender-0.9.10-h7f98852_1003.tar.bz2#f59c1242cc1dd93e72c2ee2b360979eb -https://conda.anaconda.org/conda-forge/linux-64/cairo-1.16.0-ha61ee94_1014.tar.bz2#d1a88f3ed5b52e1024b80d4bcd26a7a0 +https://conda.anaconda.org/conda-forge/linux-64/cairo-1.16.0-h35add3b_1015.conda#0c944213e40c9e4aa32292776b9c6903 https://conda.anaconda.org/conda-forge/linux-64/cffi-1.15.1-py311h409f033_3.conda#9025d0786dbbe4bc91fd8e85502decce https://conda.anaconda.org/conda-forge/linux-64/coverage-7.2.2-py311h2582759_0.conda#41f2bca794aa1b4e70c1a23f8ec2dfa5 https://conda.anaconda.org/conda-forge/linux-64/fonttools-4.39.3-py311h2582759_0.conda#55741f37ab19d949b8e7316cfe286824 https://conda.anaconda.org/conda-forge/linux-64/glib-2.74.1-h6239696_1.tar.bz2#f3220a9e9d3abcbfca43419a219df7e4 https://conda.anaconda.org/conda-forge/noarch/joblib-1.2.0-pyhd8ed1ab_0.tar.bz2#7583652522d71ad78ba536bba06940eb -https://conda.anaconda.org/conda-forge/linux-64/libclang-15.0.7-default_had23c3d_1.conda#36c65ed73b7c92589bd9562ef8a6023d +https://conda.anaconda.org/conda-forge/linux-64/libclang-16.0.0-default_h62803fd_1.conda#e8bf260afdc0fb97f7a9586ee1a05ca7 https://conda.anaconda.org/conda-forge/linux-64/mkl-2022.2.1-h84fe81f_16997.conda#a7ce56d5757f5b57e7daabe703ade5bb https://conda.anaconda.org/conda-forge/linux-64/pillow-9.4.0-py311h573f0d3_2.conda#7321881b545202cf9ab8bd24b4151dcb https://conda.anaconda.org/conda-forge/linux-64/pulseaudio-client-16.1-h5195f5e_3.conda#caeb3302ef1dc8b342b20c710a86f8a9 https://conda.anaconda.org/conda-forge/noarch/pytest-7.2.2-pyhd8ed1ab_0.conda#60958b19354e0ec295b43f6ab5cfab86 https://conda.anaconda.org/conda-forge/noarch/python-dateutil-2.8.2-pyhd8ed1ab_0.tar.bz2#dd999d1cc9f79e67dbb855c8924c7984 -https://conda.anaconda.org/conda-forge/linux-64/sip-6.7.7-py311hcafe171_0.conda#cf1adb3a0138cca4bbab415ddf2f57f1 +https://conda.anaconda.org/conda-forge/linux-64/sip-6.7.7-py311hcafe171_1.conda#981cd2df8da43935d615bc48be964698 https://conda.anaconda.org/conda-forge/noarch/typing-extensions-4.5.0-hd8ed1ab_0.conda#b3c594fde1a80a1fc3eb9cc4a5dfe392 https://conda.anaconda.org/conda-forge/linux-64/blas-1.0-mkl.tar.bz2#349aef876b1d8c9dccae01de20d5b385 https://conda.anaconda.org/conda-forge/linux-64/brotlipy-0.7.0-py311hd4cff14_1005.tar.bz2#9bdac7084ecfc08338bae1b976535724 https://conda.anaconda.org/conda-forge/linux-64/cryptography-40.0.1-py311h9b4c7bb_0.conda#7638d1e31ce0a029211e12fc719e52e6 https://conda.anaconda.org/conda-forge/linux-64/gstreamer-1.22.0-h25f0c4b_2.conda#461541cb1b387c2a28ab6217f3d38502 -https://conda.anaconda.org/conda-forge/linux-64/harfbuzz-6.0.0-h8e241bc_0.conda#448fe40d2fed88ccf4d9ded37cbb2b38 +https://conda.anaconda.org/conda-forge/linux-64/harfbuzz-6.0.0-h3ff4399_1.conda#73d2c2d25fdcec40c24929bab9f44831 https://conda.anaconda.org/conda-forge/linux-64/libblas-3.9.0-16_linux64_mkl.tar.bz2#85f61af03fd291dae33150ffe89dc09a https://conda.anaconda.org/conda-forge/noarch/platformdirs-3.2.0-pyhd8ed1ab_0.conda#f10c2cf447ca96f12a326b83c75b8e33 -https://conda.anaconda.org/conda-forge/linux-64/pulseaudio-daemon-16.1-ha8d29e2_3.conda#34d9d75ca896f5919c372a34e25f23ea https://conda.anaconda.org/conda-forge/linux-64/pyqt5-sip-12.11.0-py311hcafe171_3.conda#0d79df2a96f6572fed2883374400b235 https://conda.anaconda.org/conda-forge/noarch/pytest-cov-4.0.0-pyhd8ed1ab_0.tar.bz2#c9e3f8bfdb9bfc34aa1836a6ed4b25d7 https://conda.anaconda.org/conda-forge/noarch/pytest-forked-1.6.0-pyhd8ed1ab_0.conda#a46947638b6e005b63d2d6271da529b0 https://conda.anaconda.org/conda-forge/linux-64/gst-plugins-base-1.22.0-h4243ec0_2.conda#0d0c6604c8ac4ad5e51efa7bb58da05c https://conda.anaconda.org/conda-forge/linux-64/libcblas-3.9.0-16_linux64_mkl.tar.bz2#361bf757b95488de76c4f123805742d3 https://conda.anaconda.org/conda-forge/linux-64/liblapack-3.9.0-16_linux64_mkl.tar.bz2#a2f166748917d6d6e4707841ca1f519e -https://conda.anaconda.org/conda-forge/linux-64/pulseaudio-16.1-hcb278e6_3.conda#8b452ab959166d91949af4c2d28f81db https://conda.anaconda.org/conda-forge/noarch/pyopenssl-23.1.1-pyhd8ed1ab_0.conda#0b34aa3ab7e7ccb1765a03dd9ed29938 https://conda.anaconda.org/conda-forge/noarch/pytest-xdist-2.5.0-pyhd8ed1ab_0.tar.bz2#1fdd1f3baccf0deb647385c677a1a48e https://conda.anaconda.org/conda-forge/linux-64/numpy-1.24.2-py311h8e6699e_0.conda#90db8cc0dfa20853329bfc6642f887aa -https://conda.anaconda.org/conda-forge/linux-64/qt-main-5.15.8-h67dfc38_7.conda#f140acdc15a0196eef664f120c396266 +https://conda.anaconda.org/conda-forge/linux-64/qt-main-5.15.8-h5c52f38_9.conda#a80700497d2f343a8448778308aa5dfa https://conda.anaconda.org/conda-forge/noarch/urllib3-1.26.15-pyhd8ed1ab_0.conda#27db656619a55d727eaf5a6ece3d2fd6 https://conda.anaconda.org/conda-forge/linux-64/contourpy-1.0.7-py311ha3edf6b_0.conda#e7548e7f58965a2fe97a95950a5fedc6 -https://conda.anaconda.org/conda-forge/linux-64/pandas-1.5.3-py311h2872171_1.conda#6bb03bf6d4fab68174eae8b06c3b6934 +https://conda.anaconda.org/conda-forge/linux-64/pandas-2.0.0-py311h2872171_0.conda#f987f61faa256eace0d74ca491ab88c7 https://conda.anaconda.org/conda-forge/linux-64/pyqt-5.15.7-py311ha74522f_3.conda#ad6dd0bed0cdf5f2d4eb2b989d6253b3 https://conda.anaconda.org/conda-forge/linux-64/pytorch-1.13.1-cpu_py311h410fd25_1.conda#ddd2fadddf89e3dc3d541a2537fce010 -https://conda.anaconda.org/conda-forge/noarch/requests-2.28.2-pyhd8ed1ab_0.conda#11d178fc55199482ee48d6812ea83983 +https://conda.anaconda.org/conda-forge/noarch/requests-2.28.2-pyhd8ed1ab_1.conda#3bfbd6ead1d7299ed46dab3a7bf0bc8c https://conda.anaconda.org/conda-forge/linux-64/matplotlib-base-3.7.1-py311h8597a09_0.conda#70c3b734ffe82c16b6d121aaa11929a8 https://conda.anaconda.org/conda-forge/noarch/pooch-1.7.0-pyha770c72_3.conda#5936894aade8240c867d292aa0d980c6 https://conda.anaconda.org/conda-forge/linux-64/pytorch-cpu-1.13.1-cpu_py311hdb170b5_1.conda#a805d5f103e493f207613283d8acbbe1 diff --git a/build_tools/azure/pylatest_conda_forge_mkl_linux-64_environment.yml b/build_tools/azure/pylatest_conda_forge_mkl_linux-64_environment.yml index 823c8d1734d1e..7c59154d4cac1 100644 --- a/build_tools/azure/pylatest_conda_forge_mkl_linux-64_environment.yml +++ b/build_tools/azure/pylatest_conda_forge_mkl_linux-64_environment.yml @@ -22,3 +22,4 @@ dependencies: - ccache - pytorch=1.13 - pytorch-cpu + - array-api-compat diff --git a/build_tools/update_environments_and_lock_files.py b/build_tools/update_environments_and_lock_files.py index 8a229383d9638..f63041407cb80 100644 --- a/build_tools/update_environments_and_lock_files.py +++ b/build_tools/update_environments_and_lock_files.py @@ -89,7 +89,7 @@ def remove_from(alist, to_remove): "platform": "linux-64", "channel": "conda-forge", "conda_dependencies": common_dependencies - + ["ccache", "pytorch", "pytorch-cpu"], + + ["ccache", "pytorch", "pytorch-cpu", "array-api-compat"], "package_constraints": { "blas": "[build=mkl]", "pytorch": "1.13", From 14400abc3b15884dfde500671d1795fccffcd43e Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Wed, 5 Apr 2023 09:41:26 -0400 Subject: [PATCH 36/42] TST Ignore coverage for GPU --- sklearn/utils/_array_api.py | 7 ++----- sklearn/utils/tests/test_array_api.py | 2 +- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py index d7bf8448f7020..9724cf413f4d1 100644 --- a/sklearn/utils/_array_api.py +++ b/sklearn/utils/_array_api.py @@ -383,11 +383,8 @@ def _asarray_with_order(array, dtype=None, order=None, copy=None, *, xp=None): return xp.asarray(array, dtype=dtype, copy=copy) -def _convert_to_numpy(array, xp=None): +def _convert_to_numpy(array, xp): """Convert X into a NumPy ndarray on the CPU.""" - if xp is None: - xp, _ = get_namespace(array) - xp_name = xp.__name__ if xp_name in {"array_api_compat.torch", "torch"}: @@ -395,7 +392,7 @@ def _convert_to_numpy(array, xp=None): elif xp_name == "cupy.array_api": return array._array.get() elif xp_name in {"array_api_compat.cupy", "cupy"}: - return array.get() + return array.get() # pragma: nocover return numpy.asarray(array) diff --git a/sklearn/utils/tests/test_array_api.py b/sklearn/utils/tests/test_array_api.py index cb3d880d29ada..77fa20e6d0b58 100644 --- a/sklearn/utils/tests/test_array_api.py +++ b/sklearn/utils/tests/test_array_api.py @@ -160,7 +160,7 @@ def test_asarray_with_order_ignored(): @skip_if_array_api_compat_not_configured @pytest.mark.parametrize("library", ["cupy", "torch", "cupy.array_api"]) -def test_convert_to_numpy_gpu(library): +def test_convert_to_numpy_gpu(library): # pragma: nocover """Check convert_to_numpy for GPU backed libraries.""" xp = pytest.importorskip(library) From 1ee18bf0321c0c02673a87c7f14f7c2320342b3c Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Wed, 5 Apr 2023 09:51:29 -0400 Subject: [PATCH 37/42] TST Add no cover for cupy branch --- sklearn/utils/_array_api.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py index 9724cf413f4d1..29fb14dd802da 100644 --- a/sklearn/utils/_array_api.py +++ b/sklearn/utils/_array_api.py @@ -391,8 +391,8 @@ def _convert_to_numpy(array, xp): return array.cpu().numpy() elif xp_name == "cupy.array_api": return array._array.get() - elif xp_name in {"array_api_compat.cupy", "cupy"}: - return array.get() # pragma: nocover + elif xp_name in {"array_api_compat.cupy", "cupy"}: # pragma: nocover + return array.get() return numpy.asarray(array) From e320a327eaa06a77b9fcbfc499d2bb8a14e3a058 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Thu, 6 Apr 2023 12:44:50 -0400 Subject: [PATCH 38/42] STY Update black version --- build_tools/update_environments_and_lock_files.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/build_tools/update_environments_and_lock_files.py b/build_tools/update_environments_and_lock_files.py index 600f48d45b11c..132e0ee431dfc 100644 --- a/build_tools/update_environments_and_lock_files.py +++ b/build_tools/update_environments_and_lock_files.py @@ -88,8 +88,12 @@ def remove_from(alist, to_remove): "folder": "build_tools/azure", "platform": "linux-64", "channel": "conda-forge", - "conda_dependencies": common_dependencies - + ["ccache", "pytorch", "pytorch-cpu", "array-api-compat"], + "conda_dependencies": common_dependencies + [ + "ccache", + "pytorch", + "pytorch-cpu", + "array-api-compat", + ], "package_constraints": { "blas": "[build=mkl]", "pytorch": "1.13", From ce1191af06070e97b6eece0f55523d00aed66880 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Tue, 11 Apr 2023 11:12:58 -0400 Subject: [PATCH 39/42] ENH Use DLPack instead --- sklearn/utils/_array_api.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py index 29fb14dd802da..ba3c443da239f 100644 --- a/sklearn/utils/_array_api.py +++ b/sklearn/utils/_array_api.py @@ -400,6 +400,9 @@ def _convert_to_numpy(array, xp): def _estimator_with_converted_arrays(estimator, converter): """Create new estimator which converting all attributes that are arrays. + The converted is called on all NumPy arrays and arrays that support the + `DLPack interface `__. + Parameters ---------- estimator : Estimator @@ -417,13 +420,7 @@ def _estimator_with_converted_arrays(estimator, converter): new_estimator = clone(estimator) for key, attribute in vars(estimator).items(): - with config_context(array_api_dispatch=True): - try: - _, is_array_api_compliant = get_namespace(attribute) - except TypeError: - is_array_api_compliant = False - - if is_array_api_compliant or isinstance(attribute, numpy.ndarray): + if hasattr(attribute, "__dlpack__") or isinstance(attribute, numpy.ndarray): attribute = converter(attribute) setattr(new_estimator, key, attribute) return new_estimator From 128411b27fe291ce9f2b8cb939af7b5c73a6aa3d Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Wed, 12 Apr 2023 09:21:48 -0400 Subject: [PATCH 40/42] STY Remove unneeded import --- sklearn/utils/_array_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py index ba3c443da239f..17eec91a80a04 100644 --- a/sklearn/utils/_array_api.py +++ b/sklearn/utils/_array_api.py @@ -5,7 +5,7 @@ import numpy import scipy.special as special -from .._config import get_config, config_context +from .._config import get_config from .fixes import parse_version From cb24ffd4cb1624375d97fc798c0a29905bc9451d Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Thu, 20 Apr 2023 09:07:35 -0400 Subject: [PATCH 41/42] Apply suggestions from code review Co-authored-by: Tim Head --- sklearn/utils/_array_api.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py index 17eec91a80a04..27d342b8c1ab2 100644 --- a/sklearn/utils/_array_api.py +++ b/sklearn/utils/_array_api.py @@ -326,7 +326,7 @@ def get_namespace(*arrays): the namespace defaults to NumPy. is_array_api_compliant : bool - True of the arrays are containers that implement the Array API spec. + True if the arrays are containers that implement the Array API spec. Always False when array_api_dispatch=False. """ array_api_dispatch = get_config()["array_api_dispatch"] @@ -400,7 +400,7 @@ def _convert_to_numpy(array, xp): def _estimator_with_converted_arrays(estimator, converter): """Create new estimator which converting all attributes that are arrays. - The converted is called on all NumPy arrays and arrays that support the + The converter is called on all NumPy arrays and arrays that support the `DLPack interface `__. Parameters From 878da33c1676a50a55d3a26fe3b204cd526a6966 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Thu, 20 Apr 2023 09:11:30 -0400 Subject: [PATCH 42/42] DOC Comment about xp.asarray --- sklearn/utils/_array_api.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py index 27d342b8c1ab2..13ab96b866fc6 100644 --- a/sklearn/utils/_array_api.py +++ b/sklearn/utils/_array_api.py @@ -378,6 +378,9 @@ def _asarray_with_order(array, dtype=None, order=None, copy=None, *, xp=None): array = numpy.array(array, order=order, dtype=dtype) else: array = numpy.asarray(array, order=order, dtype=dtype) + + # At this point array is a NumPy ndarray. We convert it to an array + # container that is consistent with the input's namespace. return xp.asarray(array) else: return xp.asarray(array, dtype=dtype, copy=copy)