8000 Enable array API testing for jax.experimental.array_api by ogrisel · Pull Request #29647 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

Enable array API testing for jax.experimental.array_api #29647

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

ogrisel
Copy link
Member
@ogrisel ogrisel commented Aug 9, 2024

This is a draft PR to run the existing array API tests on jax inputs. This is not expected to work as jax does not support inplace updates via __setitem__ in particular as discussed in:

Note that scipy started to run its array API tests against jax but maintains a list of tests to skip because of that design decision:

TODO before considering a review for merge

  • compile a list of all root causes of jax-specific test failures and suggest solutions;
  • find a solution to the inplace assignment problem (will require at least a change in the array API spec or in jax or both);
  • update the documentation.

Copy link
github-actions bot commented Aug 9, 2024

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: ed1b088. Link to the linter CI: here

@ogrisel
Copy link
Member Author
ogrisel commented Aug 9, 2024

I forgot to add jax to one of the CI builds with array-api-compat support. Doing that in the next commit.

@ogrisel ogrisel added the CUDA CI label Aug 9, 2024
@github-actions github-actions bot removed the CUDA CI label Aug 9, 2024
@ogrisel ogrisel added the CUDA CI label Aug 9, 2024
@github-actions github-actions bot removed the CUDA CI label Aug 9, 2024
@ogrisel
Copy link
Member Author
ogrisel commented Aug 9, 2024

We get 41 failed array API tests. Most (all?) of them seem caused by the inplace assignment problem.

Also note that 140 array API tests pass with jax though.

EDIT: here is an archive of the 42 failures observed on the CPU-only run (vs 41 for the CUDA run):

=================================== FAILURES ===================================
_ test_pca_array_api_compliance[PCA(n_components=2,svd_solver='full')-check_array_api_input_and_values-jax.experimental.array_api-device9-float32] _
[gw1] linux -- Python 3.12.5 /usr/share/miniconda/envs/testvenv/bin/python

