8000 return bytes instead of stream, read once · scikit-learn/scikit-learn@831d78b · GitHub
[go: up one dir, main page]

Skip to content

Commit 831d78b

Browse files
return bytes instead of stream, read once
1 parent f4ca32b commit 831d78b

File tree

2 files changed

+40
-40
lines changed

2 files changed

+40
-40
lines changed

sklearn/datasets/openml.py

Lines changed: 35 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
import gzip
22
import json
33
import os
4-
from io import BytesIO
54
import hashlib
6-
import shutil
75
from os.path import join
86
from warnings import warn
97
from contextlib import closing
@@ -63,7 +61,7 @@ def wrapper():
6361
return decorator
6462

6563

66-
def _open_openml_url(openml_path, data_home, expected_md5_checksum=None):
64+
def _openml_url_bytes(openml_path, data_home, expected_md5_checksum=None):
6765
"""
6866
Returns a resource from OpenML.org. Caches it to data_home if required.
6967
@@ -79,49 +77,47 @@ def _open_openml_url(openml_path, data_home, expected_md5_checksum=None):
7977
8078
Returns
8179
-------
82-
result : stream
83-
A stream to the OpenML resource
80+
result : bytes
81+
Byte content of resource
8482
"""
8583
def is_gzip(_fsrc):
8684
return _fsrc.info().get('Content-Encoding', '') == 'gzip'
8785

8886
req = Request(_OPENML_PREFIX + openml_path)
8987
req.add_header('Accept-encoding', 'gzip')
9088

91-
def _md5_validated_stream(input_stream, md5_checksum):
89+
def _md5_validated_bytes(bytes_content, md5_checksum):
9290
"""
9391
Consume binary stream to validate checksum,
9492
return a new stream with same content
9593
9694
Parameters
9795
----------
98-
input_stream : io.BufferedIOBase
99-
Input stream with a read() method to get content in bytes
96+
bytes_content : bytes
10097
10198
md5_checksum: str
102-
Expected md5 checksum
99+
Expected md5 checksum of bytes
103100
104101
Returns
105102
-------
106-
BytesIO stream with the same content as input_stream for consumption
103+
bytes
107104
"""
108-
with closing(input_stream):
109-
bytes_content = input_stream.read()
110-
actual_md5_checksum = hashlib.md5(bytes_content).hexdigest()
111-
if md5_checksum != actual_md5_checksum:
112-
raise ValueError("md5checksum: {} does not match expected: "
113-
"{}".format(actual_md5_checksum,
114-
md5_checksum))
115-
return BytesIO(bytes_content)
105+
actual_md5_checksum = hashlib.md5(bytes_content).hexdigest()
106+
if md5_checksum != actual_md5_checksum:
107+
raise ValueError("md5checksum: {} does not match expected: "
108+
"{}".format(actual_md5_checksum,
109+
md5_checksum))
110+
return bytes_content
116111

117112
if data_home is None:
118113
fsrc = urlopen(req)
119114
if is_gzip(fsrc):
120115
fsrc = gzip.GzipFile(fileobj=fsrc, mode='rb')
116+
bytes_content = fsrc.read()
121117
if expected_md5_checksum:
122118
# validating checksum reads and consumes the stream
123-
return _md5_validated_stream(fsrc, expected_md5_checksum)
124-
return fsrc
119+
return _md5_validated_bytes(bytes_content, expected_md5_checksum)
120+
return bytes_content
125121

126122
local_path = _get_local_path(openml_path, data_home)
127123
if not os.path.exists(local_path):
@@ -135,18 +131,23 @@ def _md5_validated_stream(input_stream, md5_checksum):
135131
with closing(urlopen(req)) as fsrc:
136132
if is_gzip(fsrc): # unzip it for checksum validation
137133
fsrc = gzip.GzipFile(fileobj=fsrc, mode='rb')
134+
bytes_content = fsrc.read()
138135
if expected_md5_checksum:
139-
fsrc = _md5_validated_stream(fsrc, expected_md5_checksum)
136+
by 10000 tes_content = _md5_validated_bytes(bytes_content,
137+
expected_md5_checksum)
140138
with gzip.GzipFile(local_path, 'wb') as fdst:
141-
shutil.copyfileobj(fsrc, fdst)
139+
fdst.write(bytes_content)
142140
except Exception:
143141
if os.path.exists(local_path):
144142
os.unlink(local_path)
145143
raise
144+
else:
145+
with gzip.GzipFile(local_path, "rb") as gzip_file:
146+
bytes_content = gzip_file.read()
146147

