8000 Merge pull request #4826 from rvraghav93/exceptions · scikit-learn/scikit-learn@d4e9d79 · GitHub
[go: up one dir, main page]

Skip to content

Commit d4e9d79

Browse files
committed
Merge pull request #4826 from rvraghav93/exceptions
[MRG + 1] move custom error/warning classes into sklearn.exceptions (and move `deprecated` away from `utils.__init__.py`)
2 parents d271708 + 857cb09 commit d4e9d79

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

57 files changed

+810
-615
lines changed

doc/developers/performance.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ silently dispatched to ``numpy.dot``. If you want to be sure when the original
105105
activate the related warning::
106106

107107
>>> import warnings
108-
>>> from sklearn.utils.validation import NonBLASDotWarning
108+
>>> from sklearn.exceptions import NonBLASDotWarning
109109
>>> warnings.simplefilter('always', NonBLASDotWarning) # doctest: +SKIP
110110

111111
.. _profiling-python-code:

doc/developers/utilities.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -293,5 +293,5 @@ Warnings and Exceptions
293293

294294
- :class:`deprecated`: Decorator to mark a function or class as deprecated.
295295

296-
- :class:`ConvergenceWarning`: Custom warning to catch convergence problems.
297-
Used in ``sklearn.covariance.graph_lasso``.
296+
- :class:`sklearn.exceptions.ConvergenceWarning`: Custom warning to catch
297+
convergence problems. Used in ``sklearn.covariance.graph_lasso``.

doc/modules/classes.rst

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,31 @@ partial dependence
378378
ensemble.partial_dependence.plot_partial_dependence
379379

380380

381+
.. _exceptions_ref:
382+
383+
:mod:`sklea F42D rn.exceptions`: Exceptions and warnings
384+
==================================================
385+
386+
.. automodule:: sklearn.exceptions
387+
:no-members:
388+
:no-inherited-members:
389+
390+
.. currentmodule:: sklearn
391+
392+
.. autosummary::
393+
:toctree: generated/
394+
:template: class_without_init.rst
395+
396+
exceptions.NotFittedError
397+
exceptions.ChangedBehaviorWarning
398+
exceptions.ConvergenceWarning
399+
exceptions.DataConversionWarning
400+
exceptions.DataDimensionalityWarning
401+
exceptions.EfficiencyWarning
402+
exceptions.FitFailedWarning
403+
exceptions.NonBLASDotWarning
404+
exceptions.UndefinedMetricWarning
405+
381406
.. _feature_extraction_ref:
382407

383408
:mod:`sklearn.feature_extraction`: Feature Extraction

doc/templates/class_without_init.rst

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
:mod:`{{module}}`.{{objname}}
2+
{{ underline }}==============
3+
4+
.. currentmodule:: {{ module }}
5+
6+
.. autoclass:: {{ objname }}
7+
8+
.. include:: {{module}}.{{objname}}.examples
9+
10+
.. raw:: html
11+
12+
<div class="clearer"></div>

examples/linear_model/plot_sparse_recovery.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656
from sklearn.metrics import auc, precision_recall_curve
5757
from sklearn.ensemble import ExtraTreesRegressor
5858
from sklearn.utils.extmath import pinvh
59-
from sklearn.utils import ConvergenceWarning
59+
from sklearn.exceptions import ConvergenceWarning
6060

6161

6262
def mutual_incoherence(X_relevant, X_irelevant):

sklearn/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@
5959

