7
7
8
8
try :
9
9
# Python 3+
10
- from urllib .request import urlopen
10
+ from urllib .request import urlopen , Request
11
11
except ImportError :
12
12
# Python 2
13
- from urllib2 import urlopen
13
+ from urllib2 import urlopen , Request
14
14
15
15
16
16
import numpy as np
17
17
import scipy .sparse
18
18
19
19
from sklearn .externals import _arff
20
20
from .base import get_data_home
21
- from ..externals .six import string_types , PY2
21
+ from ..externals .six import string_types , PY2 , BytesIO
22
22
from ..externals .six .moves .urllib .error import HTTPError
23
23
from ..utils import Bunch
24
24
@@ -50,8 +50,18 @@ def _open_openml_url(openml_path, data_home):
50
50
result : stream
51
51
A stream to the OpenML resource
52
52
"""
53
+ req = Request (_OPENML_PREFIX + openml_path )
54
+ req .add_header ('Accept-encoding' , 'gzip' )
55
+ fsrc = urlopen (req )
56
+ is_gzip = fsrc .info ().get ('Content-Encoding' , '' ) == 'gzip'
57
+
53
58
if data_home is None :
54
- return urlopen (_OPENML_PREFIX + openml_path )
59
+ if is_gzip :
60
+ if PY2 :
61
+ fsrc = BytesIO (fsrc .read ())
62
+ return gzip .GzipFile (fileobj = fsrc , mode = 'rb' )
63
+ return fsrc
64
+
55
65
local_path = os .path .join (data_home , 'openml.org' , openml_path + ".gz" )
56
66
if not os .path .exists (local_path ):
57
67
try :
@@ -61,15 +71,16 @@ def _open_openml_url(openml_path, data_home):
61
71
pass
62
72
63
73
try :
64
- with gzip .GzipFile (local_path , 'wb' ) as fdst :
65
- fsrc = urlopen (_OPENML_PREFIX + openml_path )
74 + with open (local_path , 'wb' ) as fdst :
66
75
shutil .copyfileobj (fsrc , fdst )
67
76
fsrc .close ()
68
77
except Exception :
69
78
os .unlink (local_path )
70
79
raise
71
80
# XXX: unnecessary decompression on first access
72
- return gzip .GzipFile (local_path , 'rb' )
81
+ if is_gzip :
82
+ return gzip .GzipFile (local_path , 'rb' )
83
+ return fsrc
73
84
74
85
75
86
def _get_json_content_from_openml_api (url , error_message , raise_if_error ,
@@ -308,7 +319,7 @@ def _download_data_arff(file_id, sparse, data_home, encode_nominal=True):
308
319
return_type = _arff .DENSE
309
320
310
321
if PY2 :
311
- arff_file = _arff .load (response , encode_nominal = encode_nominal ,
322
+ arff_file = _arff .load (response . read () , encode_nominal = encode_nominal ,
312
323
return_type = return_type , )
313
324
else :
314
325
arff_file = _arff .loads (response .read ().decode ('utf-8' ),
0 commit comments