147148
# XXX: First time, decompression will not be necessary (by using fsrc), but
148149
# it will happen nonetheless
149-
return gzip.GzipFile(local_path, 'rb')
150+
return bytes_content
150151

151152

152153
def _get_json_content_from_openml_api(url, error_message, raise_if_error,
@@ -183,8 +184,7 @@ def _get_json_content_from_openml_api(url, error_message, raise_if_error,
183184

184185
@_retry_with_clean_cache(url, data_home)
185186
def _load_json():
186-
with closing(_open_openml_url(url, data_home)) as response:
187-
return json.loads(response.read().decode("utf-8"))
187+
return json.loads(_openml_url_bytes(url, data_home).decode("utf-8"))
188188

189189
try:
190190
return _load_json()
@@ -489,16 +489,16 @@ def _download_data_arff(file_id, sparse, data_home, encode_nominal=True,
489489

490490
@_retry_with_clean_cache(url, data_home)
491491
def _arff_load():
492-
with closing(_open_openml_url(url, data_home, expected_md5_checksum)) \
493-
as response:
494-
if sparse is True:
495-
return_type = _arff.COO
496-
else:
497-
return_type = _arff.DENSE_GEN
498-
499-
arff_file = _arff.loads(response.read().decode('utf-8'),
500-
encode_nominal=encode_nominal,
501-
return_type=return_type)
492+
bytes_content = _openml_url_bytes(url, data_home,
493+
expected_md5_checksum)
494+
if sparse is True:
495+
return_type = _arff.COO
496+
else:
497+
return_type = _arff.DENSE_GEN
498+
499+
arff_file = _arff.loads(bytes_content.decode('utf-8'),
500+
encode_nominal=encode_nominal,
501+
return_type=return_type)
502502
return arff_file
503503

504504
return _arff_load()

sklearn/datasets/tests/test_openml.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
from sklearn import config_context
1414
from sklearn.datasets import fetch_openml
15-
from sklearn.datasets.openml import (_open_openml_url,
15+
from sklearn.datasets.openml import (_openml_url_bytes,
1616
_get_data_description_by_id,
1717
_download_data_arff,
1818
_get_local_path,
@@ -922,13 +922,13 @@ def test_open_openml_url_cache(monkeypatch, gzip_response, tmpdir):
922922
openml_path = sklearn.datasets.openml._DATA_FILE.format(data_id)
923923
cache_directory = str(tmpdir.mkdir('scikit_learn_data'))
924924
# first fill the cache
925-
response1 = _open_openml_url(openml_path, cache_directory)
925+
response1 = _openml_url_bytes(openml_path, cache_directory)
926926
# assert file exists
927927
location = _get_local_path(openml_path, cache_directory)
928928
assert os.path.isfile(location)
929929
# redownload, to utilize cache
930-
response2 = _open_openml_url(openml_path, cache_directory)
931-
assert response1.read() == response2.read()
930+
response2 = _openml_url_bytes(openml_path, cache_directory)
931+
assert response1 == response2
932932

933933

934934
@pytest.mark.parametrize('gzip_response', [True, False])
@@ -949,7 +949,7 @@ def _mock_urlopen(request):
949949
monkeypatch.setattr(sklearn.datasets.openml, 'urlopen', _mock_urlopen)
950950

951951
with pytest.raises(ValueError, match="Invalid request"):
952-
_open_openml_url(openml_path, cache_directory)
952+
_openml_url_bytes(openml_path, cache_directory)
953953

954954
assert not os.path.exists(location)
955955

0 commit comments

Comments
 (0)
0