E46A [Array API] `stable_cumsum` uses `np.float64` rather than `xp.float64` · Issue #27427 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content
[Array API] stable_cumsum uses np.float64 rather than xp.float64 #27427
@fcharras

Description

@fcharras

Describe the bug

stable_cumsum has been adapted for Array API support, but when provided a pytorch array, fails:

Steps/Code to Reproduce

from sklearn.utils.extmath import stable_cumsum
import torch

arr = torch.asarray([1,2,3], dtype=torch.float32)
stable_cumsum(arr)

Expected Results

should output

tensor([1., 3., 6.])

Actual Results

but instead raises an exception:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[6], line 1
----> 1 stable_cumsum(arr)

File ~/mambaforge/envs/torch_igpu/lib/python3.10/site-packages/sklearn/utils/extmath.py:1214, in stable_cumsum(arr, axis, rtol, atol)
   1211 xp, _ = get_namespace(arr)
   1213 out = xp.cumsum(arr, axis=axis, dtype=np.float64)
-> 1214 expected = xp.sum(arr, axis=axis, dtype=np.float64)
   1215 if not xp.all(
   1216     xp.isclose(
   1217         out.take(-1, axis=axis), expected, rtol=rtol, atol=atol, equal_nan=True
   1218     )
   1219 ):
   1220     warnings.warn(
   1221         (
   1222             "cumsum was found to be unstable: "
   (...)
   1225         RuntimeWarning,
   1226     )

File <__array_function__ internals>:200, in sum(*args, **kwargs)

File /opt/venv/lib/python3.10/site-packages/numpy/core/fromnumeric.py:2324, in sum(a, axis, dtype, out, keepdims, initial, where)
   2321         return out
   2322     return res
-> 2324 return _wrapreduction(a, np.add, 'sum', axis, dtype, out, keepdims=keepdims,
   2325                       initial=initial, where=where)

File /opt/venv/lib/python3.10/site-packages/numpy/core/fromnumeric.py:82, in _wrapreduction(obj, ufunc, method, axis, dtype, out, **kwargs)
     78 else:
     79     # This branch is needed for reductions like any which don't
     80     # support a dtype.
     81     if dtype is not None:
---> 82         return reduction(axis=axis, dtype=dtype, out=out, **passkwargs)
     83     else:
     84         return reduction(axis=axis, out=out, **passkwargs)

TypeError: sum() received an invalid combination of arguments - got (dtype=type, out=NoneType, axis=NoneType, ), but expected one of:
 * (*, torch.dtype dtype)
 * (tuple of ints dim, bool keepdim, *, torch.dtype dtype)
 * (tuple of names dim, bool keepdim, *, torch.dtype dtype)

the issue is that torch arrays do not accept numpy dtypes.

To quickfix the issue, those lines should replace np.float64 with xp.float64.

Beside applying this patch, @ogrisel suggested that an investigation might be worthy of why common array api tests (that are supposed to test with several array libraries and dtypes) didn't catch this one.

Versions

Version: main

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0