You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[6], line 1
----> 1 stable_cumsum(arr)
File ~/mambaforge/envs/torch_igpu/lib/python3.10/site-packages/sklearn/utils/extmath.py:1214, in stable_cumsum(arr, axis, rtol, atol)
1211 xp, _ = get_namespace(arr)
1213 out = xp.cumsum(arr, axis=axis, dtype=np.float64)
-> 1214 expected = xp.sum(arr, axis=axis, dtype=np.float64)
1215 if not xp.all(
1216 xp.isclose(
1217 out.take(-1, axis=axis), expected, rtol=rtol, atol=atol, equal_nan=True
1218 )
1219 ):
1220 warnings.warn(
1221 (
1222 "cumsum was found to be unstable: "
(...)
1225 RuntimeWarning,
1226 )
File <__array_function__ internals>:200, in sum(*args, **kwargs)
File /opt/venv/lib/python3.10/site-packages/numpy/core/fromnumeric.py:2324, in sum(a, axis, dtype, out, keepdims, initial, where)
2321 return out
2322 return res
-> 2324 return _wrapreduction(a, np.add, 'sum', axis, dtype, out, keepdims=keepdims,
2325 initial=initial, where=where)
File /opt/venv/lib/python3.10/site-packages/numpy/core/fromnumeric.py:82, in _wrapreduction(obj, ufunc, method, axis, dtype, out, **kwargs)
78 else:
79 # This branch is needed for reductions like any which don't
80 # support a dtype.
81 if dtype is not None:
---> 82 return reduction(axis=axis, dtype=dtype, out=out, **passkwargs)
83 else:
84 return reduction(axis=axis, out=out, **passkwargs)
TypeError: sum() received an invalid combination of arguments - got (dtype=type, out=NoneType, axis=NoneType, ), but expected one of:
* (*, torch.dtype dtype)
* (tuple of ints dim, bool keepdim, *, torch.dtype dtype)
* (tuple of names dim, bool keepdim, *, torch.dtype dtype)
the issue is that torch arrays do not accept numpy dtypes.
To quickfix the issue, those lines should replace np.float64 with xp.float64.
Beside applying this patch, @ogrisel suggested that an investigation might be worthy of why common array api tests (that are supposed to test with several array libraries and dtypes) didn't catch this one.
Versions
Version: main
The text was updated successfully, but these errors were encountered:
Note that cumsum is not yet part of the Array API standard and neither is searchsorted. I reverted the code in stable_cumsum to stop pretending that it's Array API compliant in #27431 and instead convert the needed array to a CPU backed numpy array. This step is not performance critical for PCA, so it's a robust way to workaround the problem in the short term.
For the longer term, I think we should stop using stable_cumsum in scikit-learn and instead use regular np.cumsum (and xp.cumsum, once it exists) with the data-derived dtype and rely on our tests to check that numerical accuracy is good enough instead of doing that check as part of the library usage.
Describe the bug
stable_cumsum has been adapted for Array API support, but when provided a pytorch array, fails:
Steps/Code to Reproduce
Expected Results
should output
Actual Results
but instead raises an exception:
the issue is that torch arrays do not accept numpy dtypes.
To quickfix the issue, those lines should replace
np.float64
withxp.float64
.Beside applying this patch, @ogrisel suggested that an investigation might be worthy of why common array api tests (that are supposed to test with several array libraries and dtypes) didn't catch this one.
Versions
The text was updated successfully, but these errors were encountered: