-
-
Notifications
You must be signed in to change notification settings - Fork 26.2k
Closed
Description
Describe the bug
stable_cumsum has been adapted for Array API support, but when provided a pytorch array, fails:
Steps/Code to Reproduce
from sklearn.utils.extmath import stable_cumsum
import torch
arr = torch.asarray([1,2,3], dtype=torch.float32)
stable_cumsum(arr)
Expected Results
should output
tensor([1., 3., 6.])
Actual Results
but instead raises an exception:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[6], line 1
----> 1 stable_cumsum(arr)
File ~/mambaforge/envs/torch_igpu/lib/python3.10/site-packages/sklearn/utils/extmath.py:1214, in stable_cumsum(arr, axis, rtol, atol)
1211 xp, _ = get_namespace(arr)
1213 out = xp.cumsum(arr, axis=axis, dtype=np.float64)
-> 1214 expected = xp.sum(arr, axis=axis, dtype=np.float64)
1215 if not xp.all(
1216 xp.isclose(
1217 out.take(-1, axis=axis), expected, rtol=rtol, atol=atol, equal_nan=True
1218 )
1219 ):
1220 warnings.warn(
1221 (
1222 "cumsum was found to be unstable: "
(...)
1225 RuntimeWarning,
1226 )
File <__array_function__ internals>:200, in sum(*args, **kwargs)
File /opt/venv/lib/python3.10/site-packages/numpy/core/fromnumeric.py:2324, in sum(a, axis, dtype, out, keepdims, initial, where)
2321 return out
2322 return res
-> 2324 return _wrapreduction(a, np.add, 'sum', axis, dtype, out, keepdims=keepdims,
2325 initial=initial, where=where)
File /opt/venv/lib/python3.10/site-packages/numpy/core/fromnumeric.py:82, in _wrapreduction(obj, ufunc, method, axis, dtype, out, **kwargs)
78 else:
79 # This branch is needed for reductions like any which don't
80 # support a dtype.
81 if dtype is not None:
---> 82 return reduction(axis=axis, dtype=dtype, out=out, **passkwargs)
83 else:
84 return reduction(axis=axis, out=out, **passkwargs)
TypeError: sum() received an invalid combination of arguments - got (dtype=type, out=NoneType, axis=NoneType, ), but expected one of:
* (*, torch.dtype dtype)
* (tuple of ints dim, bool keepdim, *, torch.dtype dtype)
* (tuple of names dim, bool keepdim, *, torch.dtype dtype)
the issue is that torch arrays do not accept numpy dtypes.
To quickfix the issue, those lines should replace np.float64
with xp.float64
.
Beside applying this patch, @ogrisel suggested that an investigation might be worthy of why common array api tests (that are supposed to test with several array libraries and dtypes) didn't catch this one.
Versions
Version: main