@@ -202,17 +202,32 @@ it supports the Array API. This will enable dedicated checks as part of the
202
202
common tests to verify that the estimators' results are the same when using
203
203
vanilla NumPy and Array API inputs.
204
204
205
- To run the full set of checks you need to install both
206
- `PyTorch <https://pytorch.org/ >`_ and `CuPy <https://cupy.dev/ >`_ and have
205
+ To run these checks you need to install
206
+ `array-api-strict <https://data-apis.org/array-api-strict/ >`_ in your
207
+ test environment. This allows you to run checks without having a
208
+ GPU. To run the full set of checks you also need to install
209
+ `PyTorch <https://pytorch.org/ >`_, `CuPy <https://cupy.dev/ >`_ and have
207
210
a GPU. Checks that can not be executed or have missing dependencies will be
208
211
automatically skipped. Therefore it's important to run the tests with the
209
212
`-v ` flag to see which checks are skipped:
210
213
211
214
.. prompt :: bash $
212
215
213
- pip install ... # selected libraries as needed
216
+ pip install array-api-strict # and other libraries as needed
214
217
pytest -k "array_api" -v
215
218
219
+ Running the scikit-learn tests against `array-api-strict ` should help reveal
220
+ most code problems related to handling multiple device inputs via the use of
221
+ simulated non-CPU devices. This allows for fast iterative development and debugging of
222
+ array API related code.
223
+
224
+ However, to ensure full handling of PyTorch or CuPy inputs allocated on actual GPU
225
+ devices, it is necessary to run the tests against those libraries and hardware.
226
+ This can either be achieved by using
227
+ `Google Colab <https://gist.github.com/EdAbati/ff3bdc06bafeb92452b3740686cc8d7c >`_
228
+ or leveraging our CI infrastructure on pull requests (manually triggered by maintainers
229
+ for cost reasons).
230
+
216
231
.. _mps_support :
217
232
218
233
Note on MPS device support
0 commit comments