8000 Merge remote-tracking branch 'upstream/master' into pr/19102 · scikit-learn/scikit-learn@a7ccf62 · GitHub
[go: up one dir, main page]

Skip to content

Commit a7ccf62

Browse files
committed
Merge remote-tracking branch 'upstream/master' into pr/19102
2 parents 71f295c + dfc5e16 commit a7ccf62

File tree

10 files changed

+257
-29
lines changed

10 files changed

+257
-29
lines changed

doc/whats_new/v1.0.rst

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,17 @@ Changelog
6262
- |Fix| :meth:`ElasticNet.fit` no longer modifies `sample_weight` in place.
6363
:pr:`19055` by `Thomas Fan`_.
6464

65+
- |Enhancement| Validate user-supplied gram matrix passed to linear models
66+
via the `precompute` argument. :pr:`19004` by :user:`Adam Midvidy <amidvidy>`.
67+
68+
:mod:`sklearn.naive_bayes`
69+
..........................
70+
71+
- |API| The attribute ``sigma_`` is now deprecated in
72+
:class:`naive_bayes.GaussianNB` and will be removed in 1.2.
73+
Use ``var_`` instead.
74+
:pr:`18842` by :user:`Hong Shao Yang <hongshaoyang>`.
75+
6576
Code and Documentation Contributors
6677
-----------------------------------
6778

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
"""
2+
==========================================================================
3+
Fitting an Elastic Net with a precomputed Gram Matrix and Weighted Samples
4+
==========================================================================
5+
6+
The following example shows how to precompute the gram matrix
7+
while using weighted samples with an ElasticNet.
8+
9+
If weighted samples are used, the design matrix must be centered and then
10+
rescaled by the square root of the weight vector before the gram matrix
11+
is computed.
12+
13+
.. note::
14+
`sample_weight` vector is also rescaled to sum to `n_samples`, see the
15+
documentation for the `sample_weight` parameter to
16+
:func:`linear_model.ElasticNet.fit`.
17+
18+
"""
19+
20+
print(__doc__)
21+
22+
# %%
23+
# Let's start by loading the dataset and creating some sample weights.
24+
import numpy as np
25+
from sklearn.datasets import make_regression
26+
27+
rng = np.random.RandomState(0)
28+
29+
n_samples = int(1e5)
30+
X, y = make_regression(n_samples=n_samples, noise=0.5, random_state=rng)
31+
32+
sample_weight = rng.lognormal(size=n_samples)
33+
# normalize the sample weights
34+
normalized_weights = sample_weight * (n_samples / (sample_weight.sum()))
35+
36+
# %%
37+
# To fit the elastic net using the `precompute` option together with the sample
38+
# weights, we must first center the design matrix, and rescale it by the
39+
# normalized weights prior to computing the gram matrix.
40+
X_offset = np.average(X, axis=0, weights=normalized_weights)
41+
X_centered = (X - np.average(X, axis=0, weights=normalized_weights))
42+
X_scaled = X_centered * np.sqrt(normalized_weights)[:, np.newaxis]
43+
gram = np.dot(X_scaled.T, X_scaled)
44+
45+
# %%
46+
# We can now proceed with fitting. We must passed the centered design matrix to
47+
# `fit` otherwise the elastic net estimator will detect that it is uncentered
48+
# and discard the gram matrix we passed. However, if we pass the scaled design
49+
# matrix, the preprocessing code will incorrectly rescale it a second time.
50+
from sklearn.linear_model import ElasticNet
51+
52+
lm = ElasticNet(alpha=0.01, precompute=gram)
53+
lm.fit(X_centered, y, sample_weight=normalized_weights)

sklearn/linear_model/_base.py

Lines changed: 72 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from ..utils._seq_dataset import ArrayDataset32, CSRDataset32
3838
from ..utils._seq_dataset import ArrayDataset64, CSRDataset64
3939
from ..utils.validation import check_is_fitted, _check_sample_weight
40+
4041
from ..utils.fixes import delayed
4142
from ..preprocessing import normalize as f_normalize
4243

@@ -570,6 +571,61 @@ def rmatvec(b):
570571
return self
571572

572573

