8000 ENH: adds the ability load datasets from OpenML containing string · scikit-learn/scikit-learn@b437471 · GitHub
[go: up one dir, main page]

Skip to content

Commit b437471

Browse files
committed
ENH: adds the ability load datasets from OpenML containing string
attributes by providing the option to ignore said attributes. Right now, an error is raised when a dataset containing string attributes (e.g., the Titanic dataset) is fetched from OpenML. This commit allows users to specify whether or not they are okay loading only a subset of the data. Closes #11819.
1 parent 4140657 commit b437471

8 files changed

+81
-23
lines changed

sklearn/datasets/openml.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -242,11 +242,10 @@ def _convert_arff_data(arff_data, col_slice_x, col_slice_y, shape=None):
242242
count = -1
243243
else:
244244
count = shape[0] * shape[1]
245-
data = np.fromiter(itertools.chain.from_iterable(arff_data),
246-
dtype='float64', count=count)
245+
data = np.array(list(itertools.chain.from_iterable(arff_data)))
247246
data = data.reshape(*shape)
248-
X = data[:, col_slice_x]
249-
y = data[:, col_slice_y]
247+
X = np.array(data[:, col_slice_x], dtype=np.float64)
248+
y = np.array(data[:, col_slice_y], dtype=np.float64)
250249
return X, y
251250
elif isinstance(arff_data, tuple):
252251
arff_data_X = _split_sparse_columns(arff_data, col_slice_x)
@@ -287,7 +286,7 @@ def _get_data_info_by_name(name, version, data_home):
287286
Returns
288287
-------
289288
first_dataset : json
290-
json representation of the first dataset object that adhired to the
289+
json representation of the first dataset object that adhered to the
291290
search criteria
292291
293292
"""
@@ -436,7 +435,8 @@ def _valid_data_column_names(features_list, target_columns):
436435

437436

438437
def fetch_openml(name=None, version='active', data_id=None, data_home=None,
439-
target_column='default-target', cache=True, return_X_y=False):
438+
ignore_strings=False, target_column='default-target',
439+
cache=True, return_X_y=False):
440440
"""Fetch dataset from openml by name or dataset id.
441441
442442
Datasets are uniquely identified by either an integer ID or by a
@@ -450,7 +450,7 @@ def fetch_openml(name=None, version='active', data_id=None, data_home=None,
450450
.. note:: EXPERIMENTAL
451451
452452
The API is experimental in version 0.20 (particularly the return value
453-
structure), and might have small backward-incompatible changes in
453+
structure), and might have small backward-incompatble changes in
454454
future releases.
455455
456456
Parameters
@@ -475,6 +475,9 @@ def fetch_openml(name=None, version='active', data_id=None, data_home=None,
475475
Specify another download and cache folder for the data sets. By default
476476
all scikit-learn data is stored in '~/scikit_learn_data' subfolders.
477477
478+
ignore_strings : boolean, default=False
479+
Whether to ignore string attributes when loading a dataset.
480+
478481
target_column : string, list or None, default 'default-target'
479482< 9E88 /td>
Specify the column name in the data to use as target. If
480483
'default-target', the standard target column a stored on the server
@@ -573,11 +576,22 @@ def fetch_openml(name=None, version='active', data_id=None, data_home=None,
573576
# download data features, meta-info about column types
574577
features_list = _get_data_features(data_id, data_home)
575578

579+
if ignore_strings:
580+
string_features = [f for f in features_list
581+
if f['data_type'] == 'string']
582+
if string_features:
583+
string_feature_names = [f['name'] for f in string_features]
584+
features_list = [f for f in features_list if f['name'] not in
585+
string_feature_names]
586+
576587
for feature in features_list:
577588
if 'true' in (feature['is_ignore'], feature['is_row_identifier']):
578589
continue
579-
if feature['data_type'] == 'string':
580-
raise ValueError('STRING attributes are not yet supported')
590+
if feature['data_type'] == 'string' and not ignore_strings:
591+
raise ValueError('STRING attributes are not yet supported.'
592+
'If you would like to return the data '
593+
'without STRING attributes, Use'
594+
'ignore_strings=True')
581595

582596
if target_column == "default-target":
583597
# determines the default target based on the data feature results
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

sklearn/datasets/tests/test_openml.py

Lines changed: 58 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def decode_column(data_bunch, col_idx):
6666

6767

6868
def _fetch_dataset_from_openml(data_id, data_name, data_version,
69-
target_column,
69+
ignore_strings, target_column,
7070
expected_observations, expected_features,
7171
expected_missing,
7272
expected_data_dtype, expected_target_dtype,
@@ -76,17 +76,18 @@ def _fetch_dataset_from_openml(data_id, data_name, data_version,
7676
# result. Note that this function can be mocked (by invoking
7777
# _monkey_patch_webbased_functions before invoking this function)
7878
data_by_name_id = fetch_openml(name=data_name, version=data_version,
79-
cache=False)
79+
ignore_strings=ignore_strings, cache=False)
8080
assert int(data_by_name_id.details['id']) == data_id
8181

8282
# Please note that cache=False is crucial, as the monkey patched files are
8383
# not consistent with reality
84-
fetch_openml(name=data_name, cache=False)
84+
fetch_openml(name=data_name, ignore_strings=ignore_strings, cache=False)
8585
# without specifying the version, there is no guarantee that the data id
8686
# will be the same
8787

8888
# fetch with dataset id
8989
data_by_id = fetch_openml(data_id=data_id, cache=False,
90+
ignore_strings=ignore_strings,
9091
target_column=target_column)
9192
assert data_by_id.details['name'] == data_name
9293
assert data_by_id.data.shape == (expected_observations, expected_features)
@@ -112,7 +113,9 @@ def _fetch_dataset_from_openml(data_id, data_name, data_version,
112113

113114
if compare_default_target:
114115
# check whether the data by id and data by id target are equal
115-
data_by_id_default = fetch_openml(data_id=data_id, cache=False)
116+
data_by_id_default = fetch_openml(data_id=data_id,
117+
ignore_strings=ignore_strings,
118+
cache=False)
116119
if data_by_id.data.dtype == np.float64:
117120
np.testing.assert_allclose(data_by_id.data,
118121
data_by_id_default.data)
@@ -133,8 +136,9 @@ def _fetch_dataset_from_openml(data_id, data_name, data_version,
133136
expected_missing)
134137

135138
# test return_X_y option
136-
fetch_func = partial(fetch_openml, data_id=data_id, cache=False,
137-
target_column=target_column)
139+
fetch_func = partial(fetch_openml, data_id=data_id,
140+
ignore_strings=ignore_strings,
141+
cache=False, target_column=target_column)
138142
check_return_X_y(data_by_id, fetch_func)
139143
return data_by_id
140144

@@ -261,6 +265,7 @@ def test_fetch_openml_iris(monkeypatch, gzip_response):
261265
data_id = 61
262266
data_name = 'iris'
263267
data_version = 1
268+
ignore_strings = False
264269
target_column = 'class'
265270
expected_observations = 150
266271
expected_features = 4
@@ -275,6 +280,7 @@ def test_fetch_openml_iris(monkeypatch, gzip_response):
275280
_fetch_dataset_from_openml,
276281
**{'data_id': data_id, 'data_name': data_name,
277282
'data_version': data_version,
283+
'ignore_strings': ignore_strings,
278284
'target_column': target_column,
279285
'expected_observations': expected_observations,
280286
'expected_features': expected_features,
@@ -298,13 +304,15 @@ def test_fetch_openml_iris_multitarget(monkeypatch, gzip_response):
298304
data_id = 61
299305
data_name = 'iris'
300306
data_version = 1
307+
ignore_strings = False
301308
target_column = ['sepallength', 'sepalwidth']
302309
expected_observations = 150
303310
expected_features = 3
304311
expected_missing = 0
305312

306313
_monkey_patch_webbased_functions(monkeypatch, data_id, gzip_response)
307-
_fetch_dataset_from_openml(data_id, data_name, data_version, target_column,
314+
_fetch_dataset_from_openml(data_id, data_name, data_version,
315+
ignore_strings, target_column,
308316
expected_observations, expected_features,
309317
expected_missing,
310318
object, np.float64, expect_sparse=False,
@@ -317,13 +325,15 @@ def test_fetch_openml_anneal(monkeypatch, gzip_response):
317325
data_id = 2
318326
data_name = 'anneal'
319327
data_version = 1
328+
ignore_strings = False
320329
target_column = 'class'
321330
# Not all original instances included for space reasons
322331
expected_observations = 11
323332
expected_features = 38
324333
expected_missing = 267
325334
_monkey_patch_webbased_functions(monkeypatch, data_id, gzip_response)
326-
_fetch_dataset_from_openml(data_id, data_name, data_version, target_column,
335+
_fetch_dataset_from_openml(data_id, data_name, data_version,
336+
ignore_strings, target_column,
327337
expected_observations, expected_features,
328338
expected_missing,
329339
object, object, expect_sparse=False,
@@ -342,13 +352,15 @@ def test_fetch_openml_anneal_multitarget(monkeypatch, gzip_response):
342352
data_id = 2
343353
data_name = 'anneal'
344354
data_version = 1
355+
ignore_strings = False
345356
target_column = ['class', 'product-type', 'shape']
346357
# Not all original instances included for space reasons
347358
expected_observations = 11
348359
expected_features = 36
349360
expected_missing = 267
350361
_monkey_patch_webbased_functions(monkeypatch, data_id, gzip_response)
351-
_fetch_dataset_from_openml(data_id, data_name, data_version, target_column,
362+
_fetch_dataset_from_openml(data_id, data_name, data_version,
363+
ignore_strings, target_column,
352364
expected_observations, expected_features,
353365
expected_missing,
354366
object, object, expect_sparse=False,
@@ -361,12 +373,14 @@ def test_fetch_openml_cpu(monkeypatch, gzip_response):
361373
data_id = 561
362374
data_name = 'cpu'
363375
data_version = 1
376+
ignore_strings = False
364377
target_column = 'class'
365378
expected_observations = 209
366379
expected_features = 7
367380
expected_missing = 0
368381
_monkey_patch_webbased_functions(monkeypatch, data_id, gzip_response)
369-
_fetch_dataset_from_openml(data_id, data_name, data_version, target_column,
382+
_fetch_dataset_from_openml(data_id, data_name, data_version,
383+
ignore_strings, target_column,
370384
expected_observations, expected_features,
371385
expected_missing,
372386
object, np.float64, expect_sparse=False,
@@ -388,6 +402,7 @@ def test_fetch_openml_australian(monkeypatch, gzip_response):
388402
data_id = 292
389403
data_name = 'Australian'
390404
data_version = 1
405+
ignore_strings = False
391406
target_column = 'Y'
392407
# Not all original instances included for space reasons
393408
expected_observations = 85
@@ -400,6 +415,7 @@ def test_fetch_openml_australian(monkeypatch, gzip_response):
400415
_fetch_dataset_from_openml,
401416
**{'data_id': data_id, 'data_name': data_name,
402417
'data_version': data_version,
418+
'ignore_strings': ignore_strings,
403419
'target_column': target_column,
404420
'expected_observations': expected_observations,
405421
'expected_features': expected_features,
@@ -417,13 +433,15 @@ def test_fetch_openml_adultcensus(monkeypatch, gzip_response):
417433
data_id = 1119
418434
data_name = 'adult-census'
419435
data_version = 1
436+
ignore_strings = False
420437
target_column = 'class'
421438
# Not all original instances included for space reasons
422439
expected_observations = 10
423440
expected_features = 14
424441
expected_missing = 0
425442
_monkey_patch_webbased_functions(monkeypatch, data_id, gzip_response)
426-
_fetch_dataset_from_openml(data_id, data_name, data_version, target_column,
443+
_fetch_dataset_from_openml(data_id, data_name, data_version,
444+
ignore_strings, target_column,
427445
expected_observations, expected_features,
428446
expected_missing,
429447
np.float64, object, expect_sparse=False,
@@ -439,13 +457,15 @@ def test_fetch_openml_miceprotein(monkeypatch, gzip_response):
439457
data_id = 40966
440458
data_name = 'MiceProtein'
441459
data_version = 4
460+
ignore_strings = False
442461
target_column = 'class'
443462
# Not all original instances included for space reasons
444463
expected_observations = 7
445464
expected_features = 77
446465
expected_missing = 7
447466
_monkey_patch_webbased_functions(monkeypatch, data_id, gzip_response)
448-
_fetch_dataset_from_openml(data_id, data_name, data_version, target_column,
467+
_fetch_dataset_from_openml(data_id, data_name, data_version,
468+
ignore_strings, target_column,
449469
expected_observations, expected_features,
450470
expected_missing,
451471
np.float64, object, expect_sparse=False,
@@ -458,14 +478,16 @@ def test_fetch_openml_emotions(monkeypatch, gzip_response):
458478
data_id = 40589
459479
data_name = 'emotions'
460480
data_version = 3
481+
ignore_strings = False
461482
target_column = ['amazed.suprised', 'happy.pleased', 'relaxing.calm',
462483
'quiet.still', 'sad.lonely', 'angry.aggresive']
463484
expected_observations = 13
464485
expected_features = 72
465486
expected_missing = 0
466487
_monkey_patch_webbased_functions(monkeypatch, data_id, gzip_response)
467488

468-
_fetch_dataset_from_openml(data_id, data_name, data_version, target_column,
489+
_fetch_dataset_from_openml(data_id, data_name, data_version,
490+
ignore_strings, target_column,
469491
expected_observations, expected_features,
470492
expected_missing,
471493
np.float64, object, expect_sparse=False,
@@ -478,6 +500,27 @@ def test_decode_emotions(monkeypatch):
478500
_test_features_list(data_id)
479501

480502

503+
@pytest.mark.parametrize('gzip_response', [True, False])
504+
def test_fetch_titanic(monkeypatch, gzip_response):
505+
# check because of the string attributes
506+
data_id = 40945
507+
data_name = 'Titanic'
508+
data_version = 1
509+
ignore_strings = True
510+
target_column = 'survived'
511+
# Not all original features included because five are strings
512+
expected_observations = 1309
513+
expected_features = 8
514+
expected_missing = 1454
515+
_monkey_patch_webbased_functions(monkeypatch, data_id, gzip_response)
516+
_fetch_dataset_from_openml(data_id, data_name, data_version,
517+
ignore_strings, target_column,
518+
expected_observations, expected_features,
519+
expected_missing,
520+
np.float64, object, expect_sparse=False,
521+
compare_default_target=True)
522+
523+
481524
@pytest.mark.parametrize('gzip_response', [True, False])
482525
def test_open_openml_url_cache(monkeypatch, gzip_response, tmpdir):
483526
data_id = 61
@@ -667,7 +710,8 @@ def test_string_attribute(monkeypatch, gzip_response):
667710
# single column test
668711
assert_raise_message(ValueError,
669712
'STRING attributes are not yet supported',
670-
fetch_openml, data_id=data_id, cache=False)
713+
fetch_openml, data_id=data_id, ignore_strings=False,
714+
cache=False)
671715

672716

673717
@pytest.mark.parametrize('gzip_response', [True, False])

0 commit comments

Comments
 (0)
0