FAILED decomposition/tests/test_pca.py::test_pca_array_api_compliance[PCA(n_components=2,svd_solver='covariance_eigh')-check_array_api_input_and_values-jax.experimental.array_api-device9-float32] - TypeError: '<class 'jaxlib.xla_extension.ArrayImpl'>' object does not suppo...
FAILED decomposition/tests/test_pca.py::test_pca_array_api_compliance[PCA(n_components=2,svd_solver='covariance_eigh')-check_array_api_get_precision-jax.experimental.array_api-device9-float32] - TypeError: '<class 'jaxlib.xla_extension.ArrayImpl'>' object does not suppo...
FAILED decomposition/tests/test_pca.py::test_pca_array_api_compliance[PCA(n_components=2,svd_solver='covariance_eigh',whiten=True)-check_array_api_input_and_values-jax.experimental.array_api-device9-float32] - TypeError: '<class 'jaxlib.xla_extension.ArrayImpl'>' object does not suppo...
FAILED decomposition/tests/test_pca.py::test_pca_array_api_compliance[PCA(n_components=2,svd_solver='covariance_eigh',whiten=True)-check_array_api_get_precision-jax.experimental.array_api-device9-float32] - TypeError: '<class 'jaxlib.xla_extension.ArrayImpl'>' object does not suppo...
FAILED decomposition/tests/test_pca.py::test_pca_array_api_compliance[PCA(n_components=2,power_iteration_normalizer='QR',random_state=0,svd_solver='randomized')-check_array_api_input_and_values-jax.experimental.array_api-device9-float32] - TypeError: '<class 'jaxlib.xla_extension.ArrayImpl'>' object does not suppo...
FAILED decomposition/tests/test_pca.py::test_pca_array_api_compliance[PCA(n_components=2,power_iteration_normalizer='QR',random_state=0,svd_solver='randomized')-check_array_api_get_precision-jax.experimental.array_api-device9-float32] - TypeError: '<class 'jaxlib.xla_extension.ArrayImpl'>' object does not suppo...
FAILED decomposition/tests/test_pca.py::test_pca_mle_array_api_compliance[PCA(n_components='mle',svd_solver='full')-check_array_api_get_precision-jax.experimental.array_api-device9-float32] - TypeError: '<class 'jaxlib.xla_extension.ArrayImpl'>' object does not suppo...
FAILED linear_model/tests/test_ridge.py::test_ridge_array_api_compliance[Ridge(solver='svd')-check_array_api_input_and_values-jax.experimental.array_api-device9-float32] - TypeError: '<class 'jaxlib.xla_extension.ArrayImpl'>' object does not suppo...
FAILED linear_model/tests/test_ridge.py::test_ridge_array_api_compliance[Ridge(solver='svd')-check_array_api_attributes-jax.experimental.array_api-device9-float32] - TypeError: '<class 'jaxlib.xla_extension.ArrayImpl'>' object does not suppo...
FAILED linear_model/tests/test_ridge.py::test_array_api_error_and_warnings_for_solver_parameter[jax.experimental.array_api] - TypeError: '<class 'jaxlib.xla_extension.ArrayImpl'>' object does not suppo...
FAILED metrics/tests/test_common.py::test_array_api_compliance[r2_score-check_array_api_regression_metric-jax.experimental.array_api-device9-float32] - TypeError: '<class 'jaxlib.xla_extension.ArrayImpl'>' object does not suppo...
FAILED metrics/tests/test_common.py::test_array_api_compliance[r2_score-check_array_api_regression_metric_multioutput-jax.experimental.array_api-device9-float32] - TypeError: '<class 'jaxlib.xla_extension.ArrayImpl'>' object does not suppo...
FAILED metrics/tests/test_common.py::test_array_api_compliance[cosine_similarity-check_array_api_metric_pairwise-jax.experimental.array_api-device9-float32] - TypeError: '<class 'jaxlib.xla_extension.ArrayImpl'>' object does not suppo...
FAILED metrics/tests/test_common.py::test_array_api_compliance[paired_cosine_distances-check_array_api_metric_pairwise-jax.experimental.array_api-device9-float32] - TypeError: '<class 'jaxlib.xla_extension.ArrayImpl'>' object does not suppo...
FAILED metrics/tests/test_common.py::test_array_api_compliance[cosine_distances-check_array_api_metric_pairwise-jax.experimental.array_api-device9-float32] - TypeError: '<class 'jaxlib.xla_extension.ArrayImpl'>' object does not suppo...
FAILED metrics/tests/test_common.py::test_array_api_compliance[euclidean_distances-check_array_api_metric_pairwise-jax.experimental.array_api-device9-float32] - TypeError: '<class 'jaxlib.xla_extension.ArrayImpl'>' object does not suppo...
FAILED metrics/tests/test_common.py::test_array_api_compliance[rbf_kernel-check_array_api_metric_pairwise-jax.experimental.array_api-device9-float32] - TypeError: '<class 'jaxlib.xla_extension.ArrayImpl'>' object does not suppo...
FAILED model_selection/tests/test_search.py::test_array_api_search_cv_classifier[GridSearchCV-jax.experimental.array_api-device9-float32] - ValueError: 
FAILED model_selection/tests/test_search.py::test_array_api_search_cv_classifier[RandomizedSearchCV-jax.experimental.array_api-device9-float32] - ValueError: 
FAILED preprocessing/tests/test_data.py::test_scaler_array_api_compliance[MaxAbsScaler()-check_array_api_input_and_values-jax.experimental.array_api-device9-float32] - TypeError: '<class 'jaxlib.xla_extension.ArrayImpl'>' object does not suppo...
FAILED preprocessing/tests/test_data.py::test_scaler_array_api_compliance[MinMaxScaler()-check_array_api_input_and_values-jax.experimental.array_api-device9-float32] - TypeError: '<class 'jaxlib.xla_extension.ArrayImpl'>' object does not suppo...
FAILED preprocessing/tests/test_data.py::test_scaler_array_api_compliance[Normalizer(norm='l1')-check_array_api_input_and_values-jax.experimental.array_api-device9-float32] - TypeError: '<class 'jaxlib.xla_extension.ArrayImpl'>' object does not suppo...
FAILED preprocessing/tests/test_data.py::test_scaler_array_api_compliance[Normalizer()-check_array_api_input_and_values-jax.experimental.array_api-device9-float32] - TypeError: '<class 'jaxlib.xla_extension.ArrayImpl'>' object does not suppo...
FAILED preprocessing/tests/test_data.py::test_scaler_array_api_compliance[Normalizer(norm='max')-check_array_api_input_and_values-jax.experimental.array_api-device9-float32] - TypeError: '<class 'jaxlib.xla_extension.ArrayImpl'>' object does not suppo...
FAILED tests/test_common.py::test_estimators[LinearDiscriminantAnalysis()-check_array_api_input(array_namespace=jax.experimental.array_api,dtype_name=float32,device=TFRT_CPU_0)] - TypeError: '<class 'jaxlib.xla_extension.ArrayImpl'>' object does not suppo...
FAILED tests/test_common.py::test_estimators[Normalizer()-check_array_api_input(array_namespace=jax.experimental.array_api,dtype_name=float32,device=TFRT_CPU_0)] - TypeError: '<class 'jaxlib.xla_extension.ArrayImpl'>' object does not suppo...
FAILED tests/test_common.py::test_estimators[PCA()-check_array_api_input(array_namespace=jax.experimental.array_api,dtype_name=float32,device=TFRT_CPU_0)] - TypeError: '<class 'jaxlib.xla_extension.ArrayImpl'>' object does not suppo...
FAILED tests/test_common.py::test_estimators[Ridge()-check_array_api_input(array_namespace=jax.experimental.array_api,dtype_name=float32,device=TFRT_CPU_0)] - TypeError: '<class 'jaxlib.xla_extension.ArrayImpl'>' object does not suppo...
FAILED tests/test_common.py::test_search_cv[GridSearchCV(cv=2,error_score='raise',estimator=Ridge(),param_grid={'alpha':[0.1,1.0]})-check_array_api_input(array_namespace=jax.experimental.array_api,dtype_name=float32,device=TFRT_CPU_0)] - TypeError: '<class 'jaxlib.xla_extension.ArrayImpl'>' object does not suppo...
FAILED tests/test_common.py::test_search_cv[HalvingGridSearchCV(cv=2,error_score='raise',estimator=Ridge(),min_resources='smallest',param_grid={'alpha':[0.1,1.0]},random_state=0)-check_array_api_input(array_namespace=jax.experimental.array_api,dtype_name=float32,device=TFRT_CPU_0)0] - TypeError: '<class 'jaxlib.xla_extension.ArrayImpl'>' object does not suppo...
FAILED tests/test_common.py::test_search_cv[RandomizedSearchCV(cv=2,error_score='raise',estimator=Ridge(),param_distributions={'alpha':[0.1,1.0]},random_state=0)-check_array_api_input(array_namespace=jax.experimental.array_api,dtype_name=float32,device=TFRT_CPU_0)] - TypeError: '<class 'jaxlib.xla_extension.ArrayImpl'>' object does not suppo...
FAILED tests/test_common.py::test_search_cv[HalvingGridSearchCV(cv=2,error_score='raise',estimator=Ridge(),min_resources='smallest',param_grid={'alpha':[0.1,1.0]},random_state=0)-check_array_api_input(array_namespace=jax.experimental.array_api,dtype_name=float32,device=TFRT_CPU_0)1] - TypeError: '<class 'jaxlib.xla_extension.ArrayImpl'>' object does not suppo...
FAILED utils/tests/test_array_api.py::test_indexing_dtype[jax.experimental.array_api-_device9-float32] - AssertionError
FAILED utils/tests/test_array_api.py::test_fill_or_add_to_diagonal[True-jax.experimental.array_api-device_9-float32] - TypeError: '<class 'jaxlib.xla_extension.ArrayImpl'>' object does not suppo...
FAILED utils/tests/test_array_api.py::test_fill_or_add_to_diagonal[False-jax.experimental.array_api-device_9-float32] - TypeError: '<class 'jaxlib.xla_extension.ArrayImpl'>' object does not suppo...
FAILED utils/tests/test_estimator_checks.py::test_check_estimator_clones - TypeError: '<class 'jaxlib.xla_extension.ArrayImpl'>' object does not suppo...
= 42 failed, 32853 passed, 3083 skipped, 92 xfailed, 52 xpassed, 6819 warnings in 1460.23s (0:24:20) =