6060
__all__ = ['calibration', 'cluster', 'covariance', 'cross_decomposition',
6161
'cross_validation', 'datasets', 'decomposition', 'dummy',
62-
'ensemble', 'externals', 'feature_extraction',
62+
'ensemble', 'exceptions', 'externals', 'feature_extraction',
6363
'feature_selection', 'gaussian_process', 'grid_search',
6464
'isotonic', 'kernel_approximation', 'kernel_ridge',
6565
'lda', 'learning_curve',

sklearn/base.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,18 @@
99
from scipy import sparse
1010
from .externals import six
1111
from .utils.fixes import signature
12+
from .utils.deprecation import deprecated
13+
from .exceptions import ChangedBehaviorWarning as ChangedBehaviorWarning_
1214

1315

14-
class ChangedBehaviorWarning(UserWarning):
16+
class ChangedBehaviorWarning(ChangedBehaviorWarning_):
1517
pass
1618

19+
ChangedBehaviorWarning = deprecated("ChangedBehaviorWarning has been moved "
20+
"into the sklearn.exceptions module. "
21+
"It will not be available here from "
22+
"version 0.19")(ChangedBehaviorWarning)
23+
1724

1825
##############################################################################
1926
def clone(estimator, safe=True):

sklearn/cluster/birch.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
from ..externals.six.moves import xrange
1515
from ..utils import check_array
1616
from ..utils.extmath import row_norms, safe_sparse_dot
17-
from ..utils.validation import NotFittedError, check_is_fitted
17+
from ..utils.validation import check_is_fitted
18+
from ..exceptions import NotFittedError
1819
from .hierarchical import AgglomerativeClustering
1920

2021

sklearn/cluster/tests/test_k_means.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from sklearn.utils.testing import assert_raise_message
2121

2222

23-
from sklearn.utils.validation import DataConversionWarning
2423
from sklearn.utils.extmath import row_norms
2524
from sklearn.metrics.cluster import v_measure_score
2625
from sklearn.cluster import KMeans, k_means
@@ -29,6 +28,7 @@
2928
from sklearn.cluster.k_means_ import _mini_batch_step
3029
from sklearn.datasets.samples_generator import make_blobs
3130
from sklearn.externals.six.moves import cStringIO as StringIO
31+
from sklearn.exceptions import DataConversionWarning
3232

3333

3434
# non centered, sparse centers to check the

sklearn/covariance/graph_lasso_.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from .empirical_covariance_ import (empirical_covariance, EmpiricalCovariance,
1717
log_likelihood)
1818

19-
from ..utils import ConvergenceWarning
19+
from ..exceptions import ConvergenceWarning
2020
from ..utils.extmath import pinvh
2121
from ..utils.validation import check_random_state, check_array
2222
from ..linear_model import lars_path

sklearn/covariance/tests/test_robust_covariance.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88

99
from sklearn.utils.testing import assert_almost_equal
1010
from sklearn.utils.testing import assert_array_almost_equal
11-
from sklearn.utils.testing import assert_raises, assert_warns
11+
from sklearn.utils.testing import assert_raises
1212
from sklearn.utils.testing import assert_raise_message
13-
f 10000 rom sklearn.utils.validation import NotFittedError
13+
from sklearn.exceptions import NotFittedError
1414

1515
from sklearn import datasets
1616
from sklearn.covariance import empirical_covariance, MinCovDet, \

sklearn/cross_validation.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from .metrics.scorer import check_scoring
3333
from .utils.fixes import bincount
3434
from .gaussian_process.kernels import Kernel as GPKernel
35+
from .exceptions import FitFailedWarning
3536

3637
__all__ = ['KFold',
3738
'LabelKFold',
@@ -1428,10 +1429,6 @@ def cross_val_score(estimator, X, y=None, scoring=None, cv=None, n_jobs=1,
14281429
return np.array(scores)[:, 0]
14291430

14301431

1431-
class FitFailedWarning(RuntimeWarning):
1432-
pass
1433-
1434-
14351432
def _fit_and_score(estimator, X, y, scorer, train, test, verbose,
14361433
parameters, fit_params, return_train_score=False,
14371434
return_parameters=False, error_score='raise'):

sklearn/decomposition/factor_analysis.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from ..utils import check_array, check_random_state
3131
from ..utils.extmath import fast_logdet, fast_dot, randomized_svd, squared_norm
3232
from ..utils.validation import check_is_fitted
33-
from ..utils import ConvergenceWarning
33+
from ..exceptions import ConvergenceWarning
3434

3535

3636
class FactorAnalysis(BaseEstimator, TransformerMixin):

sklearn/decomposition/kernel_pca.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
from scipy import linalg
88

99
from ..utils.arpack import eigsh
10-
from ..utils.validation import check_is_fitted, NotFittedError
10+
from ..utils.validation import check_is_fitted
11+
from ..exceptions import NotFittedError
1112
from ..base import BaseEstimator, TransformerMixin
1213
from ..preprocessing import KernelCenterer
1314
from ..metrics.pairwise import pairwise_kernels

sklearn/decomposition/nmf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from ..utils.extmath import fast_dot
2727
from ..utils.validation import check_is_fitted, check_non_negative
2828
from ..utils import deprecated
29-
from ..utils import ConvergenceWarning
29+
from ..exceptions import ConvergenceWarning
3030
from .cdnmf_fast import _update_cdnmf_fast
3131

3232

sklearn/decomposition/online_lda.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,11 @@
1818
from ..base import BaseEstimator, TransformerMixin
1919
from ..utils import (check_random_state, check_array,
2020
gen_batches, gen_even_slices, _get_n_jobs)
21-
from ..utils.validation import NotFittedError, check_non_negative
21+
from ..utils.validation import check_non_negative
2222
from ..utils.extmath import logsumexp
2323
from ..externals.joblib import Parallel, delayed
2424
from ..externals.six.moves import xrange
25+
from ..exceptions import NotFittedError
2526

2627
from ._online_lda import (mean_change, _dirichlet_expectation_1d,
2728
_dirichlet_expectation_2d)

sklearn/decomposition/tests/test_factor_analysis.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from sklearn.utils.testing import assert_raises
1212
from sklearn.utils.testing import assert_almost_equal
1313
from sklearn.utils.testing import assert_array_almost_equal
14-
from sklearn.utils import ConvergenceWarning
14+
from sklearn.exceptions import ConvergenceWarning
1515
from sklearn.decomposition import FactorAnalysis
1616

1717

sklearn/decomposition/tests/test_online_lda.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from sklearn.utils.testing import assert_raises_regexp
1616
from sklearn.utils.testing import if_safe_multiprocessing_with_blas
1717

18-
from sklearn.utils.validation import NotFittedError
18+
from sklearn.exceptions import NotFittedError
1919
from sklearn.externals.six.moves import xrange
2020

2121

sklearn/ensemble/forest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ class calls the ``fit`` method of each sub-estimator on random samples
5959
ExtraTreeClassifier, ExtraTreeRegressor)
6060
from ..tree._tree import DTYPE, DOUBLE
6161
from ..utils import check_random_state, check_array, compute_sample_weight
62-
from ..utils.validation import DataConversionWarning, NotFittedError
62+
from ..exceptions import DataConversionWarning, NotFittedError
6363
from .base import BaseEnsemble, _partition_estimators
6464
from ..utils.fixes import bincount
6565
from ..utils.multiclass import check_classification_targets

sklearn/ensemble/gradient_boosting.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,8 @@
6161
from ..utils.fixes import bincount
6262
from ..utils.stats import _weighted_percentile
6363
from ..utils.validation import check_is_fitted
64-
from ..utils.validation import NotFittedError
6564
from ..utils.multiclass import check_classification_targets
65+
from ..exceptions import NotFittedError
6666

6767

6868
class QuantileEstimator(BaseEstimator):

sklearn/ensemble/tests/test_gradient_boosting.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,8 @@
2626
from sklearn.utils.testing import assert_raises
2727
from sklearn.utils.testing import assert_true
2828
from sklearn.utils.testing import assert_warns
29-
from sklearn.utils.testing import ignore_warnings
30-
from sklearn.utils.validation import DataConversionWarning
31-
from sklearn.utils.validation import NotFittedError
29+
from sklearn.exceptions import DataConversionWarning
30+
from sklearn.exceptions import NotFittedError
3231

3332
# toy sample
3433
X = [[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1]]

sklearn/exceptions.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
"""
2+
The :mod:`sklearn.exceptions` module includes all custom warnings and error
3+
classes used across scikit-learn.
4+
"""
5+
6+
__all__ = ['NotFittedError',
7+
'ChangedBehaviorWarning',
8+
'ConvergenceWarning',
9+
'DataConversionWarning',
10+
'DataDimensionalityWarning',
11+
'EfficiencyWarning',
12+
'FitFailedWarning',
13+
'NonBLASDotWarning',
14+
'UndefinedMetricWarning']
15+
16+
17+
class NotFittedError(ValueError, AttributeError):
18+
"""Exception class to raise if estimator is used before fitting.
19+
20+
This class inherits from both ValueError and AttributeError to help with
21+
exception handling and backward compatibility.
22+
23+
Examples
24+
--------
25+
>>> from sklearn.svm import LinearSVC
26+
>>> from sklearn.exceptions import NotFittedError
27+
>>> try:
28+
... LinearSVC().predict([[1, 2], [2, 3], [3, 4]])
29+
... except NotFittedError as e:
30+
... print(repr(e))
31+
... # doctest: +NORMALIZE_WHITESPACE +ELLIPSIS
32+
NotFittedError('This LinearSVC instance is not fitted yet',)
33+
"""
34+
35+
36+
class ChangedBehaviorWarning(UserWarning):
37+
"""Warning class used to notify the user of any change in the behavior."""
38+
39+
40+
class ConvergenceWarning(UserWarning):
41+
"""Custom warning to capture convergence problems"""
42+
43+
44+
class DataConversionWarning(UserWarning):
45+
"""Warning used to notify implicit data conversions happening in the code.
46+
47+
This warning occurs when some input data needs to be converted or
48+
interpreted in a way that may not match the user's expectations.
49+
50+
For example, this warning may occur when the the user
51+
- passes an integer array to a function which expects float input and
52+
will convert the input
53+
- requests a non-copying operation, but a copy is required to meet the
54+
implementation's data-type expectations;
55+
- passes an input whose shape can be interpreted ambiguously.
56+
"""
57+
58+
59+
class DataDimensionalityWarning(UserWarning):
60+
"""Custom warning to notify potential issues with data dimensionality.
61+
62+
For example, in random projection, this warning is raised when the
63+
number of components, which quantifes the dimensionality of the target
64+
projection space, is higher than the number of features, which quantifies
65+
the dimensionality of the original source space, to imply that the
66+
dimensionality of the problem will not be reduced.
67+
"""
68+
69+
70+
class EfficiencyWarning(UserWarning):
71+
"""Warning used to notify the user of inefficient computation.
72+
73+
This warning notifies the user that the efficiency may not be optimal due
74+
to some reason which may be included as a part of the warning message.
75+
This may be subclassed into a more specific Warning class.
76+
"""
77+
78+
79+
class FitFailedWarning(RuntimeWarning):
80+
"""Warning class used if there is an error while fitting the estimator.
81+
82+
This Warning is used in meta estimators GridSearchCV and RandomizedSearchCV
83+
and the cross-validation helper function cross_val_score to warn when there
84+
is an error while fitting the estimator.
85+
86+
Examples
87+
--------
88+
>>> from sklearn.grid_search import GridSearchCV
89+
>>> from sklearn.svm import LinearSVC
90+
>>> from sklearn.exceptions import FitFailedWarning
91+
>>> import warnings
92+
>>> warnings.simplefilter('always', FitFailedWarning)
93+
>>> gs = GridSearchCV(LinearSVC(), {'C': [-1, -2]}, error_score=0)
94+
>>> X, y = [[1, 2], [3, 4], [5, 6], [7, 8], [8, 9]], [0, 0, 0, 1, 1]
95+
>>> with warnings.catch_warnings(record=True) as w:
96+
... try:
97+
... gs.fit(X, y) # This will raise a ValueError since C is < 0
98+
... except ValueError:
99+
... pass
100+
... print(repr(w[-1].message))
101+
... # doctest: +NORMALIZE_WHITESPACE
102+
FitFailedWarning("Classifier fit failed. The score on this train-test
103+
partition for these parameters will be set to 0.000000. Details:
104+
\\nValueError('Penalty term must be positive; got (C=-2)',)",)
105+
"""
106+
107+
108+
class NonBLASDotWarning(EfficiencyWarning):
109+
"""Warning used when the dot operation does not use BLAS.
110+
111+
This warning is used to notify the user that BLAS was not used for dot
112+
operation and hence the efficiency may be affected.
113+
"""
114+
115+
116+
class UndefinedMetricWarning(UserWarning):
117+
pass

sklearn/feature_selection/from_model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44
import numpy as np
55

66
from .base import SelectorMixin
7-
from ..base import (TransformerMixin, BaseEstimator, clone,
8-
MetaEstimatorMixin)
7+
from ..base import TransformerMixin, BaseEstimator, clone
98
from ..externals import six
109

1110
from ..utils import safe_mask, check_array, deprecated
12-
from ..utils.validation import NotFittedError, check_is_fitted
11+
from ..utils.validation import check_is_fitted
12+
from ..exceptions import NotFittedError
1313

1414

1515
def _get_feature_importances(estimator):

0 commit comments

Comments
 (0)
0