@@ -134,7 +134,7 @@ def test_determine_key_type_array_api(array_namespace, device, dtype_name):
134134
135135
136136@pytest .mark .parametrize (
137- "array_type" , ["list" , "array" , "sparse" , "dataframe" , "polars" ]
137+ "array_type" , ["list" , "array" , "sparse" , "dataframe" , "polars" , "pyarrow" ]
138138)
139139@pytest .mark .parametrize ("indices_type" , ["list" , "tuple" , "array" , "series" , "slice" ])
140140def test_safe_indexing_2d_container_axis_0 (array_type , indices_type ):
@@ -149,7 +149,9 @@ def test_safe_indexing_2d_container_axis_0(array_type, indices_type):
149149 )
150150
151151
152- @pytest .mark .parametrize ("array_type" , ["list" , "array" , "series" , "polars_series" ])
152+ @pytest .mark .parametrize (
153+ "array_type" , ["list" , "array" , "series" , "polars_series" , "pyarrow_array" ]
154+ )
153155@pytest .mark .parametrize ("indices_type" , ["list" , "tuple" , "array" , "series" , "slice" ])
154156def test_safe_indexing_1d_container (array_type , indices_type ):
155157 indices = [1 , 2 ]
@@ -161,7 +163,9 @@ def test_safe_indexing_1d_container(array_type, indices_type):
161163 assert_allclose_dense_sparse (subset , _convert_container ([2 , 3 ], array_type ))
162164
163165
164- @pytest .mark .parametrize ("array_type" , ["array" , "sparse" , "dataframe" , "polars" ])
166+ @pytest .mark .parametrize (
167+ "array_type" , ["array" , "sparse" , "dataframe" , "polars" , "pyarrow" ]
168+ )
165169@pytest .mark .parametrize ("indices_type" , ["list" , "tuple" , "array" , "series" , "slice" ])
166170@pytest .mark .parametrize ("indices" , [[1 , 2 ], ["col_1" , "col_2" ]])
167171def test_safe_indexing_2d_container_axis_1 (array_type , indices_type , indices ):
@@ -177,7 +181,7 @@ def test_safe_indexing_2d_container_axis_1(array_type, indices_type, indices):
177181 )
178182 indices_converted = _convert_container (indices_converted , indices_type )
179183
180- if isinstance (indices [0 ], str ) and array_type not in ("dataframe " , "polars " ):
184+ if isinstance (indices [0 ], str ) and array_type in ("array " , "sparse " ):
181185 err_msg = (
182186 "Specifying the columns using strings is only supported for dataframes"
183187 )
@@ -192,7 +196,9 @@ def test_safe_indexing_2d_container_axis_1(array_type, indices_type, indices):
192196
193197@pytest .mark .parametrize ("array_read_only" , [True , False ])
194198@pytest .mark .parametrize ("indices_read_only" , [True , False ])
195- @pytest .mark .parametrize ("array_type" , ["array" , "sparse" , "dataframe" , "polars" ])
199+ @pytest .mark .parametrize (
200+ "array_type" , ["array" , "sparse" , "dataframe" , "polars" , "pyarrow" ]
201+ )
196202@pytest .mark .parametrize ("indices_type" , ["array" , "series" ])
197203@pytest .mark .parametrize (
198204 "axis, expected_array" , [(0 , [[4 , 5 , 6 ], [7 , 8 , 9 ]]), (1 , [[2 , 3 ], [5 , 6 ], [8 , 9 ]])]
@@ -212,7 +218,9 @@ def test_safe_indexing_2d_read_only_axis_1(
212218 assert_allclose_dense_sparse (subset , _convert_container (expected_array , array_type ))
213219
214220
215- @pytest .mark .parametrize ("array_type" , ["list" , "array" , "series" , "polars_series" ])
221+ @pytest .mark .parametrize (
222+ "array_type" , ["list" , "array" , "series" , "polars_series" , "pyarrow_array" ]
223+ )
216224@pytest .mark .parametrize ("indices_type" , ["list" , "tuple" , "array" , "series" ])
217225def test_safe_indexing_1d_container_mask (array_type , indices_type ):
218226 indices = [False ] + [True ] * 2 + [False ] * 6
@@ -222,7 +230,9 @@ def test_safe_indexing_1d_container_mask(array_type, indices_type):
222230 assert_allclose_dense_sparse (subset , _convert_container ([2 , 3 ], array_type ))
223231
224232
225- @pytest .mark .parametrize ("array_type" , ["array" , "sparse" , "dataframe" , "polars" ])
233+ @pytest .mark .parametrize (
234+ "array_type" , ["array" , "sparse" , "dataframe" , "polars" , "pyarrow" ]
235+ )
226236@pytest .mark .parametrize ("indices_type" , ["list" , "tuple" , "array" , "series" ])
227237@pytest .mark .parametrize (
228238 "axis, expected_subset" ,
@@ -250,6 +260,7 @@ def test_safe_indexing_2d_mask(array_type, indices_type, axis, expected_subset):
250260 ("sparse" , "sparse" ),
251261 ("dataframe" , "series" ),
252262 ("polars" , "polars_series" ),
263+ ("pyarrow" , "pyarrow_array" ),
253264 ],
254265)
255266def test_safe_indexing_2d_scalar_axis_0 (array_type , expected_output_type ):
@@ -260,7 +271,9 @@ def test_safe_indexing_2d_scalar_axis_0(array_type, expected_output_type):
260271 assert_allclose_dense_sparse (subset , expected_array )
261272
262273
263- @pytest .mark .parametrize ("array_type" , ["list" , "array" , "series" , "polars_series" ])
274+ @pytest .mark .parametrize (
275+ "array_type" , ["list" , "array" , "series" , "polars_series" , "pyarrow_array" ]
276+ )
264277def test_safe_indexing_1d_scalar (array_type ):
265278 array = _convert_container ([1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 ], array_type )
266279 indices = 2
@@ -275,6 +288,7 @@ def test_safe_indexing_1d_scalar(array_type):
275288 ("sparse" , "sparse" ),
276289 ("dataframe" , "series" ),
277290 ("polars" , "polars_series" ),
291+ ("pyarrow" , "pyarrow_array" ),
278292 ],
279293)
280294@pytest .mark .parametrize ("indices" , [2 , "col_2" ])
@@ -284,7 +298,7 @@ def test_safe_indexing_2d_scalar_axis_1(array_type, expected_output_type, indice
284298 [[1 , 2 , 3 ], [4 , 5 , 6 ], [7 , 8 , 9 ]], array_type , columns_name
285299 )
286300
287- if isinstance (indices , str ) and array_type not in ("dataframe " , "polars " ):
301+ if isinstance (indices , str ) and array_type in ("array " , "sparse " ):
288302 err_msg = (
289303 "Specifying the columns using strings is only supported for dataframes"
290304 )
@@ -321,7 +335,9 @@ def test_safe_indexing_error_axis(axis):
321335 _safe_indexing (X_toy , [0 , 1 ], axis = axis )
322336
323337
324- @pytest .mark .parametrize ("X_constructor" , ["array" , "series" , "polars_series" ])
338+ @pytest .mark .parametrize (
339+ "X_constructor" , ["array" , "series" , "polars_series" , "pyarrow_array" ]
340+ )
325341def test_safe_indexing_1d_array_error (X_constructor ):
326342 # check that we are raising an error if the array-like passed is 1D and
327343 # we try to index on the 2nd dimension
@@ -334,6 +350,9 @@ def test_safe_indexing_1d_array_error(X_constructor):
334350 elif X_constructor == "polars_series" :
335351 pl = pytest .importorskip ("polars" )
336352 X_constructor = pl .Series (values = X )
353+ elif X_constructor == "pyarrow_array" :
354+ pa = pytest .importorskip ("pyarrow" )
355+ X_constructor = pa .array (X )
337356
338357 err_msg = "'X' should be a 2D NumPy array, 2D sparse matrix or dataframe"
339358 with pytest .raises (ValueError , match = err_msg ):
0 commit comments