8000 ENH add suppress validation option · scikit-learn/scikit-learn@12dd724 · GitHub
[go: up one dir, main page]

Skip to content

Commit 12dd724

Browse files
committed
ENH add suppress validation option
1 parent b6045e6 commit 12dd724

File tree

4 files changed

+30
-1
lines changed

4 files changed

+30
-1
lines changed

doc/modules/computational_performance.rst

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,19 @@ To benchmark different estimators for your case you can simply change the
6868
:ref:`sphx_glr_auto_examples_applications_plot_prediction_latency.py`. This should give
6969
you an estimate of the order of magnitude of the prediction latency.
7070

71+
.. topic:: Configuring Scikit-learn for reduced validation overhead
72+
73+
Scikit-learn does some validation on data that increases the overhead per
74+
call to ``predict`` and similar functions. In particular, checking that
75+
features are finite (not NaN or infinite) involves a full pass over the
76+
data. If you ensure that your data is acceptable, you may suppress some
77+
validation by setting the environment variable
78+
``SKLEARN_SUPPRESS_VALIDATION`` to a non-empty string before importing
79+
scikit-learn, or configure it in Python with::
80+
81+
>>> import sklearn
82+
>>> sklearn.SUPPRESS_VALIDATION = True
83+
7184
Influence of the Number of Features
7285
-----------------------------------
7386

sklearn/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@
1515
import sys
1616
import re
1717
import warnings
18+
import os
1819

20+
SUPPRESS_VALIDATION = bool(os.environ.get('SKLEARN_SUPPRESS_VALIDATION',
21+
False))
1922

2023
# Make sure that DeprecationWarning within this package always gets printed
2124
warnings.filterwarnings('always', category=DeprecationWarning,

sklearn/utils/tests/test_validation.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,15 @@
2929
has_fit_parameter,
3030
check_is_fitted,
3131
check_consistent_length,
32+
assert_all_finite,
3233
)
34+
import sklearn
3335

3436
from sklearn.exceptions import NotFittedError
3537
from sklearn.exceptions import DataConversionWarning
3638

3739
from sklearn.utils.testing import assert_raise_message
3840

39-
4041
def test_as_float_array():
4142
# Test function for as_float_array
4243
X = np.ones((3, 10), dtype=np.int32)
@@ -469,3 +470,12 @@ def test_check_consistent_length():
469470
assert_raises_regexp(TypeError, 'estimator', check_consistent_length,
470471
[1, 2], RandomForestRegressor())
471472
# XXX: We should have a test with a string, but what is correct behaviour?
473+
474+
475+
def check_suppress_validation():
476+
X = np.array([0, np.inf])
477+
assert_raises(ValueError, assert_all_finite, X)
478+
sklearn.SUPPRESS_VALIDATION = True
479+
assert_all_finite(X)
480+
sklearn.SUPPRESS_VALIDATION = False
481+
assert_raises(ValueError, assert_all_finite, X)

sklearn/utils/validation.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ class NotFittedError(_NotFittedError):
4848

4949
def _assert_all_finite(X):
5050
"""Like assert_all_finite, but only for ndarray."""
51+
from .. import SUPPRESS_VALIDATION
52+
if SUPPRESS_VALIDATION:
53+
return
5154
X = np.asanyarray(X)
5255
# First try an O(n) time, O(1) space solution for the common case that
5356
# everything is finite; fall back to O(n) space np.isfinite to prevent

0 commit comments

Comments
 (0)
0