|
12 | 12 | from sklearn.datasets import fetch_openml
|
13 | 13 | from sklearn.datasets.openml import (_open_openml_url,
|
14 | 14 | _get_data_description_by_id,
|
15 |
| - _download_data_arff) |
| 15 | + _download_data_arff, |
| 16 | + _get_local_path) |
16 | 17 | from sklearn.utils.testing import (assert_warns_message,
|
17 | 18 | assert_raise_message)
|
18 | 19 | from sklearn.externals.six import string_types
|
@@ -77,6 +78,8 @@ def _fetch_dataset_from_openml(data_id, data_name, data_version,
|
77 | 78 | cache=False)
|
78 | 79 | assert int(data_by_name_id.details['id']) == data_id
|
79 | 80 |
|
| 81 | + # Please note that cache=False is crucial, as the monkey patched files are |
| 82 | + # not consistent with reality |
80 | 83 | fetch_openml(name=data_name, cache=False)
|
81 | 84 | # without specifying the version, there is no guarantee that the data id
|
82 | 85 | # will be the same
|
@@ -138,6 +141,9 @@ def _fetch_dataset_from_openml(data_id, data_name, data_version,
|
138 | 141 | def _monkey_patch_webbased_functions(context,
|
139 | 142 | data_id,
|
140 | 143 | gzip_response):
|
| 144 | + # monkey patches the urlopen function. Important note: Do NOT use this |
| 145 | + # in combination with a regular cache directory, as the files that are |
| 146 | + # stored as cache should not be mixed up with real openml datasets |
141 | 147 | url_prefix_data_description = "https://openml.org/api/v1/json/data/"
|
142 | 148 | url_prefix_data_features = "https://openml.org/api/v1/json/data/features/"
|
143 | 149 | url_prefix_download_data = "https://openml.org/data/v1/"
|
@@ -453,23 +459,47 @@ def test_decode_emotions(monkeypatch):
|
453 | 459 |
|
454 | 460 |
|
455 | 461 | @pytest.mark.parametrize('gzip_response', [True, False])
|
456 |
| -def test_open_openml_url_cache(monkeypatch, gzip_response): |
| 462 | +def test_open_openml_url_cache(monkeypatch, gzip_response, tmpdir): |
457 | 463 | data_id = 61
|
458 | 464 |
|
459 | 465 | _monkey_patch_webbased_functions(
|
460 | 466 | monkeypatch, data_id, gzip_response)
|
461 | 467 | openml_path = sklearn.datasets.openml._DATA_FILE.format(data_id)
|
462 |
| - test_directory = os.path.join(os.path.expanduser('~'), 'scikit_learn_data') |
| 468 | + cache_directory = str(tmpdir.mkdir('scikit_learn_data')) |
463 | 469 | # first fill the cache
|
464 |
| - response1 = _open_openml_url(openml_path, test_directory) |
| 470 | + response1 = _open_openml_url(openml_path, cache_directory) |
465 | 471 | # assert file exists
|
466 |
| - location = os.path.join(test_directory, 'openml.org', openml_path + '.gz') |
| 472 | + location = _get_local_path(openml_path, cache_directory) |
467 | 473 | assert os.path.isfile(location)
|
468 | 474 | # redownload, to utilize cache
|
469 |
| - response2 = _open_openml_url(openml_path, test_directory) |
| 475 | + response2 = _open_openml_url(openml_path, cache_directory) |
470 | 476 | assert response1.read() == response2.read()
|
471 | 477 |
|
472 | 478 |
|
| 479 | +@pytest.mark.parametrize('gzip_response', [True, False]) |
| 480 | +def test_fetch_openml_cache(monkeypatch, gzip_response, tmpdir): |
| 481 | + def _mock_urlopen_raise(request): |
| 482 | + raise ValueError('This mechanism intends to test correct cache' |
| 483 | + 'handling. As such, urlopen should never be ' |
| 484 | + 'accessed. URL: %s' % request.get_full_url()) |
| 485 | + data_id = 2 |
| 486 | + cache_directory = str(tmpdir.mkdir('scikit_learn_data')) |
| 487 | + _monkey_patch_webbased_functions( |
| 488 | + monkeypatch, data_id, gzip_response) |
| 489 | + X_fetched, y_fetched = fetch_openml(data_id=data_id, cache=True, |
| 490 | + data_home=cache_directory, |
| 491 | + return_X_y=True) |
| 492 | + |
| 493 | + monkeypatch.setattr(sklearn.datasets.openml, 'urlopen', |
| 494 | + _mock_urlopen_raise) |
| 495 | + |
| 496 | + X_cached, y_cached = fetch_openml(data_id=data_id, cache=True, |
| 497 | + data_home=cache_directory, |
| 498 | + return_X_y=True) |
| 499 | + np.testing.assert_array_equal(X_fetched, X_cached) |
| 500 | + np.testing.assert_array_equal(y_fetched, y_cached) |
| 501 | + |
| 502 | + |
473 | 503 | @pytest.mark.parametrize('gzip_response', [True, False])
|
474 | 504 | def test_fetch_openml_notarget(monkeypatch, gzip_response):
|
475 | 505 | data_id = 61
|
|
0 commit comments