@@ -249,8 +249,13 @@ class RBFSampler(ClassNamePrefixFeaturesOutMixin, TransformerMixin, BaseEstimato
249
249
250
250
Parameters
251
251
----------
252
- gamma : float, default=1.0
252
+ gamma : 'scale' or float, default=1.0
253
253
Parameter of RBF kernel: exp(-gamma * x^2).
254
+ If ``gamma='scale'`` is passed then it uses
255
+ 1 / (n_features * X.var()) as value of gamma.
256
+
257
+ .. versionadded:: 1.2
258
+ The option `"scale"` was added in 1.2.
254
259
255
260
n_components : int, default=100
256
261
Number of Monte Carlo samples per original feature.
@@ -319,7 +324,10 @@ class RBFSampler(ClassNamePrefixFeaturesOutMixin, TransformerMixin, BaseEstimato
319
324
"""
320
325
321
326
_parameter_constraints : dict = {
322
- "gamma" : [Interval (Real , 0 , None , closed = "left" )],
327
+ "gamma" : [
328
+ StrOptions ({"scale" }),
329
+ Interval (Real , 0.0 , None , closed = "left" ),
330
+ ],
323
331
"n_components" : [Interval (Integral , 1 , None , closed = "left" )],
324
332
"random_state" : ["random_state" ],
325
333
}
@@ -354,8 +362,14 @@ def fit(self, X, y=None):
354
362
X = self ._validate_data (X , accept_sparse = "csr" )
355
363
random_state = check_random_state (self .random_state )
356
364
n_features = X .shape [1 ]
357
-
358
- self .random_weights_ = np .sqrt (2 * self .gamma ) * random_state .normal (
365
+ sparse = sp .isspmatrix (X )
366
+ if self .gamma == "scale" :
367
+ # var = E[X^2] - E[X]^2 if sparse
368
+ X_var = (X .multiply (X )).mean () - (X .mean ()) ** 2 if sparse else X .var ()
369
+ self ._gamma = 1.0 / (n_features * X_var ) if X_var != 0 else 1.0
370
+ else :
371
+ self ._gamma = self .gamma
372
+ self .random_weights_ = (2.0 * self ._gamma ) ** 0.5 * random_state .normal (
359
373
size = (n_features , self .n_components )
360
374
)
361
375
@@ -390,7 +404,7 @@ def transform(self, X):
390
404
projection = safe_sparse_dot (X , self .random_weights_ )
391
405
projection += self .random_offset_
392
406
np .cos (projection , projection )
393
- projection *= np . sqrt (2.0 ) / np . sqrt ( self .n_components )
407
+ projection *= (2.0 / self .n_components ) ** 0.5
394
408
return projection
395
409
396
410
def _more_tags (self ):
0 commit comments