[go: up one dir, main page]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add array API support for Nystroem approximation #29661

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

EmilyXinyi
Copy link
Contributor

What does this implement/fix? Explain your changes.

Make Nystroem approximation array API compatible

To Do

  • Make Nystroem array API compatible
  • Test for failures (and fix them, if any) when array_api_dispatch=True AND array_api_dispatch=False
  • Test for performances difference on GPU with array API turned on vs off
  • Add changelog

Copy link
github-actions bot commented Aug 12, 2024

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: 8668a06. Link to the linter CI: here

@EmilyXinyi
Copy link
Contributor Author

While working on this I kept running into this error ValueError: Expected 2D array, got 1D array instead with the rbf_kernel. I ran the corresponding tests in sklearn/metrics/tests/test_common.py and I got the same error. I have attached the error messages below. From my investigation I believe it's because xp.asarray implicitly reshapes the array when xp is the torch namespace.

@OmarManzoor do you happen to have any pointer on how to get around this when you worked on #29433 ? Thanks!

metric = <function rbf_kernel at 0x125b99b20>, array_namespace = 'torch', device = 'cpu', dtype_name = 'float32', check_func = <function check_array_api_metric_pairwise at 0x125c4ad40>

    @pytest.mark.parametrize(
        "array_namespace, device, dtype_name", yield_namespace_device_dtype_combinations()
    )
    @pytest.mark.parametrize("metric, check_func", yield_metric_checker_combinations())
    def test_array_api_compliance(metric, array_namespace, device, dtype_name, check_func):
>       check_func(metric, array_namespace, device, dtype_name)

array_namespace = 'torch'
check_func = <function check_array_api_metric_pairwise at 0x125c4ad40>
device     = 'cpu'
dtype_name = 'float32'
metric     = <function rbf_kernel at 0x125b99b20>

