8000 [Array API] `stable_cumsum` uses `np.float64` rather than `xp.float64` · Issue #27427 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

[Array API] stable_cumsum uses np.float64 rather than xp.float64 #27427

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
fcharras opened this issue Sep 20, 2023 · 1 comment · Fixed by #27431
Closed

[Array API] stable_cumsum uses np.float64 rather than xp.float64 #27427

fcharras opened this issue Sep 20, 2023 · 1 comment · Fixed by #27431

Comments

@fcharras
Copy link
Contributor
fcharras commented Sep 20, 2023

Describe the bug

stable_cumsum has been adapted for Array API support, but when provided a pytorch array, fails:

Steps/Code to Reproduce

from sklearn.utils.extmath import stable_cumsum
import torch

arr = torch.asarray([1,2,3], dtype=torch.float32)
stable_cumsum(arr)

Expected Results

should output

tensor([1., 3., 6.])

Actual Results

but instead raises an exception:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[6], line 1
----> 1 stable_cumsum(arr)

File ~/mambaforge/envs/torch_igpu/lib/python3.10/site-packages/sklearn/utils/extmath.py:1214, in stable_cumsum(arr, axis, rtol, atol)
   1211 xp, _ = get_namespace(arr)
   1213 out = xp.cumsum(arr, axis=axis, dtype=np.float64)
-> 1214 expected = xp.sum(arr, axis=axis, dtype=np.float64)
   1215 if not xp.all(
   1216     xp.isclose(
   1217         out.take(-1, axis=axis), expected, rtol=rtol, atol=atol, equal_nan=True
   1218     )
   1219 ):
   1220     warnings.warn(
   1221         (
   1222             "cumsum was found to be unstable: "
   (...)
   1225         RuntimeWarning,
   1226     )

File <__array_function__ internals>:200, in sum(*args, **kwargs)

File /opt/venv/lib/python3.10/site-packages/numpy/core/fromnumeric.py:2324, in sum(a, axis, dtype, out, keepdims, initial, where)
   2321         return out
   2322     return res
-> 2324 return _wrapreduction(a, np.add, 'sum', axis, dtype, out, keepdims=keepdims,
   2325                       initial=initial, where=where)

File /opt/venv/lib/python3.10/site-packages/numpy/core/fromnumeric.py:82, in _wrapreduction(obj, ufunc, method, axis, dtype, out, **kwargs)
     78 else:
     79     # This branch is needed for reductions like any which don't
     80     # support a dtype.
     81     if dtype is not None:
---> 82         return reduction(axis=axis, dtype=dtype, out=out, **passkwargs)
     83     else:
     84         return reduction(axis=axis, out=out, **passkwargs)

TypeError: sum() received an invalid combination of arguments - got (dtype=type, out=NoneType, axis=NoneType, ), but expected one of:
 * (*, torch.dtype dtype)
 * (tuple of ints dim, bool keepdim, *, torch.dtype dtype)
 * (tuple of names dim, bool keepdim, *, torch.dtype dtype)

the issue is that torch arrays do not accept numpy dtypes.

To quickfix the issue, those lines should replace np.float64 with xp.float64.

Beside applying this patch, @ogrisel suggested that an investigation might be worthy of why common array api tests (that are supposed to test with several array libraries and dtypes) didn't catch this one.

Versions

Version: main
@fcharras fcharras added Bug Needs Triage Issue requires triage labels Sep 20, 2023
@ogrisel
Copy link
Member
ogrisel commented Sep 28, 2023

Note that cumsum is not yet part of the Array API standard and neither is searchsorted. I reverted the code in stable_cumsum to stop pretending that it's Array API compliant in #27431 and instead convert the needed array to a CPU backed numpy array. This step is not performance critical for PCA, so it's a robust way to workaround the problem in the short term.

For the longer term, I think we should stop using stable_cumsum in scikit-learn and instead use regular np.cumsum (and xp.cumsum, once it exists) with the data-derived dtype and rely on our tests to check that numerical accuracy is good enough instead of doing that check as part of the library usage.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants
0