I checked and the two ValueError exceptions in the search cv test are caused by the assignment problem. I will open a dedicated PR to improve this test.

The only other failure is an assertion error in test_indexing_dtype because jax uses the int32 dtype as default indexing dtype even though this is running on 64 bit runtime which is quite unexpected:

  • The following works:
>>> a = xp.arange(2 ** 32 + 1, dtype=xp.uint8)
  • but the following fails:
>>> xp.asarray(2 ** 32 + 1).dtype
Traceback (most recent call last):
  Cell In[18], line 1
    xp.asarray(2 ** 32 + 1).dtype
  File ~/miniforge3/envs/dev/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:3568 in asarray
    return array(a, dtype=dtype, copy=bool(copy), order=order, device=device)
  File ~/miniforge3/envs/dev/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:3387 in array
    _ = dtypes.coerce_to_array(object, dtype)
  File ~/miniforge3/envs/dev/lib/python3.11/site-packages/jax/_src/dtypes.py:322 in coerce_to_array
    dtype = _scalar_type_to_dtype(type(x), x)
  File ~/miniforge3/envs/dev/lib/python3.11/site-packages/jax/_src/dtypes.py:311 in _scalar_type_to_dtype
    raise OverflowError(f"Python int {value} too large to convert to {dtype}")
OverflowError: Python int 4294967297 too large to convert to int32
  • and the following does not fail but silently overflows because of the inadequate default choice for the integer dtype:
>>> xp.asarray([2 ** 32 + 1])
Array([1], dtype=int32)

@lucascolley
Copy link
Contributor

some progress at scipy/scipy#22070

@ogrisel
Copy link
Member Author
ogrisel commented Mar 25, 2025

For the record, #30340 was recently merged to main so this might unblock this PR because we can now leverage JAX compat features implemented in array-api-extra.

EDIT: updating this draft PR to explore usage of array-api-extra is on my low to medium-priority list. If you are interested in exploring this, feel free to open a concurrent draft PR, I won't be offended ;)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants
0