10000 Added scipy.misc.pilutil functions. · scikit-learn/scikit-learn@93b49c4 · GitHub
[go: up one dir, main page]

Skip to content

Commit 93b49c4

Browse files
author
Jonathan Tammo Siebert
committed
Added scipy.misc.pilutil functions.
1 parent 90a0584 commit 93b49c4

File tree

4 files changed

+295
-35
lines changed

4 files changed

+295
-35
lines changed

sklearn/datasets/base.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from os.path import dirname, exists, expanduser, isdir, join, splitext
1818
import hashlib
1919

20+
from ..externals._pilutil import _imread
2021
from ..utils import Bunch
2122
from ..utils import check_random_state
2223

@@ -766,24 +767,14 @@ def load_sample_images():
766767
>>> first_img_data.dtype #doctest: +SKIP
767768
dtype('uint8')
768769
"""
769-
# Try to import imread from scipy. We do this lazily here to prevent
770-
# this module from depending on PIL.
771-
try:
772-
try:
773-
from scipy.misc import imread
774-
except ImportError:
775-
from scipy.misc.pilutil import imread
776-
except ImportError:
777-
raise ImportError("The Python Imaging Library (PIL) "
778-
"is required to load data from jpeg files")
779770
module_path = join(dirname(__file__), "images")
780771
with open(join(module_path, 'README.txt')) as f:
781772
descr = f.read()
782773
filenames = [join(module_path, filename)
783774
for filename in os.listdir(module_path)
784775
if filename.endswith(".jpg")]
785776
# Load image data for each image in the source folder.
786-
images = [imread(filename) for filename in filenames]
777+
images = [_imread(filename) for filename in filenames]
787778

788779
return Bunch(images=images,
789780
filenames=filenames,

sklearn/datasets/lfw.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@
3131

3232
from .base import get_data_home, _fetch_remote, RemoteFileMetadata
3333
from ..utils import Bunch
34+
from ..externals._pilutil import _imread, _imresize
3435
from ..externals.joblib import Memory
35-
3636
from ..externals.six import b
3737

3838
logger = logging.getLogger(__name__)
@@ -137,18 +137,6 @@ def check_fetch_lfw(data_home=None, funneled=True, download_if_missing=True):
137137
def _load_imgs(file_paths, slice_, color, resize):
138138
"""Internally used to load images"""
139139

140-
# Try to import imread and imresize from PIL. We do this here to prevent
141-
# the whole sklearn.datasets module from depending on PIL.
142-
try:
143-
try:
144-
from scipy.misc import imread
145-
except ImportError:
146-
from scipy.misc.pilutil import imread
147-
from scipy.misc import imresize
148-
except ImportError:
149-
raise ImportError("The Python Imaging Library (PIL)"
150-
" is required to load data from jpeg files")
151-
152140
# compute the portion of the images to load to respect the slice_ parameter
153141
# given by the caller
154142
default_slice = (slice(0, 250), slice(0, 250))
@@ -181,7 +169,7 @@ def _load_imgs(file_paths, slice_, color, resize):
181169

182170
# Checks if jpeg reading worked. Refer to issue #3594 for more
183171
# details.
184-
img = imread(file_path)
172+
img = _imread(file_path)
185173
if img.ndim is 0:
186174
raise RuntimeError("Failed to read the image file %s, "
187175
"Please make sure that libjpeg is installed"
@@ -190,7 +178,7 @@ def _load_imgs(file_paths, slice_, color, resize):
190178
face = np.asarray(img[slice_], dtype=np.float32)
191179
face /= 255.0 # scale uint8 coded colors to the [0.0, 1.0] floats
192180
if resize is not None:
193-
face = imresize(face, resize)
181+
face = _imresize(face, resize)
194182
if not color:
195183
# average the color channels to compute a gray levels
196184
# representation

sklearn/datasets/tests/test_base.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from sklearn.datasets.base import Bunch
2222

2323
from sklearn.externals.six import b, u
24+
from sklearn.externals._pilutil import _have_image
2425

2526
from sklearn.utils.testing import assert_false
2627
from sklearn.utils.testing import assert_true
@@ -161,15 +162,7 @@ def test_load_sample_image():
161162

162163

163164
def test_load_missing_sample_image_error():
164-
have_PIL = True
165-
try:
166-
try:
167-
from scipy.misc import imread
168-
except ImportError:
169-
from scipy.misc.pilutil import imread # noqa
170-
except ImportError:
171-
have_PIL = False
172-
if have_PIL:
165+
if _have_image:
173166
assert_raises(AttributeError, load_sample_image,
174167
'blop.jpg')
175168
else:

sklearn/externals/_pilutil.py

Lines changed: 288 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,288 @@
1+
"""
2+
Utility functions wrapping PIL functions
3+
4+
This is a local version of utility functions from scipy that are wrapping PIL
5+
functionality. These functions are deprecated in scipy 1.0.0 and will be
6+
removed in scipy 1.2.0. Therefore, the functionality used in sklearn is
7+
copied here.
8+
9+
Copyright (c) 2001, 2002 Enthought, Inc.
10+
All rights reserved.
11+
12+
Copyright (c) 2003-2017 SciPy Developers.
13+
All rights reserved.
14+
15+
Redistribution and use in source and binary forms, with or without
16+
modification, are permitted provided that the following conditions are met:
17+
18+
a. Redistributions of source code must retain the above copyright notice,
19+
this list of conditions and the following disclaimer.
20+
b. Redistributions in binary form must reproduce the above copyright
21+
notice, this list of conditions and the following disclaimer in the
22+
documentation and/or other materials provided with the distribution.
23+
c. Neither the name of Enthought nor the names of the SciPy Developers
24+
may be used to endorse or promote products derived from this software
25+
without specific prior written permission.
26+
27+
28+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
29+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
30+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
31+
ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS
32+
BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY,
33+
OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
34+
SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
35+
INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
36+
CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
37+
ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
38+
THE POSSIBILITY OF SUCH DAMAGE.
39+
"""
40+
from __future__ import division, print_function, absolute_import
41+
42+
import numpy as np
43+
44+
_have_image = True
45+
try:
46+
try:
47+
from PIL import Image
48+
except ImportError:
49+
import Image
50+
if not hasattr(Image, 'frombytes'):
51+
Image.frombytes = Image.fromstring
52+
except ImportError:
53+
_have_image = False
54+
55+
56+
def _bytescale(data):
57+
"""
58+
Byte scales an array (image).
59+
60+
Byte scaling means converting the input image to uint8 dtype and scaling
61+
the range to ``(0, 255)``.
62+
63+
If the input image already has dtype uint8, no scaling is done.
64+
This function is only available if Python Imaging Library (PIL) is
65+
installed.
66+
67+
Parameters
68+
----------
69+
data : ndarray
70+
PIL image data array.
71+
72+
Returns
73+
-------
74+
img_array : uint8 ndarray
75+
The byte-scaled array.
76+
77+
Examples
78+
--------
79+
>>> from sklearn.externals._pilutil import _bytescale
80+
>>> img = np.array([[ 91.06794177, 3.39058326, 84.4221549 ],
81+
... [ 73.88003259, 80.91433048, 4.88878881],
82+
... [ 51.53875334, 34.45808177, 27.5873488 ]])
83+
>>> _bytescale(img)
84+
array([[255, 0, 236],
85+
[205, 225, 4],
86+
[140, 90, 70]], dtype=uint8)
87+
"""
88+
if data.dtype == np.uint8:
89+
return data
90+
91+
cmin = data.min()
92+
cmax = data.max()
93+
94+
cscale = cmax - cmin
95+
if cscale == 0:
96+
cscale = 1
97+
98+
scale = 255. / cscale
99+
bytedata = (data - cmin) * scale
100+
return (bytedata.clip(0, 255) + 0.5).astype(np.uint8)
101+
102+
103+
def _imread(name):
104+
"""
105+
Read an image from a file as an array.
106+
107+
This function is only available if Python Imaging Library (PIL) is
108+
installed.
109+
110+
Parameters
111+
----------
112+
name : str or file object
113+
The file name or file object to be read.
114+
115+
Returns
116+
-------
117+
imread : ndarray
118+
The array obtained by reading the image.
119+
120+
Notes
121+
-----
122+
This is a simplified combination of scipy's scipy.misc.pilutil.imread and
123+
scipy.misc.pilutil.fromimage, which are deprecated in scipy 1.0.0 and will
124+
be removed from scipy in version 1.2.0.
125+
"""
126+
if not _have_image:
127+
raise ImportError("The Python Imaging Library (PIL) "
128+
"is required to load data from jpeg files")
129+
pil_image = Image.open(name)
130+
131+
return _fromimage(pil_image)
132+
133+
134+
def _fromimage(pil_image):
135+
"""
136+
Return a copy of a PIL image as a numpy array.
137+
138+
This function is only available if Python Imaging Library (PIL) is
139+
installed.
140+
141+
Parameters
142+
----------
143+
im : PIL image
144+
Input image.
145+
146+
Returns
147+
-------
148+
fromimage : ndarray
149+
The different colour bands/channels are stored in the
150+
third dimension, such that a grey-image is MxN, an
151+
RGB-image MxNx3 and an RGBA-image MxNx4.
152+
"""
153+
if not _have_image:
154+
raise ImportError("The Python Imaging Library (PIL) "
155+
"is required to load data from jpeg files")
156+
if not Image.isImageType(pil_image):
157+
raise TypeError("Input is not a PIL image.")
158+
159+
if pil_image.mode == 'P':
160+
# Mode 'P' means there is an indexed "palette". If we leave the mode
161+
# as 'P', then when we do `a = array(pil_image)` below, `a` will be a
162+
# 2-D containing the indices into the palette, and not a 3-D array
163+
# containing the RGB or RGBA values.
164+
if 'transparency' in pil_image.info:
165+
pil_image = pil_image.convert('RGBA')
166+
else:
167+
pil_image = pil_image.convert('RGB')
168+
169+
if pil_image.mode == '1':
170+
# Workaround for crash in PIL. When pil_image is 1-bit, the cal
171+
# array(pil_image) can cause a seg. fault, or generate garbage. See
172+
# https://github.com/scipy/scipy/issues/2138 and
173+
# https://github.com/python-pillow/Pillow/issues/350.
174+
#
175+
# This converts im from a 1-bit image to an 8-bit image.
176+
pil_image = pil_image.convert('L')
177+
178+
return np.array(pil_image)
179+
180+
181+
def _toimage(arr):
182+
"""
183+
Takes a numpy array and returns a PIL image.
184+
185+
This function is only available if Python Imaging Library (PIL) is
186+
installed.
187+
.. warning::
188+
This function uses `_bytescale` under the hood to rescale images to
189+
use the full (0, 255) range. It will also cast data for 2-D images to
190+
``uint32``.
191+
192+
Notes
193+
-----
194+
For 3-D arrays if one of the dimensions is 3, the mode is 'RGB'
195+
by default or 'YCbCr' if selected.
196+
The numpy array must be either 2 dimensional or 3 dimensional.
197+
"""
198+
if not _have_image:
199+
raise ImportError("The Python Imaging Library (PIL) "
200+
"is required to load data from jpeg files")
201+
data = np.asarray(arr)
202+
if np.iscomplexobj(data):
203+
raise ValueError("Cannot convert a complex-valued array.")
204+
shape = list(data.shape)
205+
valid = len(shape) == 2 or ((len(shape) == 3) and
206+
((3 in shape) or (4 in shape)))
207+
if not valid:
208+
raise ValueError("'arr' does not have a suitable array shape for "
209+
"any mode.")
210+
if len(shape) == 2:
211+
shape = (shape[1], shape[0])
212+
bytedata = _bytescale(data)
213+
image = Image.frombytes('L', shape, bytedata.tostring())
214+
return image
215+
216+
# if here then 3-d array with a 3 or a 4 in the shape length.
217+
# Check for 3 in datacube shape --- 'RGB' or 'YCbCr'
218+
if 3 in shape:
219+
ca = np.flatnonzero(np.asarray(shape) == 3)[0]
220+
else:
221+
ca = np.flatnonzero(np.asarray(shape) == 4)
222+
if not ca:
223+
ca = ca[0]
224+
else:
225+
raise ValueError("Could not find channel dimension.")
226+
227+
numch = shape[ca]
228+
if numch not in [3, 4]:
229+
raise ValueError("Channel axis dimension is not valid.")
230+
231+
bytedata = _bytescale(data)
232+
if ca == 2:
233+
strdata = bytedata.tostring()
234+
shape = (shape[1], shape[0])
235+
elif ca == 1:
236+
strdata = np.transpose(bytedata, (0, 2, 1)).tostring()
237+
shape = (shape[2], shape[0])
238+
elif ca == 0:
239+
strdata = np.transpose(bytedata, (1, 2, 0)).tostring()
240+
shape = (shape[2], shape[1])
241+
else:
242+
raise ValueError("Invalid channel dimension.")
243+
244+
if numch == 3:
245+
mode = 'RGB'
246+
else:
247+
mode = 'RGBA'
248+
249+
# Here we know data and mode is correct
250+
return Image.frombytes(mode, shape, strdata)
251+
252+
253+
def _imresize(arr, size):
254+
"""
255+
Resize an image.
256+
257+
This function is only available if Python Imaging Library (PIL) is
258+
installed.
259+
.. warning::
260+
This function uses `_bytescale` under the hood to rescale images to
261+
use the full (0, 255) range.
262+
It will also cast data for 2-D images to ``uint32``.
263+
264+
Parameters
265+
----------
266+
arr : ndarray
267+
The array of image to be resized.
268+
size : int, float or tuple
269+
* int - Percentage of current size.
270+
* float - Fraction of current size.
271+
* tuple - Size of the output image (height, width).
272+
273+
Returns
274+
-------
275+
imresize : ndarray
276+
The resized array of image.
277+
"""
278+
im = _toimage(arr)
279+
ts = type(size)
280+
if np.issubdtype(ts, np.signedinteger):
281+
percent = size / 100.0
282+
size = tuple((np.array(im.size) * percent).astype(int))
283+
elif np.issubdtype(type(size), np.floating):
284+
size = tuple((np.array(im.size) * size).astype(int))
285+
else:
286+
size = (size[1], size[0])
287+
imnew = im.resize(size, resample=2)
288+
return _fromimage(imnew)

0 commit comments

Comments
 (0)
0