diff --git a/sklearn/utils/tests/test_cython_blas.py b/sklearn/utils/tests/test_cython_blas.py index e57bfc3ec5a9c..e221c3fea4e02 100644 --- a/sklearn/utils/tests/test_cython_blas.py +++ b/sklearn/utils/tests/test_cython_blas.py @@ -2,10 +2,8 @@ import pytest from sklearn.utils._cython_blas import ( - ColMajor, - NoTrans, - RowMajor, - Trans, + BLAS_Order, + BLAS_Trans, _asum_memview, _axpy_memview, _copy_memview, @@ -30,7 +28,7 @@ def _numpy_to_cython(dtype): RTOL = {np.float32: 1e-6, np.float64: 1e-12} -ORDER = {RowMajor: "C", ColMajor: "F"} +ORDER = {BLAS_Order.RowMajor: "C", BLAS_Order.ColMajor: "F"} def _no_op(x): @@ -166,9 +164,15 @@ def test_rot(dtype): @pytest.mark.parametrize("dtype", [np.float32, np.float64]) @pytest.mark.parametrize( - "opA, transA", [(_no_op, NoTrans), (np.transpose, Trans)], ids=["NoTrans", "Trans"] + "opA, transA", + [(_no_op, BLAS_Trans.NoTrans), (np.transpose, BLAS_Trans.Trans)], + ids=["NoTrans", "Trans"], +) +@pytest.mark.parametrize( + "order", + [BLAS_Order.RowMajor, BLAS_Order.ColMajor], + ids=["RowMajor", "ColMajor"], ) -@pytest.mark.parametrize("order", [RowMajor, ColMajor], ids=["RowMajor", "ColMajor"]) def test_gemv(dtype, opA, transA, order): gemv = _gemv_memview[_numpy_to_cython(dtype)] @@ -187,7 +191,11 @@ def test_gemv(dtype, opA, transA, order): @pytest.mark.parametrize("dtype", [np.float32, np.float64]) -@pytest.mark.parametrize("order", [RowMajor, ColMajor], ids=["RowMajor", "ColMajor"]) +@pytest.mark.parametrize( + "order", + [BLAS_Order.RowMajor, BLAS_Order.ColMajor], + ids=["BLAS_Order.RowMajor", "BLAS_Order.ColMajor"], +) def test_ger(dtype, order): ger = _ger_memview[_numpy_to_cython(dtype)] @@ -207,12 +215,20 @@ def test_ger(dtype, order): @pytest.mark.parametrize("dtype", [np.float32, np.float64]) @pytest.mark.parametrize( - "opB, transB", [(_no_op, NoTrans), (np.transpose, Trans)], ids=["NoTrans", "Trans"] + "opB, transB", + [(_no_op, BLAS_Trans.NoTrans), (np.transpose, BLAS_Trans.Trans)], + ids=["NoTrans", "Trans"], +) +@pytest.mark.parametrize( + "opA, transA", + [(_no_op, BLAS_Trans.NoTrans), (np.transpose, BLAS_Trans.Trans)], + ids=["NoTrans", "Trans"], ) @pytest.mark.parametrize( - "opA, transA", [(_no_op, NoTrans), (np.transpose, Trans)], ids=["NoTrans", "Trans"] + "order", + [BLAS_Order.RowMajor, BLAS_Order.ColMajor], + ids=["BLAS_Order.RowMajor", "BLAS_Order.ColMajor"], ) -@pytest.mark.parametrize("order", [RowMajor, ColMajor], ids=["RowMajor", "ColMajor"]) def test_gemm(dtype, opA, transA, opB, transB, order): gemm = _gemm_memview[_numpy_to_cython(dtype)]