2
2
"""
3
3
import os
4
4
from contextlib import contextmanager as contextmanager
5
+ import threading
5
6
6
7
_global_config = {
7
8
'assume_finite' : bool (os .environ .get ('SKLEARN_ASSUME_FINITE' , False )),
8
9
'working_memory' : int (os .environ .get ('SKLEARN_WORKING_MEMORY' , 1024 )),
9
10
'print_changed_only' : True ,
10
11
'display' : 'text' ,
11
12
}
13
+ _threadlocal = threading .local ()
14
+
15
+
16
+ def _get_threadlocal_config ():
17
+ """Get a threadlocal **mutable** configuration. If the configuration
18
+ does not exist, copy the default global configuration."""
19
+ if not hasattr (_threadlocal , 'global_config' ):
20
+ _threadlocal .global_config = _global_config .copy ()
21
+ return _threadlocal .global_config
12
22
13
23
14
24
def get_config ():
@@ -24,7 +34,9 @@ def get_config():
24
34
config_context : Context manager for global scikit-learn configuration.
25
35
set_config : Set global scikit-learn configuration.
26
36
"""
27
- return _global_config .copy ()
37
+ # Return a copy of the threadlocal configuration so that users will
38
+ # not be able to modify the configuration with the returned dict.
39
+ return _get_threadlocal_config ().copy ()
28
40
29
41
30
42
def set_config (assume_finite = None , working_memory = None ,
@@ -72,14 +84,16 @@ def set_config(assume_finite=None, working_memory=None,
72
84
config_context : Context manager for global scikit-learn configuration.
73
85
get_config : Retrieve current values of the global configuration.
74
86
"""
87
+ local_config = _get_threadlocal_config ()
88
+
75
89
if assume_finite is not None :
76
- _global_config ['assume_finite' ] = assume_finite
90
+ local_config ['assume_finite' ] = assume_finite
77
91
if working_memory is not None :
78
- _global_config ['working_memory' ] = working_memory
92
+ local_config ['working_memory' ] = working_memory
79
93
if print_changed_only is not None :
80
- _global_config ['print_changed_only' ] = print_changed_only
94
+ local_config ['print_changed_only' ] = print_changed_only
81
95
if display is not None :
82
- _global_config ['display' ] = display
96
+ local_config ['display' ] = display
83
97
84
98
85
99
@contextmanager
@@ -120,8 +134,7 @@ def config_context(**new_config):
120
134
Notes
121
135
-----
122
136
All settings, not just those presently modified, will be returned to
123
- their previous values when the context manager is exited. This is not
124
- thread-safe.
137
+ their previous values when the context manager is exited.
125
138
126
139
Examples
127
140
--------
@@ -141,7 +154,7 @@ def config_context(**new_config):
141
154
set_config : Set global scikit-learn configuration.
142
155
get_config : Retrieve current values of the global configuration.
143
156
"""
144
- old_config = get_config (). copy ()
157
+ old_config = get_config ()
145
158
set_config (** new_config )
146
159
147
160
try :
0 commit comments