8000 BUG: update joblib to 0.7.0d · deepatdotnet/scikit-learn@437644e · GitHub
[go: up one dir, main page]

Skip to content

Commit 437644e

Browse files
GaelVaroquauxlarsmans
authored andcommitted
BUG: update joblib to 0.7.0d
Fixes hashing bug that could lead to collisions in memory
1 parent bff549f commit 437644e

File tree

3 files changed

+50
-20
lines changed

3 files changed

+50
-20
lines changed

sklearn/externals/joblib/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@
102102
103103
"""
104104

105-
__version__ = '0.7.0b'
105+
__version__ = '0.7.0d'
106106

107107

108108
from .memory import Memory

sklearn/externals/joblib/hashing.py

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
< 8000 /td>
77
# Copyright (c) 2009 Gael Varoquaux
88
# License: BSD Style, 3 clauses.
99

10+
import warnings
1011
import pickle
1112
import hashlib
1213
import sys
@@ -29,6 +30,13 @@ def __init__(self, set_sequence):
2930
self._sequence = sorted(set_sequence)
3031

3132

33+
class _MyHash(object):
34+
""" Class used to hash objects that won't normaly pickle """
35+
36+
def __init__(self, *args):
37+
self.args = args
38+
39+
3240
class Hasher(Pickler):
3341
""" A subclass of pickler, to do cryptographic hashing, rather than
3442
pickling.
@@ -43,9 +51,8 @@ def __init__(self, hash_name='md5'):
4351
def hash(self, obj, return_digest=True):
4452
try:
4553
self.dump(obj)
46-
except pickle.PicklingError:
47-
pass
48-
#self.dump(
54+
except pickle.PicklingError as e:
55+
warnings.warn('PicklingError while hashing %r: %r' % (obj, e))
4956
dumps = self.stream.getvalue()
5057
self._hash.update(dumps)
5158
if return_digest:
@@ -60,8 +67,14 @@ def save(self, obj):
6067
else:
6168
func_name = obj.__name__
6269
inst = obj.__self__
63-
cls = obj.__self__.__class__
64-
obj = (func_name, inst, cls)
70+
if type(inst) == type(pickle):
71+
obj = _MyHash(func_name, inst.__name__)
72+
elif inst is None:
73+
# type(None) or type(module) do not pickle
74+
obj = _MyHash(func_name, inst)
75+
else:
76+
cls = obj.__self__.__class__
77+
obj = _MyHash(func_name, inst, cls)
6578
Pickler.save(self, obj)
6679

6780
# The dispatch table of the pickler is not accessible in Python
@@ -70,17 +83,20 @@ def save_global(self, obj, name=None, pack=struct.pack):
7083
# We have to override this method in order to deal with objects
7184
# defined interactively in IPython that are not injected in
7285
# __main__
73-
module = getattr(obj, "__module__", None)
74-
if module == '__main__':
75-
my_name = name
76-
if my_name is None:
77-
my_name = obj.__name__
78-
mod = sys.modules[module]
79-
if not hasattr(mod, my_name):
80-
# IPython doesn't inject the variables define
81-
# interactively in __main__
82-
setattr(mod, my_name, obj)
83-
Pickler.save_global(self, obj, name=name, pack=struct.pack)
86+
try:
87+
Pickler.save_global(self, obj, name=name, pack=pack)
88+
except pickle.PicklingError:
89+
Pickler.save_global(self, obj, name=name, pack=pack)
90+
module = getattr(obj, "__module__", None)
91+
if module == '__main__':
92+
my_name = name
93+
if my_name is None:
94+
my_name = obj.__name__
95+
mod = sys.modules[module]
96+
if not hasattr(mod, my_name):
97+
# IPython doesn't inject the variables define
98+
# interactively in __main__
99+
setattr(mod, my_name, obj)
84100

85101
dispatch = Pickler.dispatch.copy()
86102
# builtin

sklearn/externals/joblib/test/test_hashing.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,13 +77,18 @@ def f(self, x):
7777
def test_trival_hash():
7878
""" Smoke test hash on various types.
7979
"""
80-
obj_list = [1, 1., 1 + 1j,
81-
'a',
82-
(1, ), [1, ], {1:1},
80+
obj_list = [1, 2, 1., 2., 1 + 1j, 2. + 1j,
81+
'a', 'b',
82+
(1, ), (1, 1, ), [1, ], [1, 1, ],
83+
{1: 1}, {1: 2}, {2: 1},
8384
None,
85+
gc.collect,
86+
[1, ].append,
8487
]
8588
for obj1 in obj_list:
8689
for obj2 in obj_list:
90+
# Check that 2 objects have the same hash only if they are
91+
# the same.
8792
yield nose.tools.assert_equal, hash(obj1) == hash(obj2), \
8893
obj1 is obj2
8994

@@ -223,6 +228,15 @@ def test_hash_object_dtype():
223228
hash(b))
224229

225230

231+
@with_numpy
232+
def test_numpy_scalar():
233+
# Numpy scalars are built from compiled functions, and lead to
234+
# strange pickling paths explored, that can give hash collisions
235+
a = np.float64(2.0)
236+
b = np.float64(3.0)
237+
nose.tools.assert_not_equal(hash(a), hash(b))
238+
239+
226240
def test_dict_hash():
227241
# Check that dictionaries hash consistently, eventhough the ordering
228242
# of the keys is not garanteed

0 commit comments

Comments
 (0)
0