8000 FIX cache of OpenML fetcher (#12246) · lithuak/scikit-learn@afa0694 · GitHub
[go: up one dir, main page]

Skip to content

Commit afa0694

Browse files
janvanrijnjnothman
authored andcommitted
FIX cache of OpenML fetcher (scikit-learn#12246)
1 parent 0bbb7d0 commit afa0694

File tree

3 files changed

+62
-17
lines changed

3 files changed

+62
-17
lines changed

doc/whats_new/v0.20.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ Changelog
3737
:mod:`sklearn.datasets`
3838
............................
3939

40+
- |Fix| :func:`dataset.fetch_openml` to correctly use the local cache.
41+
:issue:`12246` by :user:`Jan N. van Rijn <janvanrijn>`.
42+
4043
- |Fix| Fixed integer overflow in :func:`datasets.make_classification`
4144
for values of ``n_informative`` parameter larger than 64.
4245
:issue:10811 by :user:`Roman Feldbauer <VarIr>`.

sklearn/datasets/openml.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@
3131
_DATA_FILE = "data/v1/download/{}"
3232

3333

34+
def _get_local_path(openml_path, data_home):
35+
return os.path.join(data_home, 'openml.org', openml_path + ".gz")
36+
37+
3438
def _open_openml_url(openml_path, data_home):
3539
"""
3640
Returns a resource from OpenML.org. Caches it to data_home if required.
@@ -50,37 +54,45 @@ def _open_openml_url(openml_path, data_home):
5054
result : stream
5155
A stream to the OpenML resource
5256
"""
57+
def is_gzip(_fsrc):
58+
return _fsrc.info().get('Content-Encoding', '') == 'gzip'
59+
5360
req = Request(_OPENML_PREFIX + openml_path)
5461
req.add_header('Accept-encoding', 'gzip')
55-
fsrc = urlopen(req)
56-
is_gzip = fsrc.info().get('Content-Encoding', '') == 'gzip'
5762

5863
if data_home is None:
59-
if is_gzip:
64+
fsrc = urlopen(req)
65+
if is_gzip(fsrc):
6066
if PY2:
6167
fsrc = BytesIO(fsrc.read())
6268
return gzip.GzipFile(fileobj=fsrc, mode='rb')
6369
return fsrc
6470

65-
local_path = os.path.join(data_home, 'openml.org', openml_path + ".gz")
71+
local_path = _get_local_path(openml_path, data_home)
6672
if not os.path.exists(local_path):
73+
fsrc = urlopen(req)
6774
try:
6875
os.makedirs(os.path.dirname(local_path))
6976
except OSError:
7077
# potentially, the directory has been created already
7178
pass
7279

7380
try:
74-
with open(local_path, 'wb') as fdst:
75-
shutil.copyfileobj(fsrc, fdst)
76-
fsrc.close()
81+
if is_gzip(fsrc):
82+
with open(local_path, 'wb') as fdst:
83+
shutil.copyfileobj(fsrc, fdst)
84+
fsrc.close()
85+
else:
86+
with gzip.GzipFile(local_path, 'wb') as fdst:
87+
shutil.copyfileobj(fsrc, fdst)
88+
fsrc.close()
7789
except Exception:
7890
os.unlink(local_path)
7991
raise
80-
# XXX: unnecessary decompression on first access
81-
if is_gzip:
82-
return gzip.GzipFile(local_path, 'rb')
83-
return fsrc
92+
93+
# XXX: First time, decompression will not be necessary (by using fsrc), but
94+
# it will happen nonetheless
95+
return gzip.GzipFile(local_path, 'rb')
8496

8597

8698
def _get_json_content_from_openml_api(url, error_message, raise_if_error,

sklearn/datasets/tests/test_openml.py

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
from sklearn.datasets import fetch_openml
1313
from sklearn.datasets.openml import (_open_openml_url,
1414
_get_data_description_by_id,
15-
_download_data_arff)
15+
_download_data_arff,
16+
_get_local_path)
1617
from sklearn.utils.testing import (assert_warns_message,
1718
assert_raise_message)
1819
from sklearn.externals.six import string_types
@@ -77,6 +78,8 @@ def _fetch_dataset_from_openml(data_id, data_name, data_version,
7778
cache=False)
7879
assert int(data_by_name_id.details['id']) == data_id
7980

81+
# Please note that cache=False is crucial, as the monkey patched files are
82+
# not consistent with reality
8083
fetch_openml(name=data_name, cache=False)
8184
# without specifying the version, there is no guarantee that the data id
8285
# will be the same
@@ -138,6 +141,9 @@ def _fetch_dataset_from_openml(data_id, data_name, data_version,
138141
def _monkey_patch_webbased_functions(context,
139142
data_id,
140143
gzip_response):
144+
# monkey patches the urlopen function. Important note: Do NOT use this
145+
# in combination with a regular cache directory, as the files that are
146+
# stored as cache should not be mixed up with real openml datasets
141147
url_prefix_data_description = "https://openml.org/api/v1/json/data/"
142148
url_prefix_data_features = "https://openml.org/api/v1/json/data/features/"
143149
url_prefix_download_data = "https://openml.org/data/v1/"
@@ -453,23 +459,47 @@ def test_decode_emotions(monkeypatch):
453459

454460

455461
@pytest.mark.parametrize('gzip_response', [True, False])
456-
def test_open_openml_url_cache(monkeypatch, gzip_response):
462+
def test_open_openml_url_cache(monkeypatch, gzip_response, tmpdir):
457463
data_id = 61
458464

459465
_monkey_patch_webbased_functions(
460466
monkeypatch, data_id, gzip_response)
461467
openml_path = sklearn.datasets.openml._DATA_FILE.format(data_id)
462-
test_directory = os.path.join(os.path.expanduser('~'), 'scikit_learn_data')
468+
cache_directory = str(tmpdir.mkdir('scikit_learn_data'))
463469
# first fill the cache
464-
response1 = _open_openml_url(openml_path, test_directory)
470+
response1 = _open_openml_url(openml_path, cache_directory)
465471
# assert file exists
466-
location = os.path.join(test_directory, 'openml.org', openml_path + '.gz')
472+
location = _get_local_path(openml_path, cache_directory)
467473
assert os.path.isfile(location)
468474
# redownload, to utilize cache
469-
response2 = _open_openml_url(openml_path, test_directory)
475+
response2 = _open_openml_url(openml_path, cache_directory)
470476
assert response1.read() == response2.read()
471477

472478

479+
@pytest.mark.parametrize('gzip_response', [True, False])
480+
def test_fetch_openml_cache(monkeypatch, gzip_response, tmpdir):
481+
def _mock_urlopen_raise(request):
482+
raise ValueError('This mechanism intends to test correct cache'
483+
'handling. As such, urlopen should never be '
484+
'accessed. URL: %s' % request.get_full_url())
485+
data_id = 2
486+
cache_directory = str(tmpdir.mkdir('scikit_learn_data'))
487+
_monkey_patch_webbased_functions(
488+
monkeypatch, data_id, gzip_response)
489+
X_fetched, y_fetched = fetch_openml(data_id=data_id, cache=True,
490+
data_home=cache_directory,
491+
return_X_y=True)
492+
493+
monkeypatch.setattr(sklearn.datasets.openml, 'urlopen',
494+
_mock_urlopen_raise)
495+
496+
X_cached, y_cached = fetch_openml(data_id=data_id, cache=True,
497+
data_home=cache_directory,
498+
return_X_y=True)
499+
np.testing.assert_array_equal(X_fetched, X_cached)
500+
np.testing.assert_array_equal(y_fetched, y_cached)
501+
502+
473503
@pytest.mark.parametrize('gzip_response', [True, False])
474504
def test_fetch_openml_notarget(monkeypatch, gzip_response):
475505
data_id = 61

0 commit comments

Comments
 (0)
0