From 1cbf18bc2b47d4bea003bf186a33ee5aa53aab59 Mon Sep 17 00:00:00 2001 From: Joel Nothman Date: Sun, 2 Oct 2016 13:50:47 +1100 Subject: [PATCH 01/17] ENH add suppress validation option --- doc/modules/computational_performance.rst | 13 +++++++++++++ sklearn/__init__.py | 3 +++ sklearn/utils/tests/test_validation.py | 12 +++++++++++- sklearn/utils/validation.py | 3 +++ 4 files changed, 30 insertions(+), 1 deletion(-) diff --git a/doc/modules/computational_performance.rst b/doc/modules/computational_performance.rst index fa7dd7f26df39..7028d257d725b 100644 --- a/doc/modules/computational_performance.rst +++ b/doc/modules/computational_performance.rst @@ -68,6 +68,19 @@ To benchmark different estimators for your case you can simply change the :ref:`sphx_glr_auto_examples_applications_plot_prediction_latency.py`. This should give you an estimate of the order of magnitude of the prediction latency. +.. topic:: Configuring Scikit-learn for reduced validation overhead + + Scikit-learn does some validation on data that increases the overhead per + call to ``predict`` and similar functions. In particular, checking that + features are finite (not NaN or infinite) involves a full pass over the + data. If you ensure that your data is acceptable, you may suppress some + validation by setting the environment variable + ``SKLEARN_SUPPRESS_VALIDATION`` to a non-empty string before importing + scikit-learn, or configure it in Python with:: + + >>> import sklearn + >>> sklearn.SUPPRESS_VALIDATION = True + Influence of the Number of Features ----------------------------------- diff --git a/sklearn/__init__.py b/sklearn/__init__.py index 4e80774bc3110..a888462bb5209 100644 --- a/sklearn/__init__.py +++ b/sklearn/__init__.py @@ -15,7 +15,10 @@ import sys import re import warnings +import os +SUPPRESS_VALIDATION = bool(os.environ.get('SKLEARN_SUPPRESS_VALIDATION', + False)) # Make sure that DeprecationWarning within this package always gets printed warnings.filterwarnings('always', category=DeprecationWarning, diff --git a/sklearn/utils/tests/test_validation.py b/sklearn/utils/tests/test_validation.py index a6268b08d1a81..6d0f38f870637 100644 --- a/sklearn/utils/tests/test_validation.py +++ b/sklearn/utils/tests/test_validation.py @@ -29,14 +29,15 @@ has_fit_parameter, check_is_fitted, check_consistent_length, + assert_all_finite, ) +import sklearn from sklearn.exceptions import NotFittedError from sklearn.exceptions import DataConversionWarning from sklearn.utils.testing import assert_raise_message - def test_as_float_array(): # Test function for as_float_array X = np.ones((3, 10), dtype=np.int32) @@ -469,3 +470,12 @@ def test_check_consistent_length(): assert_raises_regexp(TypeError, 'estimator', check_consistent_length, [1, 2], RandomForestRegressor()) # XXX: We should have a test with a string, but what is correct behaviour? + + +def check_suppress_validation(): + X = np.array([0, np.inf]) + assert_raises(ValueError, assert_all_finite, X) + sklearn.SUPPRESS_VALIDATION = True + assert_all_finite(X) + sklearn.SUPPRESS_VALIDATION = False + assert_raises(ValueError, assert_all_finite, X) diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index deb98eef85039..f059597db8a15 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -48,6 +48,9 @@ class NotFittedError(_NotFittedError): def _assert_all_finite(X): """Like assert_all_finite, but only for ndarray.""" + from .. import SUPPRESS_VALIDATION + if SUPPRESS_VALIDATION: + return X = np.asanyarray(X) # First try an O(n) time, O(1) space solution for the common case that # everything is finite; fall back to O(n) space np.isfinite to prevent From f0878d0555adaf2357737a2bbee873c6ccb5faf3 Mon Sep 17 00:00:00 2001 From: Joel Nothman Date: Sun, 2 Oct 2016 16:45:18 +1100 Subject: [PATCH 02/17] TST skip problematic doctest --- doc/modules/computational_performance.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/modules/computational_performance.rst b/doc/modules/computational_performance.rst index 7028d257d725b..a3ac0da49b3ea 100644 --- a/doc/modules/computational_performance.rst +++ b/doc/modules/computational_performance.rst @@ -79,7 +79,7 @@ you an estimate of the order of magnitude of the prediction latency. scikit-learn, or configure it in Python with:: >>> import sklearn - >>> sklearn.SUPPRESS_VALIDATION = True + >>> sklearn.SUPPRESS_VALIDATION = True # doctest: +SKIP Influence of the Number of Features ----------------------------------- From 4d35dd83778a58dbc1143ab03525a6d637bca0d0 Mon Sep 17 00:00:00 2001 From: Joel Nothman Date: Wed, 5 Oct 2016 11:30:51 +1100 Subject: [PATCH 03/17] Rename SUPPRESS_VALIDATION to PRESUME_FINITE --- doc/modules/computational_performance.rst | 8 ++++---- sklearn/__init__.py | 3 +-- sklearn/utils/tests/test_validation.py | 4 ++-- sklearn/utils/validation.py | 4 ++-- 4 files changed, 9 insertions(+), 10 deletions(-) diff --git a/doc/modules/computational_performance.rst b/doc/modules/computational_performance.rst index a3ac0da49b3ea..c80809c86a85d 100644 --- a/doc/modules/computational_performance.rst +++ b/doc/modules/computational_performance.rst @@ -73,13 +73,13 @@ you an estimate of the order of magnitude of the prediction latency. Scikit-learn does some validation on data that increases the overhead per call to ``predict`` and similar functions. In particular, checking that features are finite (not NaN or infinite) involves a full pass over the - data. If you ensure that your data is acceptable, you may suppress some - validation by setting the environment variable - ``SKLEARN_SUPPRESS_VALIDATION`` to a non-empty string before importing + data. If you ensure that your data is acceptable, you may suppress + checking for finiteness by setting the environment variable + ``SKLEARN_PRESUME_FINITE`` to a non-empty string before importing scikit-learn, or configure it in Python with:: >>> import sklearn - >>> sklearn.SUPPRESS_VALIDATION = True # doctest: +SKIP + >>> sklearn.PRESUME_FINITE = True # doctest: +SKIP Influence of the Number of Features ----------------------------------- diff --git a/sklearn/__init__.py b/sklearn/__init__.py index a888462bb5209..56a1cd16de948 100644 --- a/sklearn/__init__.py +++ b/sklearn/__init__.py @@ -17,8 +17,7 @@ import warnings import os -SUPPRESS_VALIDATION = bool(os.environ.get('SKLEARN_SUPPRESS_VALIDATION', - False)) +PRESUME_FINITE = bool(os.environ.get('SKLEARN_PRESUME_FINITE', False)) # Make sure that DeprecationWarning within this package always gets printed warnings.filterwarnings('always', category=DeprecationWarning, diff --git a/sklearn/utils/tests/test_validation.py b/sklearn/utils/tests/test_validation.py index 6d0f38f870637..3314e050ce109 100644 --- a/sklearn/utils/tests/test_validation.py +++ b/sklearn/utils/tests/test_validation.py @@ -475,7 +475,7 @@ def test_check_consistent_length(): def check_suppress_validation(): X = np.array([0, np.inf]) assert_raises(ValueError, assert_all_finite, X) - sklearn.SUPPRESS_VALIDATION = True + sklearn.PRESUME_FINITE = True assert_all_finite(X) - sklearn.SUPPRESS_VALIDATION = False + sklearn.PRESUME_FINITE = False assert_raises(ValueError, assert_all_finite, X) diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index f059597db8a15..3f96c76340764 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -48,8 +48,8 @@ class NotFittedError(_NotFittedError): def _assert_all_finite(X): """Like assert_all_finite, but only for ndarray.""" - from .. import SUPPRESS_VALIDATION - if SUPPRESS_VALIDATION: + from .. import PRESUME_FINITE + if PRESUME_FINITE: return X = np.asanyarray(X) # First try an O(n) time, O(1) space solution for the common case that From 2530a47f66ce5b6012af20b87556579913fcc6ec Mon Sep 17 00:00:00 2001 From: Joel Nothman Date: Fri, 7 Oct 2016 14:35:25 +1100 Subject: [PATCH 04/17] Change PRESUME_ to ASSUME_ for convention's sake --- doc/modules/computational_performance.rst | 4 ++-- sklearn/__init__.py | 2 +- sklearn/utils/tests/test_validation.py | 4 ++-- sklearn/utils/validation.py | 4 ++-- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/doc/modules/computational_performance.rst b/doc/modules/computational_performance.rst index c80809c86a85d..73e807425dca2 100644 --- a/doc/modules/computational_performance.rst +++ b/doc/modules/computational_performance.rst @@ -75,11 +75,11 @@ you an estimate of the order of magnitude of the prediction latency. features are finite (not NaN or infinite) involves a full pass over the data. If you ensure that your data is acceptable, you may suppress checking for finiteness by setting the environment variable - ``SKLEARN_PRESUME_FINITE`` to a non-empty string before importing + ``SKLEARN_ASSUME_FINITE`` to a non-empty string before importing scikit-learn, or configure it in Python with:: >>> import sklearn - >>> sklearn.PRESUME_FINITE = True # doctest: +SKIP + >>> sklearn.ASSUME_FINITE = True # doctest: +SKIP Influence of the Number of Features ----------------------------------- diff --git a/sklearn/__init__.py b/sklearn/__init__.py index 56a1cd16de948..28521cc4915df 100644 --- a/sklearn/__init__.py +++ b/sklearn/__init__.py @@ -17,7 +17,7 @@ import warnings import os -PRESUME_FINITE = bool(os.environ.get('SKLEARN_PRESUME_FINITE', False)) +ASSUME_FINITE = bool(os.environ.get('SKLEARN_ASSUME_FINITE', False)) # Make sure that DeprecationWarning within this package always gets printed warnings.filterwarnings('always', category=DeprecationWarning, diff --git a/sklearn/utils/tests/test_validation.py b/sklearn/utils/tests/test_validation.py index 3314e050ce109..b5122b1762180 100644 --- a/sklearn/utils/tests/test_validation.py +++ b/sklearn/utils/tests/test_validation.py @@ -475,7 +475,7 @@ def test_check_consistent_length(): def check_suppress_validation(): X = np.array([0, np.inf]) assert_raises(ValueError, assert_all_finite, X) - sklearn.PRESUME_FINITE = True + sklearn.ASSUME_FINITE = True assert_all_finite(X) - sklearn.PRESUME_FINITE = False + sklearn.ASSUME_FINITE = False assert_raises(ValueError, assert_all_finite, X) diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index 3f96c76340764..f47773ae7458e 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -48,8 +48,8 @@ class NotFittedError(_NotFittedError): def _assert_all_finite(X): """Like assert_all_finite, but only for ndarray.""" - from .. import PRESUME_FINITE - if PRESUME_FINITE: + from .. import ASSUME_FINITE + if ASSUME_FINITE: return X = np.asanyarray(X) # First try an O(n) time, O(1) space solution for the common case that From 379302ce4b2eca0e1eecc088a41b1d7a51e48673 Mon Sep 17 00:00:00 2001 From: Joel Nothman Date: Sat, 15 Oct 2016 20:44:16 +1100 Subject: [PATCH 05/17] DOC add note regarding assert_all_finite --- doc/modules/computational_performance.rst | 3 +++ 1 file changed, 3 insertions(+) diff --git a/doc/modules/computational_performance.rst b/doc/modules/computational_performance.rst index 73e807425dca2..c70584efec169 100644 --- a/doc/modules/computational_performance.rst +++ b/doc/modules/computational_performance.rst @@ -81,6 +81,9 @@ you an estimate of the order of magnitude of the prediction latency. >>> import sklearn >>> sklearn.ASSUME_FINITE = True # doctest: +SKIP + Note that this will affect all uses of + :func:`sklearn.utils.assert_all_finite`. + Influence of the Number of Features ----------------------------------- From ac125b991e4375e8ce61961fc3472e936488b6a2 Mon Sep 17 00:00:00 2001 From: Joel Nothman Date: Wed, 23 Nov 2016 22:11:43 +1100 Subject: [PATCH 06/17] ENH add set_config context manager for ASSUME_FINITE --- doc/modules/classes.rst | 1 + doc/modules/computational_performance.rst | 5 +-- sklearn/__init__.py | 40 +++++++++++++++++++++++ 3 files changed, 44 insertions(+), 2 deletions(-) diff --git a/doc/modules/classes.rst b/doc/modules/classes.rst index 8a077daf018df..988e3b3d24826 100644 --- a/doc/modules/classes.rst +++ b/doc/modules/classes.rst @@ -40,6 +40,7 @@ Functions :template: function.rst base.clone + set_config .. _cluster_ref: diff --git a/doc/modules/computational_performance.rst b/doc/modules/computational_performance.rst index c70584efec169..1c528cb8900a2 100644 --- a/doc/modules/computational_performance.rst +++ b/doc/modules/computational_performance.rst @@ -79,10 +79,11 @@ you an estimate of the order of magnitude of the prediction latency. scikit-learn, or configure it in Python with:: >>> import sklearn - >>> sklearn.ASSUME_FINITE = True # doctest: +SKIP + >>> with sklearn.set_config(assume_finite=True): + ... pass # do learning/prediction here with reduced validation Note that this will affect all uses of - :func:`sklearn.utils.assert_all_finite`. + :func:`sklearn.utils.assert_all_finite` within the context. Influence of the Number of Features ----------------------------------- diff --git a/sklearn/__init__.py b/sklearn/__init__.py index 28521cc4915df..1564390563cb2 100644 --- a/sklearn/__init__.py +++ b/sklearn/__init__.py @@ -16,9 +16,49 @@ import re import warnings import os +from contextlib import contextmanager as _contextmanager ASSUME_FINITE = bool(os.environ.get('SKLEARN_ASSUME_FINITE', False)) + +@_contextmanager +def set_config(assume_finite=None): + """Context manager for global scikit-learn settings + + Parameters + ---------- + assume_finite : bool, optional + If True, validation for finiteness will be skipped, saving time. + If False, validation for finiteness will be performed, avoiding error. + + Notes + ----- + Settings will be returned to previous values when the context manager + is exited. This is not thread-safe. + + Examples + -------- + >>> import sklearn + >>> from sklearn.utils.validation import assert_all_finite + >>> with sklearn.set_config(assume_finite=True): + ... assert_all_finite([float('nan')]) + >>> with sklearn.set_config(assume_finite=True): + ... with sklearn.set_config(assume_finite=False): + ... assert_all_finite([float('nan')]) + ... # doctest: +ELLIPSIS + Traceback (most recent call last): + ... + ValueError: Input contains NaN, ... + """ + global ASSUME_FINITE + prev_assume_finite = ASSUME_FINITE + if assume_finite is not None: + ASSUME_FINITE = assume_finite + + yield + ASSUME_FINITE = prev_assume_finite + + # Make sure that DeprecationWarning within this package always gets printed warnings.filterwarnings('always', category=DeprecationWarning, module='^{0}\.'.format(re.escape(__name__))) From 0b800e5edda7922e3d3bf552285e8d32cdc5b2bc Mon Sep 17 00:00:00 2001 From: Joel Nothman Date: Wed, 23 Nov 2016 22:24:09 +1100 Subject: [PATCH 07/17] Make ASSUME_FINITE private and provide get_config --- doc/modules/classes.rst | 1 + sklearn/__init__.py | 21 ++++++++++++++++----- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/doc/modules/classes.rst b/doc/modules/classes.rst index 988e3b3d24826..e9a19e7ee6c02 100644 --- a/doc/modules/classes.rst +++ b/doc/modules/classes.rst @@ -41,6 +41,7 @@ Functions base.clone set_config + get_config .. _cluster_ref: diff --git a/sklearn/__init__.py b/sklearn/__init__.py index 1564390563cb2..8142a41a71ecd 100644 --- a/sklearn/__init__.py +++ b/sklearn/__init__.py @@ -18,7 +18,18 @@ import os from contextlib import contextmanager as _contextmanager -ASSUME_FINITE = bool(os.environ.get('SKLEARN_ASSUME_FINITE', False)) +_ASSUME_FINITE = bool(os.environ.get('SKLEARN_ASSUME_FINITE', False)) + + +def get_config(): + """Retrieve current values for configuration set by :func:`set_config` + + Returns + ------- + config : dict + Keys are parameter names that can be passed to :func:`set_config`. + """ + return {'assume_finite': _ASSUME_FINITE} @_contextmanager @@ -50,13 +61,13 @@ def set_config(assume_finite=None): ... ValueError: Input contains NaN, ... """ - global ASSUME_FINITE - prev_assume_finite = ASSUME_FINITE + global _ASSUME_FINITE + prev_assume_finite = _ASSUME_FINITE if assume_finite is not None: - ASSUME_FINITE = assume_finite + _ASSUME_FINITE = assume_finite yield - ASSUME_FINITE = prev_assume_finite + _ASSUME_FINITE = prev_assume_finite # Make sure that DeprecationWarning within this package always gets printed From c53c8d92272f175ef727b01be4e265dbee1e0161 Mon Sep 17 00:00:00 2001 From: Joel Nothman Date: Wed, 23 Nov 2016 23:28:19 +1100 Subject: [PATCH 08/17] Fix ImportError due to incomplete change in last commit --- sklearn/utils/validation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index f47773ae7458e..f81e57bf571cd 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -20,6 +20,7 @@ from ..exceptions import DataConversionWarning as _DataConversionWarning from ..exceptions import NonBLASDotWarning as _NonBLASDotWarning from ..exceptions import NotFittedError as _NotFittedError +from .. import get_config as _get_config @deprecated("DataConversionWarning has been moved into the sklearn.exceptions" @@ -48,8 +49,7 @@ class NotFittedError(_NotFittedError): def _assert_all_finite(X): """Like assert_all_finite, but only for ndarray.""" - from .. import ASSUME_FINITE - if ASSUME_FINITE: + if _get_config()['assume_finite']: return X = np.asanyarray(X) # First try an O(n) time, O(1) space solution for the common case that From 69e58c425eee5750fa838dd35bed081708a39138 Mon Sep 17 00:00:00 2001 From: Joel Nothman Date: Thu, 24 Nov 2016 18:31:28 +1100 Subject: [PATCH 09/17] TST/DOC tests and more cautious documentation for set_config --- sklearn/__init__.py | 6 ++++-- sklearn/tests/test_config.py | 26 ++++++++++++++++++++++++++ 2 files changed, 30 insertions(+), 2 deletions(-) create mode 100644 sklearn/tests/test_config.py diff --git a/sklearn/__init__.py b/sklearn/__init__.py index 8142a41a71ecd..0465924e7040c 100644 --- a/sklearn/__init__.py +++ b/sklearn/__init__.py @@ -39,8 +39,10 @@ def set_config(assume_finite=None): Parameters ---------- assume_finite : bool, optional - If True, validation for finiteness will be skipped, saving time. - If False, validation for finiteness will be performed, avoiding error. + If True, validation for finiteness will be skipped, + saving time, but leading to potential crashes. If + False, validation for finiteness will be performed, + avoiding error. Notes ----- diff --git a/sklearn/tests/test_config.py b/sklearn/tests/test_config.py new file mode 100644 index 0000000000000..3aef6487a4095 --- /dev/null +++ b/sklearn/tests/test_config.py @@ -0,0 +1,26 @@ +from sklearn import get_config, set_config +from sklearn.utils.testing import assert_equal + + +def test_set_config(): + assert_equal(get_config(), {'assume_finite': False}) + + # Not using as a context manager affects nothing + set_config(assume_finite=True) + assert_equal(get_config(), {'assume_finite': False}) + + with set_config(assume_finite=True): + assert_equal(get_config(), {'assume_finite': True}) + assert_equal(get_config(), {'assume_finite': False}) + + with set_config(assume_finite=True): + with set_config(assume_finite=None): + assert_equal(get_config(), {'assume_finite': True}) + + with set_config(assume_finite=False): + assert_equal(get_config(), {'assume_finite': False}) + + with set_config(assume_finite=None): + assert_equal(get_config(), {'assume_finite': False}) + + assert_equal(get_config(), {'assume_finite': False}) From bc035881d36098a5335f756b54a44d345f4ebdb6 Mon Sep 17 00:00:00 2001 From: Joel Nothman Date: Mon, 28 Nov 2016 12:33:44 +1100 Subject: [PATCH 10/17] DOC what's new entry for validation suppression --- doc/whats_new.rst | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/doc/whats_new.rst b/doc/whats_new.rst index 47fd467cbc0eb..d13fa154c1b19 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -16,6 +16,11 @@ Changelog New features ............ + - Validation that input data contains no NaN or inf can now be suppressed + using :func:`set_config`, at your own risk. This will save on runtime, + and may be particularly useful for prediction time. :issue:`7548` by + `Joel Nothman`_. + Enhancements ............ From 67188f0299e2a6b43483a6214f54fe7fffa21847 Mon Sep 17 00:00:00 2001 From: Joel Nothman Date: Mon, 5 Dec 2016 21:50:52 +1100 Subject: [PATCH 11/17] context manager is now config_context; set_config affects global config --- doc/modules/classes.rst | 3 +- doc/modules/computational_performance.rst | 2 +- doc/whats_new.rst | 2 +- sklearn/__init__.py | 32 +++++++++++++------ sklearn/tests/test_config.py | 38 +++++++++++++++++------ 5 files changed, 56 insertions(+), 21 deletions(-) diff --git a/doc/modules/classes.rst b/doc/modules/classes.rst index 9e1cadf856785..d381cb0c261fc 100644 --- a/doc/modules/classes.rst +++ b/doc/modules/classes.rst @@ -40,6 +40,7 @@ Functions :template: function.rst base.clone + config_context set_config get_config @@ -1420,4 +1421,4 @@ To be removed in 0.20 cross_validation.cross_val_score cross_validation.check_cv cross_validation.permutation_test_score - cross_validation.train_test_split \ No newline at end of file + cross_validation.train_test_split diff --git a/doc/modules/computational_performance.rst b/doc/modules/computational_performance.rst index 1c528cb8900a2..bfa02d5669b57 100644 --- a/doc/modules/computational_performance.rst +++ b/doc/modules/computational_performance.rst @@ -79,7 +79,7 @@ you an estimate of the order of magnitude of the prediction latency. scikit-learn, or configure it in Python with:: >>> import sklearn - >>> with sklearn.set_config(assume_finite=True): + >>> with sklearn.config_context(assume_finite=True): ... pass # do learning/prediction here with reduced validation Note that this will affect all uses of diff --git a/doc/whats_new.rst b/doc/whats_new.rst index 54dbcf208e49d..6f33985755260 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -32,7 +32,7 @@ New features ............ - Validation that input data contains no NaN or inf can now be suppressed - using :func:`set_config`, at your own risk. This will save on runtime, + using :func:`config_context`, at your own risk. This will save on runtime, and may be particularly useful for prediction time. :issue:`7548` by `Joel Nothman`_. diff --git a/sklearn/__init__.py b/sklearn/__init__.py index 0465924e7040c..3a9fa305315fc 100644 --- a/sklearn/__init__.py +++ b/sklearn/__init__.py @@ -32,9 +32,25 @@ def get_config(): return {'assume_finite': _ASSUME_FINITE} -@_contextmanager def set_config(assume_finite=None): - """Context manager for global scikit-learn settings + """Set global scikit-learn configuration + + Parameters + ---------- + assume_finite : bool, optional + If True, validation for finiteness will be skipped, + saving time, but leading to potential crashes. If + False, validation for finiteness will be performed, + avoiding error. + """ + global _ASSUME_FINITE + if assume_finite is not None: + _ASSUME_FINITE = assume_finite + + +@_contextmanager +def config_context(**kwargs): + """Context manager for global scikit-learn configuration Parameters ---------- @@ -53,9 +69,9 @@ def set_config(assume_finite=None): -------- >>> import sklearn >>> from sklearn.utils.validation import assert_all_finite - >>> with sklearn.set_config(assume_finite=True): + >>> with sklearn.config_context(assume_finite=True): ... assert_all_finite([float('nan')]) - >>> with sklearn.set_config(assume_finite=True): + >>> with sklearn.config_context(assume_finite=True): ... with sklearn.set_config(assume_finite=False): ... assert_all_finite([float('nan')]) ... # doctest: +ELLIPSIS @@ -63,13 +79,11 @@ def set_config(assume_finite=None): ... ValueError: Input contains NaN, ... """ - global _ASSUME_FINITE - prev_assume_finite = _ASSUME_FINITE - if assume_finite is not None: - _ASSUME_FINITE = assume_finite + config = get_config().copy() + set_config(**kwargs) yield - _ASSUME_FINITE = prev_assume_finite + set_config(**{k: config[k] for k in kwargs}) # Make sure that DeprecationWarning within this package always gets printed diff --git a/sklearn/tests/test_config.py b/sklearn/tests/test_config.py index 3aef6487a4095..533b195a4b89a 100644 --- a/sklearn/tests/test_config.py +++ b/sklearn/tests/test_config.py @@ -1,26 +1,46 @@ -from sklearn import get_config, set_config -from sklearn.utils.testing import assert_equal +from sklearn import get_config, set_config, config_context +from sklearn.utils.testing import assert_equal, assert_raises -def test_set_config(): +def test_config_context(): assert_equal(get_config(), {'assume_finite': False}) # Not using as a context manager affects nothing - set_config(assume_finite=True) + config_context(assume_finite=True) assert_equal(get_config(), {'assume_finite': False}) - with set_config(assume_finite=True): + with config_context(assume_finite=True): assert_equal(get_config(), {'assume_finite': True}) assert_equal(get_config(), {'assume_finite': False}) - with set_config(assume_finite=True): - with set_config(assume_finite=None): + with config_context(assume_finite=True): + with config_context(assume_finite=None): assert_equal(get_config(), {'assume_finite': True}) - with set_config(assume_finite=False): + with config_context(assume_finite=False): assert_equal(get_config(), {'assume_finite': False}) - with set_config(assume_finite=None): + with config_context(assume_finite=None): assert_equal(get_config(), {'assume_finite': False}) assert_equal(get_config(), {'assume_finite': False}) + + # No positional arguments + assert_raises(TypeError, config_context, True) + # No unknown arguments + assert_raises(TypeError, config_context(do_something_else=True).__enter__) + + +def test_set_config(): + assert_equal(get_config(), {'assume_finite': False}) + set_config(assume_finite=None) + assert_equal(get_config(), {'assume_finite': False}) + set_config(assume_finite=True) + assert_equal(get_config(), {'assume_finite': True}) + set_config(assume_finite=None) + assert_equal(get_config(), {'assume_finite': True}) + set_config(assume_finite=False) + assert_equal(get_config(), {'assume_finite': False}) + + # No unknown arguments + assert_raises(TypeError, set_config, do_something_else=True) From 1ae394b2195acc80e233d6245b3ec5e0b17a654f Mon Sep 17 00:00:00 2001 From: Joel Nothman Date: Mon, 5 Dec 2016 22:57:28 +1100 Subject: [PATCH 12/17] Rename missed set_config to config_context --- sklearn/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/__init__.py b/sklearn/__init__.py index 3a9fa305315fc..233fc4cd081a4 100644 --- a/sklearn/__init__.py +++ b/sklearn/__init__.py @@ -72,7 +72,7 @@ def config_context(**kwargs): >>> with sklearn.config_context(assume_finite=True): ... assert_all_finite([float('nan')]) >>> with sklearn.config_context(assume_finite=True): - ... with sklearn.set_config(assume_finite=False): + ... with sklearn.config_context(assume_finite=False): ... assert_all_finite([float('nan')]) ... # doctest: +ELLIPSIS Traceback (most recent call last): From 71f6c2342cad38b91ecfbd1f4ead45e5d027459b Mon Sep 17 00:00:00 2001 From: Joel Nothman Date: Mon, 5 Dec 2016 23:01:10 +1100 Subject: [PATCH 13/17] Fix mis-named test --- sklearn/utils/tests/test_validation.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sklearn/utils/tests/test_validation.py b/sklearn/utils/tests/test_validation.py index 4a70421b75951..6373edbdb4edc 100644 --- a/sklearn/utils/tests/test_validation.py +++ b/sklearn/utils/tests/test_validation.py @@ -472,10 +472,10 @@ def test_check_consistent_length(): # XXX: We should have a test with a string, but what is correct behaviour? -def check_suppress_validation(): +def test_suppress_validation(): X = np.array([0, np.inf]) assert_raises(ValueError, assert_all_finite, X) - sklearn.ASSUME_FINITE = True + sklearn.set_config(assume_finite=True) assert_all_finite(X) - sklearn.ASSUME_FINITE = False + sklearn.set_config(assume_finite=False) assert_raises(ValueError, assert_all_finite, X) From 40bfbdb318e26d441e9e014fd3da5b561119a044 Mon Sep 17 00:00:00 2001 From: Joel Nothman Date: Thu, 8 Dec 2016 09:16:17 +1100 Subject: [PATCH 14/17] Mention set_config in narrative docs --- doc/modules/computational_performance.rst | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/doc/modules/computational_performance.rst b/doc/modules/computational_performance.rst index bfa02d5669b57..11272d44e6196 100644 --- a/doc/modules/computational_performance.rst +++ b/doc/modules/computational_performance.rst @@ -76,7 +76,9 @@ you an estimate of the order of magnitude of the prediction latency. data. If you ensure that your data is acceptable, you may suppress checking for finiteness by setting the environment variable ``SKLEARN_ASSUME_FINITE`` to a non-empty string before importing - scikit-learn, or configure it in Python with:: + scikit-learn, or configure it in Python with :func:`sklearn.set_config`. + For more control than these global settings, a :func:`config_context` + allows you to set this configuration within a specified context:: >>> import sklearn >>> with sklearn.config_context(assume_finite=True): From afbbdda70d35f5788bdc8bdd458025594105b235 Mon Sep 17 00:00:00 2001 From: Joel Nothman Date: Thu, 8 Dec 2016 09:18:21 +1100 Subject: [PATCH 15/17] More explicit about limmited restoration of context --- sklearn/__init__.py | 8 +++++--- sklearn/tests/test_config.py | 11 +++++++++++ 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/sklearn/__init__.py b/sklearn/__init__.py index 233fc4cd081a4..c9d04777a7649 100644 --- a/sklearn/__init__.py +++ b/sklearn/__init__.py @@ -62,8 +62,9 @@ def config_context(**kwargs): Notes ----- - Settings will be returned to previous values when the context manager - is exited. This is not thread-safe. + Only settings that are set by this context manager will be returned to + previous values when the context manager is exited. This is not + thread-safe. Examples -------- @@ -83,7 +84,8 @@ def config_context(**kwargs): set_config(**kwargs) yield - set_config(**{k: config[k] for k in kwargs}) + set_config(**{k: config[k] for k, v in kwargs.items() + if v is not None}) # Make sure that DeprecationWarning within this package always gets printed diff --git a/sklearn/tests/test_config.py b/sklearn/tests/test_config.py index 533b195a4b89a..0469a9d053830 100644 --- a/sklearn/tests/test_config.py +++ b/sklearn/tests/test_config.py @@ -17,12 +17,23 @@ def test_config_context(): with config_context(assume_finite=None): assert_equal(get_config(), {'assume_finite': True}) + assert_equal(get_config(), {'assume_finite': True}) + with config_context(assume_finite=False): assert_equal(get_config(), {'assume_finite': False}) with config_context(assume_finite=None): assert_equal(get_config(), {'assume_finite': False}) + # global setting will be retained outside of context that did + # not modify this setting + set_config(assume_finite=True) + assert_equal(get_config(), {'assume_finite': True}) + + assert_equal(get_config(), {'assume_finite': True}) + + assert_equal(get_config(), {'assume_finite': True}) + assert_equal(get_config(), {'assume_finite': False}) # No positional arguments From 9d2eaf9416ad2bc9005c03bdbf7bbab12285966e Mon Sep 17 00:00:00 2001 From: Joel Nothman Date: Fri, 16 Dec 2016 08:10:53 +1100 Subject: [PATCH 16/17] Handle case where error raised in config_context --- sklearn/__init__.py | 8 +++++--- sklearn/tests/test_config.py | 11 +++++++++++ 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/sklearn/__init__.py b/sklearn/__init__.py index 83a622fff7660..c974223770dd7 100644 --- a/sklearn/__init__.py +++ b/sklearn/__init__.py @@ -83,9 +83,11 @@ def config_context(**kwargs): config = get_config().copy() set_config(**kwargs) - yield - set_config(**{k: config[k] for k, v in kwargs.items() - if v is not None}) + try: + yield + finally: + set_config(**{k: config[k] for k, v in kwargs.items() + if v is not None}) # Make sure that DeprecationWarning within this package always gets printed diff --git a/sklearn/tests/test_config.py b/sklearn/tests/test_config.py index 0469a9d053830..ef59454d20d8a 100644 --- a/sklearn/tests/test_config.py +++ b/sklearn/tests/test_config.py @@ -42,6 +42,17 @@ def test_config_context(): assert_raises(TypeError, config_context(do_something_else=True).__enter__) +def test_config_context_exception(): + assert_equal(get_config(), {'assume_finite': False}) + try: + with config_context(assume_finite=True): + assert_equal(get_config(), {'assume_finite': True}) + raise ValueError() + except ValueError: + pass + assert_equal(get_config(), {'assume_finite': False}) + + def test_set_config(): assert_equal(get_config(), {'assume_finite': False}) set_config(assume_finite=None) From 089339ca1b97b6cb6e2f714f16e9deac7269ab96 Mon Sep 17 00:00:00 2001 From: Joel Nothman Date: Wed, 21 Dec 2016 21:32:08 +1100 Subject: [PATCH 17/17] Reset all settings after exiting context manager --- sklearn/__init__.py | 13 ++++++------- sklearn/tests/test_config.py | 6 +++--- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/sklearn/__init__.py b/sklearn/__init__.py index c974223770dd7..b4916dd5925de 100644 --- a/sklearn/__init__.py +++ b/sklearn/__init__.py @@ -49,7 +49,7 @@ def set_config(assume_finite=None): @_contextmanager -def config_context(**kwargs): +def config_context(**new_config): """Context manager for global scikit-learn configuration Parameters @@ -62,8 +62,8 @@ def config_context(**kwargs): Notes ----- - Only settings that are set by this context manager will be returned to - previous values when the context manager is exited. This is not + All settings, not just those presently modified, will be returned to + their previous values when the context manager is exited. This is not thread-safe. Examples @@ -80,14 +80,13 @@ def config_context(**kwargs): ... ValueError: Input contains NaN, ... """ - config = get_config().copy() - set_config(**kwargs) + old_config = get_config().copy() + set_config(**new_config) try: yield finally: - set_config(**{k: config[k] for k, v in kwargs.items() - if v is not None}) + set_config(**old_config) # Make sure that DeprecationWarning within this package always gets printed diff --git a/sklearn/tests/test_config.py b/sklearn/tests/test_config.py index ef59454d20d8a..b968e7b7917ea 100644 --- a/sklearn/tests/test_config.py +++ b/sklearn/tests/test_config.py @@ -25,12 +25,12 @@ def test_config_context(): with config_context(assume_finite=None): assert_equal(get_config(), {'assume_finite': False}) - # global setting will be retained outside of context that did - # not modify this setting + # global setting will not be retained outside of context that + # did not modify this setting set_config(assume_finite=True) assert_equal(get_config(), {'assume_finite': True}) - assert_equal(get_config(), {'assume_finite': True}) + assert_equal(get_config(), {'assume_finite': False}) assert_equal(get_config(), {'assume_finite': True})