8000 [MRG] Add check for n_components in pca (#10042) · scikit-learn/scikit-learn@c3980bc · GitHub
[go: up one dir, main page]

Skip to content
Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit c3980bc

Browse files
CoderPatjnothman
authored andcommitted
[MRG] Add check for n_components in pca (#10042)
1 parent 71402ef commit c3980bc

File tree

2 files changed

+21
-0
lines changed

2 files changed

+21
-0
lines changed

sklearn/decomposition/pca.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# License: BSD 3 clause
1212

1313
from math import log, sqrt
14+
import numbers
1415

1516
import numpy as np
1617
from scipy import linalg
@@ -421,6 +422,12 @@ def _fit_full(self, X, n_components):
421422
"min(n_samples, n_features)=%r with "
422423
"svd_solver='full'"
423424
% (n_components, min(n_samples, n_features)))
425+
elif n_components >= 1:
426+
if not isinstance(n_components, (numbers.Integral, np.integer)):
427+
raise ValueError("n_components=%r must be of type int "
428+
"when greater than or equal to 1, "
429+
"was of type=%r"
430+
% (n_components, type(n_components)))
424431

425432
# Center data
426433
self.mean_ = np.mean(X, axis=0)
@@ -481,6 +488,10 @@ def _fit_truncated(self, X, n_components, svd_solver):
481488
"svd_solver='%s'"
482489
% (n_components, min(n_samples, n_features),
483490
svd_solver))
491+
elif not isinstance(n_components, (numbers.Integral, np.integer)):
492+
raise ValueError("n_components=%r must be of type int "
493+
"when greater than or equal to 1, was of type=%r"
494+
% (n_components, type(n_components)))
484495
elif svd_solver == 'arpack' and n_components == min(n_samples,
485496
n_features):
486497
raise ValueError("n_components=%r must be strictly less than "

sklearn/decomposition/tests/test_pca.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from sklearn.utils.testing import assert_raise_message
1111
from sklearn.utils.testing import assert_raises
1212
from sklearn.utils.testing import assert_raises_regex
13+
from sklearn.utils.testing import assert_raise_message
1314
from sklearn.utils.testing import assert_no_warnings
1415
from sklearn.utils.testing import assert_warns_message
1516
from sklearn.utils.testing import ignore_warnings
@@ -390,6 +391,15 @@ def test_pca_validation():
390391
PCA(n_components, svd_solver=solver)
391392
.fit, data)
392393

394+
n_components = 1.0
395+
type_ncom = type(n_components)
396+
assert_raise_message(ValueError,
397+
"n_components={} must be of type int "
398+
"when greater than or equal to 1, was of type={}"
399+
.format(n_components, type_ncom),
400+
PCA(n_components, svd_solver=solver).fit, data)
401+
402+
393403

394404
def test_n_components_none():
395405
# Ensures that n_components == None is handled correctly

0 commit comments

Comments
 (0)
0