14
14
15
15
# avoid division truncation
16
16
import warnings
17
+ from numbers import Real
17
18
import numpy as np
18
19
19
20
from . import empirical_covariance , EmpiricalCovariance
20
21
from .._config import config_context
21
22
from ..utils import check_array
23
+ from ..utils ._param_validation import Interval
22
24
23
25
24
26
# ShrunkCovariance estimator
@@ -145,6 +147,11 @@ class ShrunkCovariance(EmpiricalCovariance):
145
147
array([0.0622..., 0.0193...])
146
148
"""
147
149
150
+ _parameter_constraints = {
151
+ ** EmpiricalCovariance ._parameter_constraints ,
152
+ "shrinkage" : [Interval (Real , 0 , 1 , closed = "both" )],
153
+ }
154
+
148
155
def __init__ (self , * , store_precision = True , assume_centered = False , shrinkage = 0.1 ):
149
156
super ().__init__ (
150
157
store_precision = store_precision , assume_centered = assume_centered
@@ -168,6 +175,7 @@ def fit(self, X, y=None):
168
175
self : object
169
176
Returns the instance itself.
170
177
"""
178
+ self ._validate_params ()
171
179
X = self ._validate_data (X )
172
180
# Not calling the parent object to fit, to avoid a potential
173
181
# matrix inversion when setting the precision
0 commit comments