@@ -66,7 +66,7 @@ def decode_column(data_bunch, col_idx):
66
66
67
67
68
68
def _fetch_dataset_from_openml (data_id , data_name , data_version ,
69
- target_column ,
69
+ ignore_strings , target_column ,
70
70
expected_observations , expected_features ,
71
71
expected_missing ,
72
72
expected_data_dtype , expected_target_dtype ,
@@ -76,17 +76,18 @@ def _fetch_dataset_from_openml(data_id, data_name, data_version,
76
76
# result. Note that this function can be mocked (by invoking
77
77
# _monkey_patch_webbased_functions before invoking this function)
78
78
data_by_name_id = fetch_openml (name = data_name , version = data_version ,
79
- cache = False )
79
+ ignore_strings = ignore_strings , cache = False )
80
80
assert int (data_by_name_id .details ['id' ]) == data_id
81
81
82
82
# Please note that cache=False is crucial, as the monkey patched files are
83
83
# not consistent with reality
84
- fetch_openml (name = data_name , cache = False )
84
+ fetch_openml (name = data_name , ignore_strings = ignore_strings , cache = False )
85
85
# without specifying the version, there is no guarantee that the data id
86
86
# will be the same
87
87
88
88
# fetch with dataset id
89
89
data_by_id = fetch_openml (data_id = data_id , cache = False ,
90
+ ignore_strings = ignore_strings ,
90
91
target_column = target_column )
91
92
assert data_by_id .details ['name' ] == data_name
92
93
assert data_by_id .data .shape == (expected_observations , expected_features )
@@ -112,7 +113,9 @@ def _fetch_dataset_from_openml(data_id, data_name, data_version,
112
113
113
114
if compare_default_target :
114
115
# check whether the data by id and data by id target are equal
115
- data_by_id_default = fetch_openml (data_id = data_id , cache = False )
116
+ data_by_id_default = fetch_openml (data_id = data_id ,
117
+ ignore_strings = ignore_strings ,
118
+ cache = False )
116
119
if data_by_id .data .dtype == np .float64 :
117
120
np .testing .assert_allclose (data_by_id .data ,
118
121
data_by_id_default .data )
@@ -133,8 +136,9 @@ def _fetch_dataset_from_openml(data_id, data_name, data_version,
133
136
expected_missing )
134
137
135
138
# test return_X_y option
136
- fetch_func = partial (fetch_openml , data_id = data_id , cache = False ,
137
- target_column = target_column )
139
+ fetch_func = partial (fetch_openml , data_id = data_id ,
140
+ ignore_strings = ignore_strings ,
141
+ cache = False , target_column = target_column )
138
142
check_return_X_y (data_by_id , fetch_func )
139
143
return data_by_id
140
144
@@ -261,6 +265,7 @@ def test_fetch_openml_iris(monkeypatch, gzip_response):
261
265
data_id = 61
262
266
data_name = 'iris'
263
267
data_version = 1
268
+ ignore_strings = False
264
269
target_column = 'class'
265
270
expected_observations = 150
266
271
expected_features = 4
@@ -275,6 +280,7 @@ def test_fetch_openml_iris(monkeypatch, gzip_response):
275
280
_fetch_dataset_from_openml ,
276
281
** {'data_id' : data_id , 'data_name' : data_name ,
277
282
'data_version' : data_version ,
283
+ 'ignore_strings' : ignore_strings ,
278
284
'target_column' : target_column ,
279
285
'expected_observations' : expected_observations ,
280
286
'expected_features' : expected_features ,
@@ -298,13 +304,15 @@ def test_fetch_openml_iris_multitarget(monkeypatch, gzip_response):
298
304
data_id = 61
299
305
data_name = 'iris'
300
306
data_version = 1
307
+ ignore_strings = False
301
308
target_column = ['sepallength' , 'sepalwidth' ]
302
309
expected_observations = 150
303
310
expected_features = 3
304
311
expected_missing = 0
305
312
306
313
_monkey_patch_webbased_functions (monkeypatch , data_id , gzip_response )
307
- _fetch_dataset_from_openml (data_id , data_name , data_version , target_column ,
314
+ _fetch_dataset_from_openml (data_id , data_name , data_version ,
315
+ ignore_strings , target_column ,
308
316
expected_observations , expected_features ,
309
317
expected_missing ,
310
318
object , np .float64 , expect_sparse = False ,
@@ -317,13 +325,15 @@ def test_fetch_openml_anneal(monkeypatch, gzip_response):
317
325
data_id = 2
318
326
data_name = 'anneal'
319
327
data_version = 1
328
+ ignore_strings = False
320
329
target_column = 'class'
321
330
# Not all original instances included for space reasons
322
331
expected_observations = 11
323
332
expected_features = 38
324
333
expected_missing = 267
325
334
_monkey_patch_webbased_functions (monkeypatch , data_id , gzip_response )
326
- _fetch_dataset_from_openml (data_id , data_name , data_version , target_column ,
335
+ _fetch_dataset_from_openml (data_id , data_name , data_version ,
336
+ ignore_strings , target_column ,
327
337
expected_observations , expected_features ,
328
338
expected_missing ,
329
339
object , object , expect_sparse = False ,
@@ -342,13 +352,15 @@ def test_fetch_openml_anneal_multitarget(monkeypatch, gzip_response):
342
352
data_id = 2
343
353
data_name = 'anneal'
344
354
data_version = 1
355
+ ignore_strings = False
345
356
target_column = ['class' , 'product-type' , 'shape' ]
346
357
# Not all original instances included for space reasons
347
358
expected_observations = 11
348
359
expected_features = 36
349
360
expected_missing = 267
350
361
_monkey_patch_webbased_functions (monkeypatch , data_id , gzip_response )
351
- _fetch_dataset_from_openml (data_id , data_name , data_version , target_column ,
362
+ _fetch_dataset_from_openml (data_id , data_name , data_version ,
363
+ ignore_strings , target_column ,
352
364
expected_observations , expected_features ,
353
365
expected_missing ,
354
366
object , object , expect_sparse = False ,
@@ -361,12 +373,14 @@ def test_fetch_openml_cpu(monkeypatch, gzip_response):
361
373
data_id = 561
362
374
data_name = 'cpu'
363
375
data_version = 1
376
+ ignore_strings = False
364
377
target_column = 'class'
365
378
expected_observations = 209
366
379
expected_features = 7
367
380
expected_missing = 0
368
381
_monkey_patch_webbased_functions (monkeypatch , data_id , gzip_response )
369
- _fetch_dataset_from_openml (data_id , data_name , data_version , target_column ,
382
+ _fetch_dataset_from_openml (data_id , data_name , data_version ,
383
+ ignore_strings , target_column ,
370
384
expected_observations , expected_features ,
371
385
expected_missing ,
372
386
object , np .float64 , expect_sparse = False ,
@@ -388,6 +402,7 @@ def test_fetch_openml_australian(monkeypatch, gzip_response):
388
402
data_id = 292
389
403
data_name = 'Australian'
390
404
data_version = 1
405
+ ignore_strings = False
391
406
target_column = 'Y'
392
407
# Not all original instances included for space reasons
393
408
expected_observations = 85
@@ -400,6 +415,7 @@ def test_fetch_openml_australian(monkeypatch, gzip_response):
400
415
_fetch_dataset_from_openml ,
401
416
** {'data_id' : data_id , 'data_name' : data_name ,
402
417
'data_version' : data_version ,
418
+ 'ignore_strings' : ignore_strings ,
403
419
'target_column' : target_column ,
404
420
'expected_observations' : expected_observations ,
405
421
'expected_features' : expected_features ,
@@ -417,13 +433,15 @@ def test_fetch_openml_adultcensus(monkeypatch, gzip_response):
417
433
data_id = 1119
418
434
data_name = 'adult-census'
419
435
data_version = 1
436
+ ignore_strings = False
420
437
target_column = 'class'
421
438
# Not all original instances included for space reasons
422
439
expected_observations = 10
423
440
expected_features = 14
424
441
expected_missing = 0
425
442
_monkey_patch_webbased_functions (monkeypatch , data_id , gzip_response )
426
- _fetch_dataset_from_openml (data_id , data_name , data_version , target_column ,
443
+ _fetch_dataset_from_openml (data_id , data_name , data_version ,
444
+ ignore_strings , target_column ,
427
445
expected_observations , expected_features ,
428
446
expected_missing ,
429
447
np .float64 , object , expect_sparse = False ,
@@ -439,13 +457,15 @@ def test_fetch_openml_miceprotein(monkeypatch, gzip_response):
439
457
data_id = 40966
440
458
data_name = 'MiceProtein'
441
459
data_version = 4
460
+ ignore_strings = False
442
461
target_column = 'class'
443
462
# Not all original instances included for space reasons
444
463
expected_observations = 7
445
464
expected_features = 77
446
465
expected_missing = 7
447
466
_monkey_patch_webbased_functions (monkeypatch , data_id , gzip_response )
448
- _fetch_dataset_from_openml (data_id , data_name , data_version , target_column ,
467
+ _fetch_dataset_from_openml (data_id , data_name , data_version ,
468
+ ignore_strings , target_column ,
449
469
expected_observations , expected_features ,
450
470
expected_missing ,
451
471
np .float64 , object , expect_sparse = False ,
@@ -458,14 +478,16 @@ def test_fetch_openml_emotions(monkeypatch, gzip_response):
458
478
data_id = 40589
459
479
data_name = 'emotions'
460
480
data_version = 3
481
+ ignore_strings = False
461
482
target_column = ['amazed.suprised' , 'happy.pleased' , 'relaxing.calm' ,
462
483
'quiet.still' , 'sad.lonely' , 'angry.aggresive' ]
463
484
expected_observations = 13
464
485
expected_features = 72
465
486
expected_missing = 0
466
487
_monkey_patch_webbased_functions (monkeypatch , data_id , gzip_response )
467
488
468
- _fetch_dataset_from_openml (data_id , data_name , data_version , target_column ,
489
+ _fetch_dataset_from_openml (data_id , data_name , data_version ,
490
+ ignore_strings , target_column ,
469
491
expected_observations , expected_features ,
470
492
expected_missing ,
471
493
np .float64 , object , expect_sparse = False ,
@@ -478,6 +500,27 @@ def test_decode_emotions(monkeypatch):
478
500
_test_features_list (data_id )
479
501
480
502
503
+ @pytest .mark .parametrize ('gzip_response' , [True , False ])
504
+ def test_fetch_titanic (monkeypatch , gzip_response ):
505
+ # check because of the string attributes
506
+ data_id = 40945
507
+ data_name = 'Titanic'
508
+ data_version = 1
509
+ ignore_strings = True
510
+ target_column = 'survived'
511
+ # Not all original features included because five are strings
512
+ expected_observations = 1309
513
+ expected_features = 8
514
+ expected_missing = 1454
515
+ _monkey_patch_webbased_functions (monkeypatch , data_id , gzip_response )
516
+ _fetch_dataset_from_openml (data_id , data_name , data_version ,
517
+ ignore_strings , target_column ,
518
+ expected_observations , expected_features ,
519
+ expected_missing ,
520
+ np .float64 , object , expect_sparse = False ,
521
+ compare_default_target = True )
522
+
523
+
481
524
@pytest .mark .parametrize ('gzip_response' , [True , False ])
482
525
def test_open_openml_url_cache (monkeypatch , gzip_response , tmpdir ):
483
526
data_id = 61
@@ -667,7 +710,8 @@ def test_string_attribute(monkeypatch, gzip_response):
667
710
# single column test
668
711
assert_raise_message (ValueError ,
669
712
'STRING attributes are not yet supported' ,
670
- fetch_openml , data_id = data_id , cache = False )
713
+ fetch_openml , data_id = data_id , ignore_strings = False ,
714
+ cache = False )
671
715
672
716
673
717
@pytest .mark .parametrize ('gzip_response' , [True , False ])
0 commit comments