8000 ENH Add gamma='scale' option to RBFSampler (#24755) · scikit-learn/scikit-learn@61ae92a · GitHub
[go: up one dir, main page]

Skip to content

Commit 61ae92a

Browse files
glevvglemaitre
andauthored
ENH Add gamma='scale' option to RBFSampler (#24755)
Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
1 parent 0b978cb commit 61ae92a

File tree

3 files changed

+31
-5
lines changed

3 files changed

+31
-5
lines changed

doc/whats_new/v1.2.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,10 @@ Changelog
350350
- |Enhancement| :class:`kernel_approximation.SkewedChi2Sampler` now preserves
351351
dtype for `numpy.float32` inputs. :pr:`24350` by :user:`Rahil Parikh <rprkh>`.
352352

353+
- |Enhancement| :class:`kernel_approximation.RBFSampler` now accepts
354+
`'scale'` option for parameter `gamma`.
355+
:pr:`24755` by :user:`Gleb Levitski <GLevV>`
356+
353357
:mod:`sklearn.linear_model`
354358
...........................
355359

sklearn/kernel_approximation.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -249,8 +249,13 @@ class RBFSampler(ClassNamePrefixFeaturesOutMixin, TransformerMixin, BaseEstimato
249249
250250
Parameters
251251
----------
252-
gamma : float, default=1.0
252+
gamma : 'scale' or float, default=1.0
253253
Parameter of RBF kernel: exp(-gamma * x^2).
254+
If ``gamma='scale'`` is passed then it uses
255+
1 / (n_features * X.var()) as value of gamma.
256+
257+
.. versionadded:: 1.2
258+
The option `"scale"` was added in 1.2.
254259
255260
n_components : int, default=100
256261
Number of Monte Carlo samples per original feature.
@@ -319,7 +324,10 @@ class RBFSampler(ClassNamePrefixFeaturesOutMixin, TransformerMixin, BaseEstimato
319324
"""
320325

321326
_parameter_constraints: dict = {
322-
"gamma": [Interval(Real, 0, None, closed="left")],
327+
"gamma": [
328+
StrOptions({"scale"}),
329+
Interval(Real, 0.0, None, closed="left"),
330+
],
323331
"n_components": [Interval(Integral, 1, None, closed="left")],
324332
"random_state": ["random_state"],
325333
}
@@ -354,8 +362,14 @@ def fit(self, X, y=None):
354362
X = self._validate_data(X, accept_sparse="csr")
355363
random_state = check_random_state(self.random_state)
356364
n_features = X.shape[1]
357-
358-
self.random_weights_ = np.sqrt(2 * self.gamma) * random_state.normal(
365+
sparse = sp.isspmatrix(X)
366+
if self.gamma == "scale":
367+
# var = E[X^2] - E[X]^2 if sparse
368+
X_var = (X.multiply(X)).mean() - (X.mean()) ** 2 if sparse else X.var()
369+
self._gamma = 1.0 / (n_features * X_var) if X_var != 0 else 1.0
370+
else:
371+
self._gamma = self.gamma
372+
self.random_weights_ = (2.0 * self._gamma) ** 0.5 * random_state.normal(
359373
size=(n_features, self.n_components)
360374
)
361375

@@ -390,7 +404,7 @@ def transform(self, X):
390404
projection = safe_sparse_dot(X, self.random_weights_)
391405
projection += self.random_offset_
392406
np.cos(projection, projection)
393-
projection *= np.sqrt(2.0) / np.sqrt(self.n_components)
407+
projection *= (2.0 / self.n_components) ** 0.5
394408
return projection
395409

396410
def _more_tags(self):

sklearn/tests/test_kernel_approximation.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,14 @@ def test_rbf_sampler_dtype_equivalence():
242242
assert_allclose(rbf32.random_weights_, rbf64.random_weights_)
243243

244244

245+
def test_rbf_sampler_gamma_scale():
246+
"""Check the inner value computed when `gamma='scale'`."""
247+
X, y = [[0.0], [1.0]], [0, 1]
248+
rbf = RBFSampler(gamma="scale")
249+
rbf.fit(X, y)
250+
assert rbf._gamma == pytest.approx(4)
251+
252+
245253
def test_skewed_chi2_sampler_fitted_attributes_dtype(global_dtype):
246254
"""Check that the fitted attributes are stored accordingly to the
247255
data type of X."""

0 commit comments

Comments
 (0)
0