sklearn/metrics/tests/test_common.py:2048: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
sklearn/metrics/tests/test_common.py:1979: in check_array_api_metric_pairwise
    check_array_api_metric(
        X_np       = array([[0.1, 0.2, 0.3],
       [0.4, 0.5, 0.6]], dtype=float32)
        Y_np       = array([[0.2, 0.3, 0.4],
       [0.5, 0.6, 0.7]], dtype=float32)
        array_namespace = 'torch'
        device     = 'cpu'
        dtype_name = 'float32'
        metric     = <function rbf_kernel at 0x125b99b20>
        metric_kwargs = {}
sklearn/metrics/tests/test_common.py:1773: in check_array_api_metric
    metric_xp = metric(a_xp, b_xp, **metric_kwargs)
        a_np       = array([[0.1, 0.2, 0.3],
       [0.4, 0.5, 0.6]], dtype=float32)
        a_xp       = tensor([0.1000, 0.2000, 0.3000, 0.4000, 0.5000, 0.6000])
        array_namespace = 'torch'
        b_np       = array([[0.2, 0.3, 0.4],
       [0.5, 0.6, 0.7]], dtype=float32)
        b_xp       = tensor([0.2000, 0.3000, 0.4000, 0.5000, 0.6000, 0.7000])
        device     = 'cpu'
        dtype_name = 'float32'
        metric     = <function rbf_kernel at 0x125b99b20>
        metric_kwargs = {}
        metric_np  = array([[0.9900499, 0.8521437],
       [0.9607895, 0.9900499]], dtype=float32)
        multioutput = None
        xp         = <module 'array_api_compat.torch' from '/Users/emilychen/miniforge3/envs/sklearn-dev/lib/python3.12/site-packages/array_api_compat/torch/__init__.py'>
sklearn/utils/_param_validation.py:216: in wrapper
    return func(*args, **kwargs)
        args       = (tensor([0.1000, 0.2000, 0.3000, 0.4000, 0.5000, 0.6000]), tensor([0.2000, 0.3000, 0.4000, 0.5000, 0.6000, 0.7000]))
        func       = <function rbf_kernel at 0x125b99a80>
        func_sig   = <Signature (X, Y=None, gamma=None)>
        global_skip_validation = False
        kwargs     = {}
        parameter_constraints = {'X': ['array-like', 'sparse matrix'], 'Y': ['array-like', 'sparse matrix', None], 'gamma': [<sklearn.utils._param_validation.Interval object at 0x125b42330>, None, <sklearn.utils._param_validation.Hidden object at 0x125b42360>]}
        params     = {'X': tensor([0.1000, 0.2000, 0.3000, 0.4000, 0.5000, 0.6000]), 'Y': tensor([0.2000, 0.3000, 0.4000, 0.5000, 0.6000, 0.7000]), 'gamma': None}
        prefer_skip_nested_validation = True
        to_ignore  = ['self', 'cls']
sklearn/metrics/pairwise.py:1598: in rbf_kernel
    X, Y = check_pairwise_arrays(X, Y)
        X          = tensor([0.1000, 0.2000, 0.3000, 0.4000, 0.5000, 0.6000])
        Y          = tensor([0.2000, 0.3000, 0.4000, 0.5000, 0.6000, 0.7000])
        _          = True
        gamma      = None
        xp         = <module 'array_api_compat.torch' from '/Users/emilychen/miniforge3/envs/sklearn-dev/lib/python3.12/site-packages/array_api_compat/torch/__init__.py'>
sklearn/metrics/pairwise.py:205: in check_pairwise_arrays
    X = check_array(
        X          = tensor([0.1000, 0.2000, 0.3000, 0.4000, 0.5000, 0.6000])
        Y          = tensor([0.2000, 0.3000, 0.4000, 0.5000, 0.6000, 0.7000])
        _          = True
        accept_sparse = 'csr'
        copy       = False
        dtype      = torch.float32
        dtype_float = torch.float32
        ensure_2d  = True
        ensure_all_finite = True
        estimator  = 'check_pairwise_arrays'
        force_all_finite = 'deprecated'
        precomputed = False
        xp         = <module 'array_api_compat.torch' from '/Users/emilychen/miniforge3/envs/sklearn-dev/lib/python3.12/site-packages/array_api_compat/torch/__init__.py'>
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

array = tensor([0.1000, 0.2000, 0.3000, 0.4000, 0.5000, 0.6000]), accept_sparse = 'csr'

    def check_array(
        array,
        accept_sparse=False,
        *,
        accept_large_sparse=True,
        dtype="numeric",
        order=None,
        copy=False,
        force_writeable=False,
        force_all_finite="deprecated",
        ensure_all_finite=None,
        ensure_non_negative=False,
        ensure_2d=True,
        allow_nd=False,
        ensure_min_samples=1,
        ensure_min_features=1,
        estimator=None,
        input_name="",
    ):
        """Input validation on an array, list, sparse matrix or similar.
    
        By default, the input is checked to be a non-empty 2D array containing
        only finite values. If the dtype of the array is object, attempt
        converting to float, raising on failure.
    
        Parameters
        ----------
        array : object
            Input object to check / convert.
    
        accept_sparse : str, bool or list/tuple of str, default=False
            String[s] representing allowed sparse matrix formats, such as 'csc',
            'csr', etc. If the input is sparse but not in the allowed format,
            it will be converted to the first listed format. True allows the input
            to be any format. False means that a sparse matrix input will
            raise an error.
    
        accept_large_sparse : bool, default=True
            If a CSR, CSC, COO or BSR sparse matrix is supplied and accepted by
            accept_sparse, accept_large_sparse=False will cause it to be accepted
            only if its indices are stored with a 32-bit dtype.
    
            .. versionadded:: 0.20
    
        dtype : 'numeric', type, list of type or None, default='numeric'
            Data type of result. If None, the dtype of the input is preserved.
            If "numeric", dtype is preserved unless array.dtype is object.
            If dtype is a list of types, conversion on the first type is only
            performed if the dtype of the input is not in the list.
    
        order : {'F', 'C'} or None, default=None
            Whether an array will be forced to be fortran or c-style.
            When order is None (default), then if copy=False, nothing is ensured
            about the memory layout of the output array; otherwise (copy=True)
            the memory layout of the returned array is kept as close as possible
            to the original array.
    
        copy : bool, default=False
            Whether a forced copy will be triggered. If copy=False, a copy might
            be triggered by a conversion.
    
        force_writeable : bool, default=False
            Whether to force the output array to be writeable. If True, the returned array
            is guaranteed to be writeable, which may require a copy. Otherwise the
            writeability of the input array is preserved.
    
            .. versionadded:: 1.6
    
        force_all_finite : bool or 'allow-nan', default=True
            Whether to raise an error on np.inf, np.nan, pd.NA in array. The
            possibilities are:
    
            - True: Force all values of array to be finite.
            - False: accepts np.inf, np.nan, pd.NA in array.
            - 'allow-nan': accepts only np.nan and pd.NA values in array. Values
              cannot be infinite.
    
            .. versionadded:: 0.20
               ``force_all_finite`` accepts the string ``'allow-nan'``.
    
            .. versionchanged:: 0.23
               Accepts `pd.NA` and converts it into `np.nan`
    
            .. deprecated:: 1.6
               `force_all_finite` was renamed to `ensure_all_finite` and will be removed
               in 1.8.
    
        ensure_all_finite : bool or 'allow-nan', default=True
            Whether to raise an error on np.inf, np.nan, pd.NA in array. The
            possibilities are:
    
            - True: Force all values of array to be finite.
            - False: accepts np.inf, np.nan, pd.NA in array.
            - 'allow-nan': accepts only np.nan and pd.NA values in array. Values
              cannot be infinite.
    
            .. versionadded:: 1.6
               `force_all_finite` was renamed to `ensure_all_finite`.
    
        ensure_non_negative : bool, default=False
            Make sure the array has only non-negative values. If True, an array that
            contains negative values will raise a ValueError.
    
            .. versionadded:: 1.6
    
        ensure_2d : bool, default=True
            Whether to raise a value error if array is not 2D.
    
        allow_nd : bool, default=False
            Whether to allow array.ndim > 2.
    
        ensure_min_samples : int, default=1
            Make sure that the array has a minimum number of samples in its first
            axis (rows for a 2D array). Setting to 0 disables this check.
    
        ensure_min_features : int, default=1
            Make sure that the 2D array has some minimum number of features
            (columns). The default value of 1 rejects empty datasets.
            This check is only enforced when the input data has effectively 2
            dimensions or is originally 1D and ``ensure_2d`` is True. Setting to 0
            disables this check.
    
        estimator : str or estimator instance, default=None
            If passed, include the name of the estimator in warning messages.
    
        input_name : str, default=""
            The data name used to construct the error message. In particular
            if `input_name` is "X" and the data has NaN values and
            allow_nan is False, the error message will link to the imputer
            documentation.
    
            .. versionadded:: 1.1.0
    
        Returns
        -------
        array_converted : object
            The converted and validated array.
    
        Examples
        --------
        >>> from sklearn.utils.validation import check_array
        >>> X = [[1, 2, 3], [4, 5, 6]]
        >>> X_checked = check_array(X)
        >>> X_checked
        array([[1, 2, 3], [4, 5, 6]])
        """
        ensure_all_finite = _deprecate_force_all_finite(force_all_finite, ensure_all_finite)
    
        if isinstance(array, np.matrix):
            raise TypeError(
                "np.matrix is not supported. Please convert to a numpy array with "
                "np.asarray. For more information see: "
                "https://numpy.org/doc/stable/reference/generated/numpy.matrix.html"
            )
    
        xp, is_array_api_compliant = get_namespace(array)
    
        # store reference to original array to check if copy is needed when
        # function returns
        array_orig = array
    
        # store whether originally we wanted numeric dtype
        dtype_numeric = isinstance(dtype, str) and dtype == "numeric"
    
        dtype_orig = getattr(array, "dtype", None)
        if not is_array_api_compliant and not hasattr(dtype_orig, "kind"):
            # not a data type (e.g. a column named dtype in a pandas DataFrame)
            dtype_orig = None
    
        # check if the object contains several dtypes (typically a pandas
        # DataFrame), and store them. If not, store None.
        dtypes_orig = None
        pandas_requires_conversion = False
        # track if we have a Series-like object to raise a better error message
        type_if_series = None
        if hasattr(array, "dtypes") and hasattr(array.dtypes, "__array__"):
            # throw warning if columns are sparse. If all columns are sparse, then
            # array.sparse exists and sparsity will be preserved (later).
            with suppress(ImportError):
                from pandas import SparseDtype
    
                def is_sparse(dtype):
                    return isinstance(dtype, SparseDtype)
    
                if not hasattr(array, "sparse") and array.dtypes.apply(is_sparse).any():
                    warnings.warn(
                        "pandas.DataFrame with sparse columns found."
                        "It will be converted to a dense numpy array."
                    )
    
            dtypes_orig = list(array.dtypes)
            pandas_requires_conversion = any(
                _pandas_dtype_needs_early_conversion(i) for i in dtypes_orig
            )
            if all(isinstance(dtype_iter, np.dtype) for dtype_iter in dtypes_orig):
                dtype_orig = np.result_type(*dtypes_orig)
            elif pandas_requires_conversion and any(d == object for d in dtypes_orig):
                # Force object if any of the dtypes is an object
                dtype_orig = object
    
        elif (_is_extension_array_dtype(array) or hasattr(array, "iloc")) and hasattr(
            array, "dtype"
        ):
            # array is a pandas series
            type_if_series = type(array)
            pandas_requires_conversion = _pandas_dtype_needs_early_conversion(array.dtype)
            if isinstance(array.dtype, np.dtype):
                dtype_orig = array.dtype
            else:
                # Set to None to let array.astype work out the best dtype
                dtype_orig = None
    
        if dtype_numeric:
            if (
                dtype_orig is not None
                and hasattr(dtype_orig, "kind")
                and dtype_orig.kind == "O"
            ):
                # if input is object, convert to float.
                dtype = xp.float64
            else:
                dtype = None
    
        if isinstance(dtype, (list, tuple)):
            if dtype_orig is not None and dtype_orig in dtype:
                # no dtype conversion required
                dtype = None
            else:
                # dtype conversion required. Let's select the first element of the
                # list of accepted types.
                dtype = dtype[0]
    
        if pandas_requires_conversion:
            # pandas dataframe requires conversion earlier to handle extension dtypes with
            # nans
            # Use the original dtype for conversion if dtype is None
            new_dtype = dtype_orig if dtype is None else dtype
            array = array.astype(new_dtype)
            # Since we converted here, we do not need to convert again later
            dtype = None
    
        if ensure_all_finite not in (True, False, "allow-nan"):
            raise ValueError(
                "ensure_all_finite should be a bool or 'allow-nan'. Got "
                f"{ensure_all_finite!r} instead."
            )
    
        if dtype is not None and _is_numpy_namespace(xp):
            # convert to dtype object to conform to Array API to be use `xp.isdtype` later
            dtype = np.dtype(dtype)
    
        estimator_name = _check_estimator_name(estimator)
        context = " by %s" % estimator_name if estimator is not None else ""
    
        # When all dataframe columns are sparse, convert to a sparse array
        if hasattr(array, "sparse") and array.ndim > 1:
            with suppress(ImportError):
                from pandas import SparseDtype  # noqa: F811
    
                def is_sparse(dtype):
                    return isinstance(dtype, SparseDtype)
    
                if array.dtypes.apply(is_sparse).all():
                    # DataFrame.sparse only supports `to_coo`
                    array = array.sparse.to_coo()
                    if array.dtype == np.dtype("object"):
                        unique_dtypes = set([dt.subtype.name for dt in array_orig.dtypes])
                        if len(unique_dtypes) > 1:
                            raise ValueError(
                                "Pandas DataFrame with mixed sparse extension arrays "
                                "generated a sparse matrix with object dtype which "
                                "can not be converted to a scipy sparse matrix."
                                "Sparse extension arrays should all have the same "
                                "numeric type."
                            )
    
        if sp.issparse(array):
            _ensure_no_complex_data(array)
            array = _ensure_sparse_format(
                array,
                accept_sparse=accept_sparse,
                dtype=dtype,
                copy=copy,
                ensure_all_finite=ensure_all_finite,
                accept_large_sparse=accept_large_sparse,
                estimator_name=estimator_name,
                input_name=input_name,
            )
            if ensure_2d and array.ndim < 2:
                raise ValueError(
                    f"Expected 2D input, got input with shape {array.shape}.\n"
                    "Reshape your data either using array.reshape(-1, 1) if "
                    "your data has a single feature or array.reshape(1, -1) "
                    "if it contains a single sample."
                )
        else:
            # If np.array(..) gives ComplexWarning, then we convert the warning
            # to an error. This is needed because specifying a non complex
            # dtype to the function converts complex to real dtype,
            # thereby passing the test made in the lines following the scope
            # of warnings context manager.
            with warnings.catch_warnings():
                try:
                    warnings.simplefilter("error", ComplexWarning)
                    if dtype is not None and xp.isdtype(dtype, "integral"):
                        # Conversion float -> int should not contain NaN or
                        # inf (numpy#14412). We cannot use casting='safe' because
                        # then conversion float -> int would be disallowed.
                        array = _asarray_with_order(array, order=order, xp=xp)
                        if xp.isdtype(array.dtype, ("real floating", "complex floating")):
                            _assert_all_finite(
                                array,
                                allow_nan=False,
                                msg_dtype=dtype,
                                estimator_name=estimator_name,
                                input_name=input_name,
                            )
                        array = xp.astype(array, dtype, copy=False)
                    else:
                        array = _asarray_with_order(array, order=order, dtype=dtype, xp=xp)
                except ComplexWarning as complex_warning:
                    raise ValueError(
                        "Complex data not supported\n{}\n".format(array)
                    ) from complex_warning
    
            # It is possible that the np.array(..) gave no warning. This happens
            # when no dtype conversion happened, for example dtype = None. The
            # result is that np.array(..) produces an array of complex dtype
            # and we need to catch and raise exception for such cases.
            _ensure_no_complex_data(array)
    
            if ensure_2d:
                # If input is scalar raise error
                if array.ndim == 0:
                    raise ValueError(
                        "Expected 2D array, got scalar array instead:\narray={}.\n"
                        "Reshape your data either using array.reshape(-1, 1) if "
                        "your data has a single feature or array.reshape(1, -1) "
                        "if it contains a single sample.".format(array)
                    )
                # If input is 1D raise error
                if array.ndim == 1:
                    # If input is a Series-like object (eg. pandas Series or polars Series)
                    if type_if_series is not None:
                        msg = (
                            f"Expected a 2-dimensional container but got {type_if_series} "
                            "instead. Pass a DataFrame containing a single row (i.e. "
                            "single sample) or a single column (i.e. single feature) "
                            "instead."
                        )
                    else:
                        msg = (
                            f"Expected 2D array, got 1D array instead:\narray={array}.\n"
                            "Reshape your data either using array.reshape(-1, 1) if "
                            "your data has a single feature or array.reshape(1, -1) "
                            "if it contains a single sample."
                        )
>                   raise ValueError(msg)
E                   ValueError: Expected 2D array, got 1D array instead:
E                   array=tensor([0.1000, 0.2000, 0.3000, 0.4000, 0.5000, 0.6000]).
E                   Reshape your data either using array.reshape(-1, 1) if your data has a single feature or array.reshape(1, -1) if it contains a single sample.

accept_large_sparse = True
accept_sparse = 'csr'
allow_nd   = False
array      = tensor([0.1000, 0.2000, 0.3000, 0.4000, 0.5000, 0.6000])
array_orig = tensor([0.1000, 0.2000, 0.3000, 0.4000, 0.5000, 0.6000])
context    = ' by check_pairwise_arrays'
copy       = False
dtype      = torch.float32
dtype_numeric = False
dtype_orig = torch.float32
dtypes_orig = None
ensure_2d  = True
ensure_all_finite = True
ensure_min_features = 1
ensure_min_samples = 1
ensure_non_negative = False
estimator  = 'check_pairwise_arrays'
estimator_name = 'check_pairwise_arrays'
force_all_finite = 'deprecated'
force_writeable = False
input_name = ''
is_array_api_compliant = True
msg        = 'Expected 2D array, got 1D array instead:\narray=tensor([0.1000, 0.2000, 0.3000, 0.4000, 0.5000, 0.6000]).\nReshape yo...r using array.reshape(-1, 1) if your data has a single feature or array.reshape(1, -1) if it contains a single sample.'
order      = None
pandas_requires_conversion = False
type_if_series = None
xp         = <module 'array_api_compat.torch' from '/Users/emilychen/miniforge3/envs/sklearn-dev/lib/python3.12/site-packages/array_api_compat/torch/__init__.py'>

sklearn/utils/validation.py:1090: ValueError

@ogrisel
Copy link
Member
ogrisel commented Aug 14, 2024

There is something fishy that happens right after check_array_api_metric_pairwise calls check_array_api_metric:

sklearn/metrics/tests/test_common.py:1979: in check_array_api_metric_pairwise
    check_array_api_metric(
        X_np       = array([[0.1, 0.2, 0.3],
       [0.4, 0.5, 0.6]], dtype=float32)
        Y_np       = array([[0.2, 0.3, 0.4],
       [0.5, 0.6, 0.7]], dtype=float32)
        array_namespace = 'torch'
        device     = 'cpu'
        dtype_name = 'float32'
        metric     = <function rbf_kernel at 0x125b99b20>
        metric_kwargs = {}
sklearn/metrics/tests/test_common.py:1773: in check_array_api_metric
    metric_xp = metric(a_xp, b_xp, **metric_kwargs)
        a_np       = array([[0.1, 0.2, 0.3],
       [0.4, 0.5, 0.6]], dtype=float32)
        a_xp       = tensor([0.1000, 0.2000, 0.3000, 0.4000, 0.5000, 0.6000])
        array_namespace = 'torch'
        b_np       = array([[0.2, 0.3, 0.4],
       [0.5, 0.6, 0.7]], dtype=float32)
        b_xp       = tensor([0.2000, 0.3000, 0.4000, 0.5000, 0.6000, 0.7000])
  • the original X_np and Y_np are 2D arrays as expected;
  • they are then renamed but unchanged as a_np and b_np in check_array_api_metric;
  • but then their namespace specific counterparts a_xp and b_xp are flattened which is not expected.

@OmarManzoor
Copy link
Contributor

Hi @EmilyXinyi

I checked out your branch locally and ran the sklearn.metrics.tests.test_common.test_array_api_compliance tests and they all seem to pass. What version of Pytorch are you using?

Also @ogrisel I think it might be useful to merge the PR #29475 as the current PR makes use of the other kernels. What do you think?

@ogrisel
Copy link
Member
ogrisel commented Aug 19, 2024

Let me update the branch in this PR to check if the errors in the CI can be reproduced or if they were a transient event.

Copy link
Member
@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some feedback. In general, I think you should have a look at a merge PR that adds array API support to another unsupervised transformer (e.g. PCA):

#26315

This should give guidance on how to upgrade this estimator and test it.

if X.ndim < 2:
X = xp.reshape(X, (1, -1))
if y is not None:
y = xp.asarray(y, device=device(X))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why change y? it is not used by this transformer (as this is an unsupervised estimator). There is no need to change it at all.

@@ -991,6 +991,11 @@ def fit(self, X, y=None):
self : object
Returns the instance itself.
"""
xp, _ = get_namespace(X)
if X.ndim < 2:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not access X.ndim before the call to self._validate_data(X, accept_sparse="csr") a few lines below.

Furthermore, self._validate_data itself is in charge of checking the shape of X by passing appropriate values to **check_params if needed.

However, all scikit-learn estimators are expected to raise a standardized error message when input data is 1D. self._validate_data is in doing this automatically by default so there is no need to do it manually.

S = np.maximum(S, 1e-12)
self.normalization_ = np.dot(U / np.sqrt(S), V)
dtype = _find_matching_floating_dtype(basis_kernel, xp=xp)
basis_kernel = xp.asarray(basis_kernel, dtype=dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should not be needed. If basis is an xp array, then basis_kernel should also be one. If not there is a problem. This might be fixed by #29475.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, the main problem is that the _parallel_pairwise function called by pairwise_kernels should be upgraded to support array API.

Maybe we could have a PR dedicated to adding array API support to pairwise_kernels first and pause this PR in the meantime to incrementally review PRs one after the other.

@pytest.mark.parametrize(
"array_namespace, device, dtype_name", yield_namespace_device_dtype_combinations()
)
def test_nystroem_approximation(array_namespace, device, dtype_name):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would rather not change the existing tests that check the behavior with regular numpy array when array API dispatch is disabled.

Instead add a new test that compares the results of using the estimator on array API inputs with array API dispatch enabled (and also check the type and device of fitted attributes) as we do for other array API estimators (e.g. test_pca_array_api_compliance) with non-default hyper-parameter values.

Note that you should also add "array_api_support": True to the _more_tags method of this estimator so that our common tests pick it up and run standard array API compliance tests with the default parameters.

Also use array_api somewhere in the test name so that it's easy to run all array API related test of scikit-learn by doing pytest -k array_api sklearn.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants