|
11 | 11 | # License: BSD 3 clause
|
12 | 12 |
|
13 | 13 | from math import log, sqrt
|
| 14 | +import numbers |
14 | 15 |
|
15 | 16 | import numpy as np
|
16 | 17 | from scipy import linalg
|
@@ -421,6 +422,12 @@ def _fit_full(self, X, n_components):
|
421 | 422 | "min(n_samples, n_features)=%r with "
|
422 | 423 | "svd_solver='full'"
|
423 | 424 | % (n_components, min(n_samples, n_features)))
|
| 425 | + elif n_components >= 1: |
| 426 | + if not isinstance(n_components, (numbers.Integral, np.integer)): |
| 427 | + raise ValueError("n_components=%r must be of type int " |
| 428 | + "when greater than or equal to 1, " |
| 429 | + "was of type=%r" |
| 430 | + % (n_components, type(n_components))) |
424 | 431 |
|
425 | 432 | # Center data
|
426 | 433 | self.mean_ = np.mean(X, axis=0)
|
@@ -481,6 +488,10 @@ def _fit_truncated(self, X, n_components, svd_solver):
|
481 | 488 | "svd_solver='%s'"
|
482 | 489 | % (n_components, min(n_samples, n_features),
|
483 | 490 | svd_solver))
|
| 491 | + elif not isinstance(n_components, (numbers.Integral, np.integer)): |
| 492 | + raise ValueError("n_components=%r must be of type int " |
| 493 | + "when greater than or equal to 1, was of type=%r" |
| 494 | + % (n_components, type(n_components))) |
484 | 495 | elif svd_solver == 'arpack' and n_components == min(n_samples,
|
485 | 496 | n_features):
|
486 | 497 | raise ValueError("n_components=%r must be strictly less than "
|
|
0 commit comments