8000 [MRG+1] Option to suppress validation for finiteness (#7548) · Sundrique/scikit-learn@95461f1 · GitHub
[go: up one dir, main page]

Skip to content

Commit 95461f1

Browse files
jnothmanSundrique
authored andcommitted
[MRG+1] Option to suppress validation for finiteness (scikit-learn#7548)
* ENH add suppress validation option * TST skip problematic doctest * Rename SUPPRESS_VALIDATION to PRESUME_FINITE * Change PRESUME_ to ASSUME_ for convention's sake * DOC add note regarding assert_all_finite * ENH add set_config context manager for ASSUME_FINITE * Make ASSUME_FINITE private and provide get_config * Fix ImportError due to incomplete change in last commit * TST/DOC tests and more cautious documentation for set_config * DOC what's new entry for validation suppression * context manager is now config_context; set_config affects global config * Rename missed set_config to config_context * Fix mis-named test * Mention set_config in narrative docs * More explicit about limmited restoration of context * Handle case where error raised in config_context * Reset all settings after exiting context manager
1 parent a8ff1db commit 95461f1

File tree

7 files changed

+181
-1
lines changed

7 files changed

+181
-1
lines changed

doc/modules/classes.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ Functions
4040
:template: function.rst
4141

4242
base.clone
43+
config_context
44+
set_config
45+
get_config
4346

4447

4548
.. _cluster_ref:

doc/modules/computational_performance.rst

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,25 @@ To benchmark different estimators for your case you can simply change the
6868
:ref:`sphx_glr_auto_examples_applications_plot_prediction_latency.py`. This should give
6969
you an estimate of the order of magnitude of the prediction latency.
7070

71+
.. topic:: Configuring Scikit-learn for reduced validation overhead
72+
73+
Scikit-learn does some validation on data that increases the overhead per
74+
call to ``predict`` and similar functions. In particular, checking that
75+
features are finite (not NaN or infinite) involves a full pass over the
76+
data. If you ensure that your data is acceptable, you may suppress
77+
checking for finiteness by setting the environment variable
78+
``SKLEARN_ASSUME_FINITE`` to a non-empty string before importing
79+
scikit-learn, or configure it in Python with :func:`sklearn.set_config`.
80+
For more control than these global settings, a :func:`config_context`
81+
allows you to set this configuration within a specified context::
82+
83+
>>> import sklearn
84+
>>> with sklearn.config_context(assume_finite=True):
85+
... pass # do learning/prediction here with reduced validation
86+
87+
Note that this will affect all uses of
88+
:func:`sklearn.utils.assert_all_finite` within the context.
89+
7190
Influence of the Number of Features
7291
-----------------------------------
7392

doc/whats_new.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,11 @@ Changelog
3131
New features
3232
............
3333

34+
- Validation that input data contains no NaN or inf can now be suppressed
35+
using :func:`config_context`, at your own risk. This will save on runtime,
36+
and may be particularly useful for prediction time. :issue:`7548` by
37+
`Joel Nothman`_.
38+
3439
- Added the :class:`neighbors.LocalOutlierFactor` class for anomaly
3540
detection based on nearest neighbors.
3641
:issue:`5279` by `Nicolas Goix`_ and `Alexandre Gramfort`_.

sklearn/__init__.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,78 @@
1515
import sys
1616
import re
1717
import warnings
18+
import os
19+
from contextlib import contextmanager as _contextmanager
20+
21+
_ASSUME_FINITE = bool(os.environ.get('SKLEARN_ASSUME_FINITE', False))
22+
23+
24+
def get_config():
25+
"""Retrieve current values for configuration set by :func:`set_config`
26+
27+
Returns
28+
-------
29+
config : dict
30+
Keys are parameter names that can be passed to :func:`set_config`.
31+
"""
32+
return {'assume_finite': _ASSUME_FINITE}
33+
34+
35+
def set_config(assume_finite=None):
36+
"""Set global scikit-learn configuration
37+
38+
Parameters
39+
----------
40+
assume_finite : bool, optional
41+
If True, validation for finiteness will be skipped,
42+
saving time, but leading to potential crashes. If
43+
False, validation for finiteness will be performed,
44+
avoiding error.
45+
"""
46+
global _ASSUME_FINITE
47+
if assume_finite is not None:
48+
_ASSUME_FINITE = assume_finite
49+
50+
51+
@_contextmanager
52+
def config_context(**new_config):
53+
"""Context manager for global scikit-learn configuration
54+
55+
Parameters
56+
----------
57+
assume_finite : bool, optional
58+
If True, validation for finiteness will be skipped,
59+
saving time, but leading to potential crashes. If
60+
False, validation for finiteness will be performed,
61+
avoiding error.
62+
63+
Notes
64+
-----
65+
All settings, not just those presently modified, will be returned to
66+
their previous values when the context manager is exited. This is not
67+
thread-safe.
68+
69+
Examples
70+
--------
71+
>>> import sklearn
72+
>>> from sklearn.utils.validation import assert_all_finite
73+
>>> with sklearn.config_context(assume_finite=True):
74+
... assert_all_finite([float('nan')])
75+
>>> with sklearn.config_context(assume_finite=True):
76+
... with sklearn.config_context(assume_finite=False):
77+
... assert_all_finite([float('nan')])
78+
... # doctest: +ELLIPSIS
79+
Traceback (most recent call last):
80+
...
81+
ValueError: Input contains NaN, ...
82+
"""
83+
old_config = get_config().copy()
84+
set_config(**new_config)
85+
86+
try:
87+
yield
88+
finally:
89+
set_config(**old_config)
1890

1991

2092
# Make sure that DeprecationWarning within this package always gets printed

sklearn/tests/test_config.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
from sklearn import get_config, set_config, config_context
2+
from sklearn.utils.testing import assert_equal, assert_raises
3+
4+
5+
def test_config_context():
6+
assert_equal(get_config(), {'assume_finite': False})
7+
8+
# Not using as a context manager affects nothing
9+
config_context(assume_finite=True)
10+
assert_equal(get_config(), {'assume_finite': False})
11+
12+
with config_context(assume_finite=True):
13+
assert_equal(get_config(), {'assume_finite': True})
14+
assert_equal(get_config(), {'assume_finite': False})
15+
16+
with config_context(assume_finite=True):
17+
with config_context(assume_finite=None):
18+
assert_equal(get_config(), {'assume_finite': True})
19+
20+
assert_equal(get_config(), {'assume_finite': True})
21+
22+
with config_context(assume_finite=False):
23+
assert_equal(get_config(), {'assume_finite': False})
24+
25+
with config_context(assume_finite=None):
26+
assert_equal(get_config(), {'assume_finite': False})
27+
28+
# global setting will not be retained outside of context that
29+
# did not modify this setting
30+
set_config(assume_finite=True)
31+
assert_equal(get_config(), {'assume_finite': True})
32+
33+
assert_equal(get_config(), {'assume_finite': False})
34+
35+
assert_equal(get_config(), {'assume_finite': True})
36+
37+
assert_equal(get_config(), {'assume_finite': False})
38+
39+
# No positional arguments
40+
assert_raises(TypeError, config_context, True)
41+
# No unknown arguments
42+
assert_raises(TypeError, config_context(do_something_else=True).__enter__)
43+
44+
45+
def test_config_context_exception():
46+
assert_equal(get_config(), {'assume_finite': False})
47+
try:
48+
with config_context(assume_finite=True):
49+
assert_equal(get_config(), {'assume_finite': True})
50+
raise ValueError()
51+
except ValueError:
52+
pass
53+
assert_equal(get_config(), {'assume_finite': False})
54+
55+
56+
def test_set_config():
57+
assert_equal(get_config(), {'assume_finite': False})
58+
set_config(assume_finite=None)
59+
assert_equal(get_config(), {'assume_finite': False})
60+
set_config(assume_finite=True)
61+
assert_equal(get_config(), {'assume_finite': True})
62+
set_config(assume_finite=None)
63+
assert_equal(get_config(), {'assume_finite': True})
64+
set_config(assume_finite=False)
65+
assert_equal(get_config(), {'assume_finite': False})
66+
67+
# No unknown arguments
68+
assert_raises(TypeError, set_config, do_something_else=True)

sklearn/utils/tests/test_validation.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,15 @@
3030
has_fit_parameter,
3131
check_is_fitted,
3232
check_consistent_length,
33+
assert_all_finite,
3334
)
35+
import sklearn
3436

3537
from sklearn.exceptions import NotFittedError
3638
from sklearn.exceptions import DataConversionWarning
3739

3840
from sklearn.utils.testing import assert_raise_message
3941

40-
4142
def test_as_float_array():
4243
# Test function for as_float_array
4344
X = np.ones((3, 10), dtype=np.int32)
@@ -526,3 +527,12 @@ def test_check_dataframe_fit_attribute():
526527
check_consistent_length(X_df)
527528
except ImportError:
528529
raise SkipTest("Pandas not found")
530+
531+
532+
def test_suppress_validation():
533+
X = np.array([0, np.inf])
534+
assert_raises(ValueError, assert_all_finite, X)
535+
sklearn.set_config(assume_finite=True)
536+
assert_all_finite(X)
537+
sklearn.set_config(assume_finite=False)
538+
assert_raises(ValueError, assert_all_finite, X)

sklearn/utils/validation.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from ..externals import six
1818
from ..utils.fixes import signature
19+
from .. import get_config as _get_config
1920
from ..exceptions import NonBLASDotWarning
2021
from ..exceptions import NotFittedError
2122
from ..exceptions import DataConversionWarning
@@ -30,6 +31,8 @@
3031

3132
def _assert_all_finite(X):
3233
"""Like assert_all_finite, but only for ndarray."""
34+
if _get_config()['assume_finite']:
35+
return
3336
X = np.asanyarray(X)
3437
# First try an O(n) time, O(1) space solution for the common case that
3538
# everything is finite; fall back to O(n) space np.isfinite to prevent

0 commit comments

Comments
 (0)
0