574+
def _check_precomputed_gram_matrix(X, precompute, X_offset, X_scale,
575+
rtol=1e-7,
57 10000 6+
atol=1e-5):
577+
"""Computes a single element of the gram matrix and compares it to
578+
the corresponding element of the user supplied gram matrix.
579+
580+
If the values do not match a ValueError will be thrown.
581+
582+
Parameters
583+
----------
584+
X : ndarray of shape (n_samples, n_features)
585+
Data array.
586+
587+
precompute : array-like of shape (n_features, n_features)
588+
User-supplied gram matrix.
589+
590+
X_offset : ndarray of shape (n_features,)
591+
Array of feature means used to center design matrix.
592+
593+
X_scale : ndarray of shape (n_features,)
594+
Array of feature scale factors used to normalize design matrix.
595+
596+
rtol : float, default=1e-7
597+
Relative tolerance; see numpy.allclose.
598+
599+
atol : float, default=1e-5
600+
absolute tolerance; see :func`numpy.allclose`. Note that the default
601+
here is more tolerant than the default for
602+
:func:`numpy.testing.assert_allclose`, where `atol=0`.
603+
604+
Raises
605+
------
606+
ValueError
607+
Raised when the provided Gram matrix is not consistent.
608+
"""
609+
610+
n_features = X.shape[1]
611+
f1 = n_features // 2
612+
f2 = min(f1+1, n_features-1)
613+
614+
v1 = (X[:, f1] - X_offset[f1]) * X_scale[f1]
615+
v2 = (X[:, f2] - X_offset[f2]) * X_scale[f2]
616+
617+
expected = np.dot(v1, v2)
618+
actual = precompute[f1, f2]
619+
620+
if not np.isclose(expected, actual, rtol=rtol, atol=atol):
621+
raise ValueError("Gram matrix passed in via 'precompute' parameter "
622+
"did not pass validation when a single element was "
623+
"checked - please check that it was computed "
624+
f"properly. For element ({f1},{f2}) we computed "
625+
f"{expected} but the user-supplied value was "
626+
f"{actual}.")
627+
628+
573629
def _pre_fit(X, y, Xy, precompute, normalize, fit_intercept, copy,
574630
check_input=True, sample_weight=None):
575631
"""Aux function used at beginning of fit in linear models
@@ -595,16 +651,22 @@ def _pre_fit(X, y, Xy, precompute, normalize, fit_intercept, copy,
595651
check_input=check_input, sample_weight=sample_weight)
596652
if sample_weight is not None:
597653
X, y = _rescale_data(X, y, sample_weight=sample_weight)
598-
if hasattr(precompute, '__array__') and (
599-
fit_intercept and not np.allclose(X_offset, np.zeros(n_features)) or
600-
normalize and not np.allclose(X_scale, np.ones(n_features))):
601-
warnings.warn("Gram matrix was provided but X was centered"
602-
" to fit intercept, "
603-
"or X was normalized : recomputing Gram matrix.",
604-
UserWarning)
605-
# recompute Gram
606-
precompute = 'auto'
607-
Xy = None
654+
if hasattr(precompute, '__array__'):
655+
if (fit_intercept and not np.allclose(X_offset, np.zeros(n_features))
656+
or normalize and not np.allclose(X_scale,
657+
np.ones(n_features))):
658+
warnings.warn(
659+
"Gram matrix was provided but X was centered to fit "
660+
"intercept, or X was normalized : recomputing Gram matrix.",
661+
UserWarning
662+
)
663+
# recompute Gram
664+
precompute = 'auto'
665+
Xy = None
666+
elif check_input:
667+
# If we're going to use the user's precomputed gram matrix, we
668+
# do a quick check to make sure its not totally bogus.
669+
_check_precomputed_gram_matrix(X, precompute, X_offset, X_scale)
608670

609671
# precompute if n_samples > n_features
610672
if isinstance(precompute, str) and precompute == 'auto':

sklearn/linear_model/_coordinate_descent.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -729,7 +729,8 @@ def fit(self, X, y, sample_weight=None, check_input=True):
729729
Target. Will be cast to X's dtype if necessary.
730730
731731
sample_weight : float or array-like of shape (n_samples,), default=None
732-
Sample weight.
732+
Sample weight. Internally, the `sample_weight` vector will be
733+
rescaled to sum to `n_samples`.
733734
734735
.. versionadded:: 0.23
735736

sklearn/linear_model/tests/test_coordinate_descent.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -743,6 +743,45 @@ def test_precompute_invalid_argument():
743743
"Got 'auto'", ElasticNet(precompute='auto').fit, X, y)
744744

745745

746+
def test_elasticnet_precompute_incorrect_gram():
747+
# check that passing an invalid precomputed Gram matrix will raise an
748+
# error.
749+
X, y, _, _ = build_dataset()
750+
751+
rng = np.random.RandomState(0)
752+
753+
X_centered = X - np.average(X, axis=0)
754+
garbage = rng.standard_normal(X.shape)
755+
precompute = np.dot(garbage.T, garbage)
756+
757+
clf = ElasticNet(alpha=0.01, precompute=precompute)
758+
msg = "Gram matrix.*did not pass validation.*"
759+
with pytest.raises(ValueError, match=msg):
760+
clf.fit(X_centered, y)
761+
762+
763+
def test_elasticnet_precompute_gram_weighted_samples():
764+
# check the equivalence between passing a precomputed Gram matrix and
765+
# internal computation using sample weights.
766+
X, y, _, _ = build_dataset()
767+
768+
rng = np.random.RandomState(0)
769+
sample_weight = rng.lognormal(size=y.shape)
770+
771+
w_norm = sample_weight * (y.shape / np.sum(sample_weight))
772+
X_c = (X - np.average(X, axis=0, weights=w_norm))
773+
X_r = X_c * np.sqrt(w_norm)[:, np.newaxis]
774+
gram = np.dot(X_r.T, X_r)
775+
776+
clf1 = ElasticNet(alpha=0.01, precompute=gram)
777+
clf1.fit(X_c, y, sample_weight=sample_weight)
778+
779+
clf2 = ElasticNet(alpha=0.01, precompute=False)
780+
clf2.fit(X, y, sample_weight=sample_weight)
781+
782+
assert_allclose(clf1.coef_, clf2.coef_)
783+
784+
746785
def test_warm_start_convergence():
747786
X, y, _, _ = build_dataset()
748787
model = ElasticNet(alpha=1e-3, tol=1e-3).fit(X, y)

sklearn/metrics/tests/test_common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def precision_recall_curve_padded_thresholds(*args, **kwargs):
198198
return np.array([
199199
precision,
200200
recall,
201-
np.pad(thresholds,
201+
np.pad(thresholds.astype(np.float64),
202202
pad_width=(0, pad_threshholds),
203203
mode='constant',
204204
constant_values=[np.nan])

sklearn/naive_bayes.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,16 @@ class labels known to the classifier
154154
absolute additive value to variances
155155
156156
sigma_ : ndarray of shape (n_classes, n_features)
157-
variance of each feature per class
157+
Variance of each feature per class.
158+
159+
.. deprecated:: 1.0
160+
`sigma_` is deprecated in 1.0 and will be removed in 1.2.
161+
Use `var_` instead.
162+
163+
var_ : ndarray of shape (n_classes, n_features)
164+
Variance of each feature per class.
165+
166+
.. versionadded:: 1.0
158167
159168
theta_ : ndarray of shape (n_classes, n_features)
160169
mean of each feature per class
@@ -377,7 +386,7 @@ def _partial_fit(self, X, y, classes=None, _refit=False,
377386
n_features = X.shape[1]
378387
n_classes = len(self.classes_)
379388
self.theta_ = np.zeros((n_classes, n_features))
380-
self.sigma_ = np.zeros((n_classes, n_features))
389+
self.var_ = np.zeros((n_classes, n_features))
381390

382391
self.class_count_ = np.zeros(n_classes, dtype=np.float64)
383392

@@ -405,7 +414,7 @@ def _partial_fit(self, X, y, classes=None, _refit=False,
405414
msg = "Number of features %d does not match previous data %d."
406415
raise ValueError(msg % (X.shape[1], self.theta_.shape[1]))
407416
# Put epsilon back in each time
408-
self.sigma_[:, :] -= self.epsilon_
417+
self.var_[:, :] -= self.epsilon_
409418

410419
classes = self.classes_
411420

@@ -429,14 +438,14 @@ def _partial_fit(self, X, y, classes=None, _refit=False,
429438
N_i = X_i.shape[0]
430439

431440
new_theta, new_sigma = self._update_mean_variance(
432-
self.class_count_[i], self.theta_[i, :], self.sigma_[i, :],
441+
self.class_count_[i], self.theta_[i, :], self.var_[i, :],
433442
X_i, sw_i)
434443

435444
self.theta_[i, :] = new_theta
436-
self.sigma_[i, :] = new_sigma
445+
self.var_[i, :] = new_sigma
437446
self.class_count_[i] += N_i
438447

439-
self.sigma_[:, :] += self.epsilon_
448+
self.var_[:, :] += self.epsilon_
440449

441450
# Update if only no priors is provided
442451
if self.priors is None:
@@ -449,14 +458,22 @@ def _joint_log_likelihood(self, X):
449458
joint_log_likelihood = []
450459
for i in range(np.size(self.classes_)):
451460
jointi = np.log(self.class_prior_[i])
452-
n_ij = - 0.5 * np.sum(np.log(2. * np.pi * self.sigma_[i, :]))
461+
n_ij = - 0.5 * np.sum(np.log(2. * np.pi * self.var_[i, :]))
453462
n_ij -= 0.5 * np.sum(((X - self.theta_[i, :]) ** 2) /
454-
(self.sigma_[i, :]), 1)
463+
(self.var_[i, :]), 1)
455464
joint_log_likelihood.append(jointi + n_ij)
456465

457466
joint_log_likelihood = np.array(joint_log_likelihood).T
458467
return joint_log_likelihood
459468

469+
@deprecated( # type: ignore
470+
"Attribute sigma_ was deprecated in 1.0 and will be removed in"
471+
"1.2. Use var_ instead."
472+
)
473+
@property
474+
def sigma_(self):
475+
return self.var_
476+
460477

461478
_ALPHA_MIN = 1e-10
462479

sklearn/neighbors/tests/test_dist_metrics.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,19 @@ def test_cdist(metric):
5555
keys = argdict.keys()
5656
for vals in itertools.product(*argdict.values()):
5757
kwargs = dict(zip(keys, vals))
58-
D_true = cdist(X1, X2, metric, **kwargs)
58+
if metric == "wminkowski":
59+
if sp_version >= parse_version("1.8.0"):
60+
pytest.skip("wminkowski will be removed in SciPy 1.8.0")
61+
62+
# wminkoski is deprecated in SciPy 1.6.0 and removed in 1.8.0
63+
ExceptionToAssert = None
64+
if sp_version >= parse_version("1.6.0"):
65+
ExceptionToAssert = DeprecationWarning
66+
with pytest.warns(ExceptionToAssert):
67+
D_true = cdist(X1, X2, metric, **kwargs)
68+
else:
69+
D_true = cdist(X1, X2, metric, **kwargs)
70+
5971
check_cdist(metric, kwargs, D_true)
6072

6173

@@ -83,7 +95,19 @@ def test_pdist(metric):
8395
keys = argdict.keys()
8496
for vals in itertools.product(*argdict.values()):
8597
kwargs = dict(zip(keys, vals))
86-
D_true = cdist(X1, X1, metric, **kwargs)
98+
if metric == "wminkowski":
99+
if sp_version >= parse_version("1.8.0"):
100+
pytest.skip("wminkowski will be removed in SciPy 1.8.0")
101+
102+
# wminkoski is deprecated in SciPy 1.6.0 and removed in 1.8.0
103+
ExceptionToAssert = None
104+
if sp_version >= parse_version("1.6.0"):
105+
ExceptionToAssert = DeprecationWarning
106+
with pytest.warns(ExceptionToAssert):
107+
D_true = cdist(X1, X1, metric, **kwargs)
108+
else:
109+
D_true = cdist(X1, X1, metric, **kwargs)
110+
87111
check_pdist(metric, kwargs, D_true)
88112

89113

sklearn/neighbors/tests/test_neighbors.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from sklearn.utils._testing import assert_raise_message
2727
from sklearn.utils._testing import ignore_warnings
2828
from sklearn.utils.validation import check_random_state
29+
from sklearn.utils.fixes import sp_version, parse_version
2930

3031
import joblib
3132

@@ -1244,6 +1245,9 @@ def test_neighbors_metrics(n_samples=20, n_features=3,
12441245
test = rng.rand(n_query_pts, n_features)
12451246

12461247
for metric, metric_params in metrics:
1248+
if metric == "wminkowski" and sp_version >= parse_version("1.8.0"):
1249+
# wminkowski will be removed in SciPy 1.8.0
1250+
continue
12471251
results = {}
12481252
p = metric_params.pop('p', 2)
12491253
for algorithm in algorithms:
@@ -1265,8 +1269,16 @@ def test_neighbors_metrics(n_samples=20, n_features=3,
12651269
if metric == 'haversine' else slice(None))
12661270

12671271
neigh.fit(X[:, feature_sl])
1268-
results[algorithm] = neigh.kneighbors(test[:, feature_sl],
1269-
return_distance=True)
1272+
1273+
# wminkoski is deprecated in SciPy 1.6.0 and removed in 1.8.0
1274+
ExceptionToAssert = None
1275+
if (metric == "wminkowski" and algorithm == 'brute'
1276+
and sp_version >= parse_version("1.6.0")):
1277+
ExceptionToAssert = DeprecationWarning
1278+
1279+
with pytest.warns(ExceptionToAssert):
1280+
results[algorithm] = neigh.kneighbors(test[:, feature_sl],
1281+
return_distance=True)
12701282

12711283
assert_array_almost_equal(results['brute'][0], results['ball_tree'][0])
12721284
assert_array_almost_equal(results['brute'][1], results['ball_tree'][1])

0 commit comments

Comments
 (0)
0