@@ -134,7 +134,7 @@ def test_determine_key_type_array_api(array_namespace, device, dtype_name):
134
134
135
135
136
136
@pytest .mark .parametrize (
137
- "array_type" , ["list" , "array" , "sparse" , "dataframe" , "polars" ]
137
+ "array_type" , ["list" , "array" , "sparse" , "dataframe" , "polars" , "pyarrow" ]
138
138
)
139
139
@pytest .mark .parametrize ("indices_type" , ["list" , "tuple" , "array" , "series" , "slice" ])
140
140
def test_safe_indexing_2d_container_axis_0 (array_type , indices_type ):
@@ -149,7 +149,9 @@ def test_safe_indexing_2d_container_axis_0(array_type, indices_type):
149
149
)
150
150
151
151
152
- @pytest .mark .parametrize ("array_type" , ["list" , "array" , "series" , "polars_series" ])
152
+ @pytest .mark .parametrize (
153
+ "array_type" , ["list" , "array" , "series" , "polars_series" , "pyarrow_array" ]
154
+ )
153
155
@pytest .mark .parametrize ("indices_type" , ["list" , "tuple" , "array" , "series" , "slice" ])
154
156
def test_safe_indexing_1d_container (array_type , indices_type ):
155
157
indices = [1 , 2 ]
@@ -161,7 +163,9 @@ def test_safe_indexing_1d_container(array_type, indices_type):
161
163
assert_allclose_dense_sparse (subset , _convert_container ([2 , 3 ], array_type ))
162
164
163
165
164
- @pytest .mark .parametrize ("array_type" , ["array" , "sparse" , "dataframe" , "polars" ])
166
+ @pytest .mark .parametrize (
167
+ "array_type" , ["array" , "sparse" , "dataframe" , "polars" , "pyarrow" ]
168
+ )
165
169
@pytest .mark .parametrize ("indices_type" , ["list" , "tuple" , "array" , "series" , "slice" ])
166
170
@pytest .mark .parametrize ("indices" , [[1 , 2 ], ["col_1" , "col_2" ]])
167
171
def test_safe_indexing_2d_container_axis_1 (array_type , indices_type , indices ):
@@ -177,7 +181,7 @@ def test_safe_indexing_2d_container_axis_1(array_type, indices_type, indices):
177
181
)
178
182
indices_converted = _convert_container (indices_converted , indices_type )
179
183
180
- if isinstance (indices [0 ], str ) and array_type not in ("dataframe " , "polars " ):
184
+ if isinstance (indices [0 ], str ) and array_type in ("array " , "sparse " ):
181
185
err_msg = (
182
186
"Specifying the columns using strings is only supported for dataframes"
183
187
)
@@ -192,7 +196,9 @@ def test_safe_indexing_2d_container_axis_1(array_type, indices_type, indices):
192
196
193
197
@pytest .mark .parametrize ("array_read_only" , [True , False ])
194
198
@pytest .mark .parametrize ("indices_read_only" , [True , False ])
195
- @pytest .mark .parametrize ("array_type" , ["array" , "sparse" , "dataframe" , "polars" ])
199
+ @pytest .mark .parametrize (
200
+ "array_type" , ["array" , "sparse" , "dataframe" , "polars" , "pyarrow" ]
201
+ )
196
202
@pytest .mark .parametrize ("indices_type" , ["array" , "series" ])
197
203
@pytest .mark .parametrize (
198
204
"axis, expected_array" , [(0 , [[4 , 5 , 6 ], [7 , 8 , 9 ]]), (1 , [[2 , 3 ], [5 , 6 ], [8 , 9 ]])]
@@ -212,7 +218,9 @@ def test_safe_indexing_2d_read_only_axis_1(
212
218
assert_allclose_dense_sparse (subset , _convert_container (expected_array , array_type ))
213
219
214
220
215
- @pytest .mark .parametrize ("array_type" , ["list" , "array" , "series" , "polars_series" ])
221
+ @pytest .mark .parametrize (
222
+ "array_type" , ["list" , "array" , "series" , "polars_series" , "pyarrow_array" ]
223
+ )
216
224
@pytest .mark .parametrize ("indices_type" , ["list" , "tuple" , "array" , "series" ])
217
225
def test_safe_indexing_1d_container_mask (array_type , indices_type ):
218
226
indices = [False ] + [True ] * 2 + [False ] * 6
@@ -222,7 +230,9 @@ def test_safe_indexing_1d_container_mask(array_type, indices_type):
222
230
assert_allclose_dense_sparse (subset , _convert_container ([2 , 3 ], array_type ))
223
231
224
232
225
- @pytest .mark .parametrize ("array_type" , ["array" , "sparse" , "dataframe" , "polars" ])
233
+ @pytest .mark .parametrize (
234
+ "array_type" , ["array" , "sparse" , "dataframe" , "polars" , "pyarrow" ]
235
+ )
226
236
@pytest .mark .parametrize ("indices_type" , ["list" , "tuple" , "array" , "series" ])
227
237
@pytest .mark .parametrize (
228
238
"axis, expected_subset" ,
@@ -250,6 +260,7 @@ def test_safe_indexing_2d_mask(array_type, indices_type, axis, expected_subset):
250
260
("sparse" , "sparse" ),
251
261
("dataframe" , "series" ),
252
262
("polars" , "polars_series" ),
263
+ ("pyarrow" , "pyarrow_array" ),
253
264
],
254
265
)
255
266
def test_safe_indexing_2d_scalar_axis_0 (array_type , expected_output_type ):
@@ -260,7 +271,9 @@ def test_safe_indexing_2d_scalar_axis_0(array_type, expected_output_type):
260
271
assert_allclose_dense_sparse (subset , expected_array )
261
272
262
273
263
- @pytest .mark .parametrize ("array_type" , ["list" , "array" , "series" , "polars_series" ])
274
+ @pytest .mark .parametrize (
275
+ "array_type" , ["list" , "array" , "series" , "polars_series" , "pyarrow_array" ]
276
+ )
264
277
def test_safe_indexing_1d_scalar (array_type ):
265
278
array = _convert_container ([1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 ], array_type )
266
279
indices = 2
@@ -275,6 +288,7 @@ def test_safe_indexing_1d_scalar(array_type):
275
288
("sparse" , "sparse" ),
276
289
("dataframe" , "series" ),
277
290
("polars" , "polars_series" ),
291
+ ("pyarrow" , "pyarrow_array" ),
278
292
],
279
293
)
280
294
@pytest .mark .parametrize ("indices" , [2 , "col_2" ])
@@ -284,7 +298,7 @@ def test_safe_indexing_2d_scalar_axis_1(array_type, expected_output_type, indice
284
298
[[1 , 2 , 3 ], [4 , 5 , 6 ], [7 , 8 , 9 ]], array_type , columns_name
285
299
)
286
300
287
- if isinstance (indices , str ) and array_type not in ("dataframe " , "polars " ):
301
+ if isinstance (indices , str ) and array_type in ("array " , "sparse " ):
288
302
err_msg = (
289
303
"Specifying the columns using strings is only supported for dataframes"
290
304
)
@@ -321,7 +335,9 @@ def test_safe_indexing_error_axis(axis):
321
335
_safe_indexing (X_toy , [0 , 1 ], axis = axis )
322
336
323
337
324
- @pytest .mark .parametrize ("X_constructor" , ["array" , "series" , "polars_series" ])
338
+ @pytest .mark .parametrize (
339
+ "X_constructor" , ["array" , "series" , "polars_series" , "pyarrow_array" ]
340
+ )
325
341
def test_safe_indexing_1d_array_error (X_constructor ):
326
342
# check that we are raising an error if the array-like passed is 1D and
327
343
# we try to index on the 2nd dimension
@@ -334,6 +350,9 @@ def test_safe_indexing_1d_array_error(X_constructor):
334
350
elif X_constructor == "polars_series" :
335
351
pl = pytest .importorskip ("polars" )
336
352
X_constructor = pl .Series (values = X )
353
+ elif X_constructor == "pyarrow_array" :
354
+ pa = pytest .importorskip ("pyarrow" )
355
+ X_constructor = pa .array (X )
337
356
338
357
err_msg = "'X' should be a 2D NumPy array, 2D sparse matrix or dataframe"
339
358
with pytest .raises (ValueError , match = err_msg ):
0 commit comments