From 0625e0384e320e563e30aab966243f1808759168 Mon Sep 17 00:00:00 2001 From: Thierry Date: Thu, 14 Apr 2016 16:30:47 +0200 Subject: [PATCH 1/2] Integration of the new GaussianMixture class. Depreciation of the GMM class. Modification of the GaussianMixture class. Some functions from the original GSoC code have been removed, renamed or simplified. Some new functions have been introduced (as the 'check_parameters' function). Some parameters names have been changed : - covars_ -> covariances_ : to be coherent with sklearn/covariances Addition of the parameter 'warm_start' allowing to fit data by using the previous computation. The old examples have been modified to replace the deprecated GMM class by the new GaussianMixture class. Every exemple use the eigenvectors norm to solve the scale ellipse problem (Issues 6548). Correction of all commentaries from the PR - Rename MixtureBase -> BaseMixture - Remove n_features_ - Fix some problems - Add some tests Correction of the bic/aic test. Fix the test_check_means and test_check_covariances. Remove all references to the deprecated GMM class. Remove initialized_. Add and correct docstring. Correct the order of random_state. Fix small typo. Some fix in prevision of the integration of the new BayesianGaussianMixture class. Modification in preparation of the integration of the BayesianGaussianMixture class. Add 'best_n_iter' attribute. Fix some bugs and tests. Change the parameter order in the documentation. Change best_n_iter_ name to n_iter_. Fix of the warm_start problem. Fix the divergence error message. Correction of the random state init in the test file. Fix the testing problems. Update and add comments into the monotonic test. --- doc/modules/dp-derivation.rst | 64 +- doc/modules/mixture.rst | 62 +- examples/mixture/plot_gmm.py | 56 +- examples/mixture/plot_gmm_covariances.py | 38 +- examples/mixture/plot_gmm_pdf.py | 14 +- examples/mixture/plot_gmm_selection.py | 24 +- examples/mixture/plot_gmm_sin.py | 77 +- sklearn/base.py | 18 + sklearn/mixture/__init__.py | 6 +- sklearn/mixture/base.py | 442 ++++++++++ sklearn/mixture/dpgmm.py | 6 +- sklearn/mixture/gaussian_mixture.py | 651 ++++++++++++++ sklearn/mixture/gmm.py | 17 +- .../mixture/tests/test_gaussian_mixture.py | 825 ++++++++++++++++++ sklearn/mixture/tests/test_gmm.py | 515 ++++++----- 15 files changed, 2421 insertions(+), 394 deletions(-) create mode 100644 sklearn/mixture/base.py create mode 100644 sklearn/mixture/gaussian_mixture.py create mode 100644 sklearn/mixture/tests/test_gaussian_mixture.py diff --git a/doc/modules/dp-derivation.rst b/doc/modules/dp-derivation.rst index 1585e1488abbd..b02b329472dc1 100644 --- a/doc/modules/dp-derivation.rst +++ b/doc/modules/dp-derivation.rst @@ -13,7 +13,7 @@ as covariance matrices. The inference algorithm is the one from the following paper: * `Variational Inference for Dirichlet Process Mixtures - `_ + `_ David Blei, Michael Jordan. Bayesian Analysis, 2006 While this paper presents the parts of the inference algorithm that @@ -36,7 +36,7 @@ is necessary to invert the covariance/precision matrices and compute its determinant, hence the cubic term). This implementation is expected to scale at least as well as EM for -the mixture of Gaussians. +the Gaussian mixture. Update rules for VB inference ============================== @@ -78,7 +78,7 @@ The variational distribution we'll use is \sigma_k &\sim& Gamma(a_{k}, b_{k}) \\ z_{i} &\sim& Discrete(\nu_{z_i}) \\ \end{array} - + The bound ........... @@ -88,7 +88,7 @@ The variational bound is .. math:: \begin{array}{rcl} - \log P(X) &\ge& + \log P(X) &\ge& \sum_k (E_q[\log P(\phi_k)] - E_q[\log Q(\phi_k)]) \\ && +\sum_k \left( E_q[\log P(\mu_k)] - E_q[\log Q(\mu_k)] \right) \\ @@ -99,16 +99,16 @@ The variational bound is && +\sum_i E_q[\log P(X_t)] \end{array} - - -**The bound for** :math:`\phi_k` + + +**The bound for** :math:`\phi_k` .. math:: \begin{array}{rcl} - E_q[\log Beta(1,\alpha)] - E[\log Beta(\gamma_{k,1},\gamma_{k,2})] + E_q[\log Beta(1,\alpha)] - E[\log Beta(\gamma_{k,1},\gamma_{k,2})] &=& - \log \Gamma(1+\alpha) - \log \Gamma(\alpha) \\ && + \log \Gamma(1+\alpha) - \log \Gamma(\alpha) \\ && +(\alpha-1)(\Psi(\gamma_{k,2})-\Psi(\gamma_{k,1}+\gamma_{k,2})) \\ && - \log \Gamma(\gamma_{k,1}+\gamma_{k,2}) + \log \Gamma(\gamma_{k,1}) + \log \Gamma(\gamma_{k,2}) \\ && @@ -116,11 +116,11 @@ The variational bound is (\gamma_{k,1}-1)(\Psi(\gamma_{k,1})-\Psi(\gamma_{k,1}+\gamma_{k,2})) \\ && - - (\gamma_{k,2}-1)(\Psi(\gamma_{k,2})-\Psi(\gamma_{k,1}+\gamma_{k,2})) + (\gamma_{k,2}-1)(\Psi(\gamma_{k,2})-\Psi(\gamma_{k,1}+\gamma_{k,2})) \end{array} - -**The bound for** :math:`\mu_k` + +**The bound for** :math:`\mu_k` .. math:: @@ -131,11 +131,11 @@ The variational bound is - \int\!d\mu_f q(\mu_f) \log Q(\mu_f) \\ &=& - \frac{D}{2}\log 2\pi - \frac{1}{2} ||\nu_{\mu_k}||^2 - \frac{D}{2} - + \frac{D}{2} \log 2\pi e + + \frac{D}{2} \log 2\pi e \end{array} -**The bound for** :math:`\sigma_k` +**The bound for** :math:`\sigma_k` Here I'll use the inverse scale parametrization of the gamma distribution. @@ -155,21 +155,21 @@ distribution. \begin{array}{rcl} && E_q[\log P(z)] - E_q[\log Q(z)] \\ &=& - \sum_{k} \left( - \left(\sum_{j=k+1}^K \nu_{z_{i,j}}\right)(\Psi(\gamma_{k,2})-\Psi(\gamma_{k,1}+\gamma_{k,2})) + \sum_{k} \left( + \left(\sum_{j=k+1}^K \nu_{z_{i,j}}\right)(\Psi(\gamma_{k,2})-\Psi(\gamma_{k,1}+\gamma_{k,2})) + \nu_{z_{i,k}}(\Psi(\gamma_{k,1})-\Psi(\gamma_{k,1}+\gamma_{k,2})) - \log \nu_{z_{i,k}} \right) \end{array} -**The bound for** :math:`X` +**The bound for** :math:`X` Recall that there is no need for a :math:`Q(X)` so this bound is just .. math:: \begin{array}{rcl} - E_q[\log P(X_i)] &=& \sum_k \nu_{z_k} \left( - \frac{D}{2}\log 2\pi + E_q[\log P(X_i)] &=& \sum_k \nu_{z_k} \left( - \frac{D}{2}\log 2\pi +\frac{D}{2} (\Psi(a_k) - \log(b_k)) -\frac{a_k}{2b_k} (||X_i - \nu_{\mu_k}||^2+D) - \log 2 \pi e \right) \end{array} @@ -186,7 +186,7 @@ The updates \begin{array}{rcl} \gamma_{k,1} &=& 1+\sum_i \nu_{z_{i,k}} \\ - \gamma_{k,2} &=& \alpha + \sum_i \sum_{j > k} \nu_{z_{i,j}}. + \gamma_{k,2} &=& \alpha + \sum_i \sum_{j > k} \nu_{z_{i,j}}. \end{array} @@ -204,7 +204,7 @@ The gradient is so the update is -.. math:: +.. math:: \nu_{\mu_k} = \frac{\sum_i \frac{\nu_{z_{i,k}}b_k}{a_k}X_i}{1+\sum_i \frac{\nu_{z_{i,k}}b_k}{a_k}} @@ -299,9 +299,9 @@ have .. math:: \begin{array}{rcl} - E_q[\log P(X_i)] &=& \sum_k \nu_{z_k} \Big( - \frac{D}{2}\log 2\pi + E_q[\log P(X_i)] &=& \sum_k \nu_{z_k} \Big( - \frac{D}{2}\log 2\pi +\frac{1}{2}\sum_d (\Psi(a_{k,d}) - \log(b_{k,d})) \\ - && + && -\frac{1}{2}((X_i - \nu_{\mu_k})^T\bm{\frac{a_k}{b_k}}(X_i - \nu_{\mu_k})+ \sum_d \sigma_{k,d})- \log 2 \pi e \Big) \end{array} @@ -315,7 +315,7 @@ The updates only chance for :math:`\mu` (to weight them with the new **The update for** :math:`\mu` -.. math:: +.. math:: \nu_{\mu_k} = \left(\mathbf{I}+\sum_i \frac{\nu_{z_{i,k}}\mathbf{b_k}}{\mathbf{a_k}}\right)^{-1}\left(\sum_i \frac{\nu_{z_{i,k}}b_k}{a_k}X_i\right) @@ -328,13 +328,13 @@ of the bound: .. math:: - \log Q(\sigma_{k,d}) = -\sigma_{k,d} + \sum_i \nu_{z_{i,k}}\frac{1}{2}\log \sigma_{k,d} + \log Q(\sigma_{k,d}) = -\sigma_{k,d} + \sum_i \nu_{z_{i,k}}\frac{1}{2}\log \sigma_{k,d} - \frac{\sigma_{k,d}}{2}\sum_i \nu_{z_{i,k}} ((X_{i,d}-\mu_{k,d})^2 + 1) -Hence +Hence -.. math:: +.. math:: a_{k,d} = 1 + \frac{1}{2} \sum_i \nu_{z_{i,k}} @@ -381,7 +381,7 @@ There are two changes in the lower-bound: for :math:`\Sigma` and for :math:`X`. \begin{array}{rcl} \frac{D^2}{2}\log 2 + \sum_d \log \Gamma(\frac{D+1-d}{2}) \\ - \frac{aD}{2}\log 2 + \frac{a}{2} \log |\mathbf{B}| + \sum_d \log \Gamma(\frac{a+1-d}{2}) \\ - + \frac{a-D}{2}\left(\sum_d \Psi\left(\frac{a+1-d}{2}\right) + + \frac{a-D}{2}\left(\sum_d \Psi\left(\frac{a+1-d}{2}\right) + D \log 2 + \log |\mathbf{B}|\right) \\ + \frac{1}{2} a \mathbf{tr}[\mathbf{B}-\mathbf{I}] \end{array} @@ -392,10 +392,10 @@ There are two changes in the lower-bound: for :math:`\Sigma` and for :math:`X`. .. math:: \begin{array}{rcl} - E_q[\log P(X_i)] &=& \sum_k \nu_{z_k} \Big( - \frac{D}{2}\log 2\pi - +\frac{1}{2}\left(\sum_d \Psi\left(\frac{a+1-d}{2}\right) + E_q[\log P(X_i)] &=& \sum_k \nu_{z_k} \Big( - \frac{D}{2}\log 2\pi + +\frac{1}{2}\left(\sum_d \Psi\left(\frac{a+1-d}{2}\right) + D \log 2 + \log |\mathbf{B}|\right) \\ - && + && -\frac{1}{2}((X_i - \nu_{\mu_k})a\mathbf{B}(X_i - \nu_{\mu_k})+ a\mathbf{tr}(\mathbf{B}))- \log 2 \pi e \Big) \end{array} @@ -482,9 +482,9 @@ The updates .............. All that changes in the updates is that the update for mu uses only -the proper sigma and the updates for a and B don't have a sum over K, so +the proper sigma and the updates for a and B don't have a sum over K, so -.. math:: +.. math:: \nu_{\mu_k} = \left(\mathbf{I}+ a_k\mathbf{B_k}\sum_i \nu_{z_{i,k}}\right)^{-1} \left(a_k\mathbf{B_k}\sum_i \nu_{z_{i,k}} X_i\right) diff --git a/doc/modules/mixture.rst b/doc/modules/mixture.rst index c2ee9da243a79..6d6494b9bc4aa 100644 --- a/doc/modules/mixture.rst +++ b/doc/modules/mixture.rst @@ -2,9 +2,9 @@ .. _gmm: -=================================================== +======================= Gaussian mixture models -=================================================== +======================= .. currentmodule:: sklearn.mixture @@ -19,7 +19,7 @@ components are also provided. :align: center :scale: 50% - **Two-component Gaussian mixture model:** *data points, and equi-probability surfaces of + **Two-component Gaussian mixture model:** *data points, and equi-probability surfaces of the model.* A Gaussian mixture model is a probabilistic model that assumes all the @@ -33,25 +33,25 @@ Scikit-learn implements different classes to estimate Gaussian mixture models, that correspond to different estimation strategies, detailed below. -Gaussian Mixture Model -====================== +Gaussian Mixture +================ -The :class:`GMM` object implements the +The :class:`GaussianMixture` object implements the :ref:`expectation-maximization ` (EM) algorithm for fitting mixture-of-Gaussian models. It can also draw confidence ellipsoids for multivariate models, and compute the Bayesian Information Criterion to assess the number of clusters in the -data. A :meth:`GMM.fit` method is provided that learns a Gaussian +data. A :meth:`GaussianMixture.fit` method is provided that learns a Gaussian Mixture Model from train data. Given test data, it can assign to each sample the Gaussian it mostly probably belong to using -the :meth:`GMM.predict` method. +the :meth:`GaussianMixture.predict` method. -.. +.. Alternatively, the probability of each sample belonging to the various Gaussians may be retrieved using the - :meth:`GMM.predict_proba` method. + :meth:`GaussianMixture.predict_proba` method. -The :class:`GMM` comes with different options to constrain the covariance +The :class:`GaussianMixture` comes with different options to constrain the covariance of the difference classes estimated: spherical, diagonal, tied or full covariance. @@ -63,37 +63,37 @@ covariance. .. topic:: Examples: * See :ref:`example_mixture_plot_gmm_covariances.py` for an example of - using a GMM for clustering on the iris dataset. + using the Gaussian mixture as clustering on the iris dataset. - * See :ref:`example_mixture_plot_gmm_pdf.py` for an example on plotting the + * See :ref:`example_mixture_plot_gmm_pdf.py` for an example on plotting the density estimation. -Pros and cons of class :class:`GMM`: expectation-maximization inference ------------------------------------------------------------------------- +Pros and cons of class :class:`GaussianMixture` +----------------------------------------------- Pros ..... -:Speed: it is the fastest algorithm for learning mixture models +:Speed: It is the fastest algorithm for learning mixture models -:Agnostic: as this algorithm maximizes only the likelihood, it +:Agnostic: As this algorithm maximizes only the likelihood, it will not bias the means towards zero, or bias the cluster sizes to have specific structures that might or might not apply. Cons .... -:Singularities: when one has insufficiently many points per +:Singularities: When one has insufficiently many points per mixture, estimating the covariance matrices becomes difficult, and the algorithm is known to diverge and find solutions with infinite likelihood unless one regularizes the covariances artificially. -:Number of components: this algorithm will always use all the +:Number of components: This algorithm will always use all the components it has access to, needing held-out data - or information theoretical criteria to decide how many components to use + or information theoretical criteria to decide how many components to use in the absence of external cues. -Selecting the number of components in a classical GMM +Selecting the number of components in a classical GMM ------------------------------------------------------ The BIC criterion can be used to select the number of components in a GMM @@ -110,7 +110,7 @@ number of components for a Gaussian mixture model. .. topic:: Examples: * See :ref:`example_mixture_plot_gmm_selection.py` for an example - of model selection performed with classical GMM. + of model selection performed with classical Gaussian mixture. .. _expectation_maximization: @@ -131,7 +131,7 @@ origin) and computes for each point a probability of being generated by each component of the model. Then, one tweaks the parameters to maximize the likelihood of the data given those assignments. Repeating this process is guaranteed to always converge -to a local optimum. +to a local optimum. .. _vbgmm: @@ -139,10 +139,7 @@ VBGMM: variational Gaussian mixtures ==================================== The :class:`VBGMM` object implements a variant of the Gaussian mixture -model with :ref:`variational inference ` algorithms. The API is identical to -:class:`GMM`. It is essentially a middle-ground between :class:`GMM` -and :class:`DPGMM`, as it has some of the properties of the Dirichlet -process. +model with :ref:`variational inference ` algorithms. Pros and cons of class :class:`VBGMM`: variational inference ------------------------------------------------------------ @@ -205,7 +202,7 @@ DPGMM: Infinite Gaussian mixtures The :class:`DPGMM` object implements a variant of the Gaussian mixture model with a variable (but bounded) number of components using the -Dirichlet Process. The API is identical to :class:`GMM`. +Dirichlet Process. This class doesn't require the user to choose the number of components, and at the expense of extra computational time the user only needs to specify a loose upper bound on this number and a @@ -228,17 +225,18 @@ components on a dataset composed of 2 clusters. We can see that the DPGMM is able to limit itself to only 2 components whereas the GMM fits the data fit too many components. Note that with very little observations, the DPGMM can take a conservative stand, and fit only one component. **On the right** we are fitting -a dataset not well-depicted by a mixture of Gaussian. Adjusting the `alpha` +a dataset not well-depicted by a Gaussian mixture. Adjusting the `alpha` parameter of the DPGMM controls the number of components used to fit this data. .. topic:: Examples: * See :ref:`example_mixture_plot_gmm.py` for an example on plotting the - confidence ellipsoids for both :class:`GMM` and :class:`DPGMM`. + confidence ellipsoids for both :class:`GaussianMixture` + and :class:`DPGMM`. - * :ref:`example_mixture_plot_gmm_sin.py` shows using :class:`GMM` and - :class:`DPGMM` to fit a sine wave + * :ref:`example_mixture_plot_gmm_sin.py` shows using + :class:`GaussianMixture` and :class:`DPGMM` to fit a sine wave Pros and cons of class :class:`DPGMM`: Dirichlet process mixture model ---------------------------------------------------------------------- diff --git a/examples/mixture/plot_gmm.py b/examples/mixture/plot_gmm.py index c916c5ff06faf..d6b0839c193a7 100644 --- a/examples/mixture/plot_gmm.py +++ b/examples/mixture/plot_gmm.py @@ -20,6 +20,7 @@ per cluster than there are dimensions in the data, due to regularization properties of the inference algorithm. """ + import itertools import numpy as np @@ -29,33 +30,16 @@ from sklearn import mixture -# Number of samples per component -n_samples = 500 - -# Generate random sample, two components -np.random.seed(0) -C = np.array([[0., -0.1], [1.7, .4]]) -X = np.r_[np.dot(np.random.randn(n_samples, 2), C), - .7 * np.random.randn(n_samples, 2) + np.array([-6, 3])] - -# Fit a mixture of Gaussians with EM using five components -gmm = mixture.GMM(n_components=5, covariance_type='full') -gmm.fit(X) - -# Fit a Dirichlet process mixture of Gaussians using five components -dpgmm = mixture.DPGMM(n_components=5, covariance_type='full') -dpgmm.fit(X) - color_iter = itertools.cycle(['navy', 'c', 'cornflowerblue', 'gold', 'darkorange']) -for i, (clf, title) in enumerate([(gmm, 'GMM'), - (dpgmm, 'Dirichlet Process GMM')]): - splot = plt.subplot(2, 1, 1 + i) - Y_ = clf.predict(X) + +def plot_results(X, Y_, means, covariances, index, title): + splot = plt.subplot(2, 1, 1 + index) for i, (mean, covar, color) in enumerate(zip( - clf.means_, clf._get_covars(), color_iter)): + means, covariances, color_iter)): v, w = linalg.eigh(covar) + v = 2. * np.sqrt(2.) * np.sqrt(v) u = w[0] / linalg.norm(w[0]) # as the DP will not use every component it has access to # unless it needs it, we shouldn't plot the redundant @@ -66,16 +50,36 @@ # Plot an ellipse to show the Gaussian component angle = np.arctan(u[1] / u[0]) - angle = 180 * angle / np.pi # convert to degrees - ell = mpl.patches.Ellipse(mean, v[0], v[1], 180 + angle, color=color) + angle = 180. * angle / np.pi # convert to degrees + ell = mpl.patches.Ellipse(mean, v[0], v[1], 180. + angle, color=color) ell.set_clip_box(splot.bbox) ell.set_alpha(0.5) splot.add_artist(ell) - plt.xlim(-10, 10) - plt.ylim(-3, 6) + plt.xlim(-10., 10.) + plt.ylim(-3., 6.) plt.xticks(()) plt.yticks(()) plt.title(title) + +# Number of samples per component +n_samples = 500 + +# Generate random sample, two components +np.random.seed(0) +C = np.array([[0., -0.1], [1.7, .4]]) +X = np.r_[np.dot(np.random.randn(n_samples, 2), C), + .7 * np.random.randn(n_samples, 2) + np.array([-6, 3])] + +# Fit a Gaussian mixture with EM using five components +gmm = mixture.GaussianMixture(n_components=5, covariance_type='full').fit(X) +plot_results(X, gmm.predict(X), gmm.means_, gmm.covariances_, 0, + 'Gaussian Mixture') + +# Fit a Dirichlet process Gaussian mixture using five components +dpgmm = mixture.DPGMM(n_components=5, covariance_type='full').fit(X) +plot_results(X, dpgmm.predict(X), dpgmm.means_, dpgmm._get_covars(), 1, + 'Dirichlet Process GMM') + plt.show() diff --git a/examples/mixture/plot_gmm_covariances.py b/examples/mixture/plot_gmm_covariances.py index 0c6058a44a6cb..e3c8d8b68b43a 100644 --- a/examples/mixture/plot_gmm_covariances.py +++ b/examples/mixture/plot_gmm_covariances.py @@ -25,33 +25,40 @@ dimensions are shown here, and thus some points are separated in other dimensions. """ -print(__doc__) # Author: Ron Weiss , Gael Varoquaux +# Modified by Thierry Guillemot # License: BSD 3 clause -# $Id$ - -import matplotlib.pyplot as plt import matplotlib as mpl +import matplotlib.pyplot as plt + import numpy as np from sklearn import datasets +from sklearn.mixture import GaussianMixture from sklearn.model_selection import StratifiedKFold -from sklearn.externals.six.moves import xrange -from sklearn.mixture import GMM +print(__doc__) colors = ['navy', 'turquoise', 'darkorange'] def make_ellipses(gmm, ax): for n, color in enumerate(colors): - v, w = np.linalg.eigh(gmm._get_covars()[n][:2, :2]) + if gmm.covariance_type == 'full': + covars = gmm.covariances_[n][:2, :2] + elif gmm.covariance_type == 'tied': + covars = gmm.covariances_[:2, :2] + elif gmm.covariance_type == 'diag': + covars = np.diag(gmm.covariances_[n][:2]) + elif gmm.covariance_type == 'spherical': + covars = np.eye(gmm.means_.shape[1]) * gmm.covariances_[n] + v, w = np.linalg.eigh(covars) u = w[0] / np.linalg.norm(w[0]) angle = np.arctan2(u[1], u[0]) angle = 180 * angle / np.pi # convert to degrees - v *= 9 + v = 2. * np.sqrt(2.) * np.sqrt(v) ell = mpl.patches.Ellipse(gmm.means_[n, :2], v[0], v[1], 180 + angle, color=color) ell.set_clip_box(ax.bbox) @@ -75,14 +82,13 @@ def make_ellipses(gmm, ax): n_classes = len(np.unique(y_train)) # Try GMMs using different types of covariances. -estimators = dict((covar_type, - GMM(n_components=n_classes, covariance_type=covar_type, - init_params='wc', n_iter=20)) +estimators = dict((covar_type, GaussianMixture(n_components=n_classes, + covariance_type=covar_type, max_iter=20)) for covar_type in ['spherical', 'diag', 'tied', 'full']) n_estimators = len(estimators) -plt.figure(figsize=(3 * n_estimators / 2, 6)) +plt.figure(figsize=(3 * n_estimators // 2, 6)) plt.subplots_adjust(bottom=.01, top=0.95, hspace=.15, wspace=.05, left=.01, right=.99) @@ -90,13 +96,13 @@ def make_ellipses(gmm, ax): for index, (name, estimator) in enumerate(estimators.items()): # Since we have class labels for the training data, we can # initialize the GMM parameters in a supervised manner. - estimator.means_ = np.array([X_train[y_train == i].mean(axis=0) - for i in xrange(n_classes)]) + estimator.means_init = np.array([X_train[y_train == i].mean(axis=0) + for i in range(n_classes)]) # Train the other parameters using the EM algorithm. estimator.fit(X_train) - h = plt.subplot(2, n_estimators / 2, index + 1) + h = plt.subplot(2, n_estimators // 2, index + 1) make_ellipses(estimator, h) for n, color in enumerate(colors): @@ -122,7 +128,7 @@ def make_ellipses(gmm, ax): plt.yticks(()) plt.title(name) -plt.legend(loc='lower right', prop=dict(size=12)) +plt.legend(scatterpoints=1, loc='lower right', prop=dict(size=12)) plt.show() diff --git a/examples/mixture/plot_gmm_pdf.py b/examples/mixture/plot_gmm_pdf.py index a1b7ea5de202f..4469c36a89625 100644 --- a/examples/mixture/plot_gmm_pdf.py +++ b/examples/mixture/plot_gmm_pdf.py @@ -1,7 +1,7 @@ """ -============================================= -Density Estimation for a mixture of Gaussians -============================================= +========================================= +Density Estimation for a Gaussian mixture +========================================= Plot the density estimation of a mixture of two Gaussians. Data is generated from two Gaussians with different centers and covariance @@ -29,15 +29,15 @@ X_train = np.vstack([shifted_gaussian, stretched_gaussian]) # fit a Gaussian Mixture Model with two components -clf = mixture.GMM(n_components=2, covariance_type='full') +clf = mixture.GaussianMixture(n_components=2, covariance_type='full') clf.fit(X_train) # display predicted scores by the model as a contour plot -x = np.linspace(-20.0, 30.0) -y = np.linspace(-20.0, 40.0) +x = np.linspace(-20., 30.) +y = np.linspace(-20., 40.) X, Y = np.meshgrid(x, y) XX = np.array([X.ravel(), Y.ravel()]).T -Z = -clf.score_samples(XX)[0] +Z = -clf.score_samples(XX) Z = Z.reshape(X.shape) CS = plt.contour(X, Y, Z, norm=LogNorm(vmin=1.0, vmax=1000.0), diff --git a/examples/mixture/plot_gmm_selection.py b/examples/mixture/plot_gmm_selection.py index cf7f8426a6b9f..747dc0d8a90c7 100644 --- a/examples/mixture/plot_gmm_selection.py +++ b/examples/mixture/plot_gmm_selection.py @@ -1,7 +1,7 @@ """ -================================= +================================ Gaussian Mixture Model Selection -================================= +================================ This example shows that model selection can be performed with Gaussian Mixture Models using information-theoretic criteria (BIC). @@ -14,17 +14,18 @@ In that case, the model with 2 components and full covariance (which corresponds to the true generative model) is selected. """ -print(__doc__) +import numpy as np import itertools -import numpy as np from scipy import linalg import matplotlib.pyplot as plt import matplotlib as mpl from sklearn import mixture +print(__doc__) + # Number of samples per component n_samples = 500 @@ -40,8 +41,9 @@ cv_types = ['spherical', 'tied', 'diag', 'full'] for cv_type in cv_types: for n_components in n_components_range: - # Fit a mixture of Gaussians with EM - gmm = mixture.GMM(n_components=n_components, covariance_type=cv_type) + # Fit a Gaussian mixture with EM + gmm = mixture.GaussianMixture(n_components=n_components, + covariance_type=cv_type) gmm.fit(X) bic.append(gmm.bic(X)) if bic[-1] < lowest_bic: @@ -73,7 +75,7 @@ # Plot the winner splot = plt.subplot(2, 1, 2) Y_ = clf.predict(X) -for i, (mean, covar, color) in enumerate(zip(clf.means_, clf.covars_, +for i, (mean, covar, color) in enumerate(zip(clf.means_, clf.covariances_, color_iter)): v, w = linalg.eigh(covar) if not np.any(Y_ == i): @@ -82,15 +84,13 @@ # Plot an ellipse to show the Gaussian component angle = np.arctan2(w[0][1], w[0][0]) - angle = 180 * angle / np.pi # convert to degrees - v *= 4 - ell = mpl.patches.Ellipse(mean, v[0], v[1], 180 + angle, color=color) + angle = 180. * angle / np.pi # convert to degrees + v = 2. * np.sqrt(2.) * np.sqrt(v) + ell = mpl.patches.Ellipse(mean, v[0], v[1], 180. + angle, color=color) ell.set_clip_box(splot.bbox) ell.set_alpha(.5) splot.add_artist(ell) -plt.xlim(-10, 10) -plt.ylim(-3, 6) plt.xticks(()) plt.yticks(()) plt.title('Selected GMM: full model, 2 components') diff --git a/examples/mixture/plot_gmm_sin.py b/examples/mixture/plot_gmm_sin.py index 6538ee80e4626..cc9217740a195 100644 --- a/examples/mixture/plot_gmm_sin.py +++ b/examples/mixture/plot_gmm_sin.py @@ -22,60 +22,69 @@ import matplotlib as mpl from sklearn import mixture -from sklearn.externals.six.moves import xrange -# Number of samples per component -n_samples = 100 - -# Generate random sample following a sine curve -np.random.seed(0) -X = np.zeros((n_samples, 2)) -step = 4 * np.pi / n_samples - -for i in xrange(X.shape[0]): - x = i * step - 6 - X[i, 0] = x + np.random.normal(0, 0.1) - X[i, 1] = 3 * (np.sin(x) + np.random.normal(0, .2)) - -color_iter = itertools.cycle(['navy', 'turquoise', 'cornflowerblue', +color_iter = itertools.cycle(['navy', 'c', 'cornflowerblue', 'gold', 'darkorange']) -for i, (clf, title) in enumerate([ - (mixture.GMM(n_components=10, covariance_type='full', n_iter=100), - "Expectation-maximization"), - (mixture.DPGMM(n_components=10, covariance_type='full', alpha=0.01, - n_iter=100), - "Dirichlet Process,alpha=0.01"), - (mixture.DPGMM(n_components=10, covariance_type='diag', alpha=100., - n_iter=100), - "Dirichlet Process,alpha=100.")]): - - clf.fit(X) - splot = plt.subplot(3, 1, 1 + i) - Y_ = clf.predict(X) + +def plot_results(X, Y_, means, covariances, index, title): + splot = plt.subplot(3, 1, 1 + index) for i, (mean, covar, color) in enumerate(zip( - clf.means_, clf._get_covars(), color_iter)): + means, covariances, color_iter)): v, w = linalg.eigh(covar) + v = 2. * np.sqrt(2.) * np.sqrt(v) u = w[0] / linalg.norm(w[0]) # as the DP will not use every component it has access to # unless it needs it, we shouldn't plot the redundant # components. if not np.any(Y_ == i): continue - plt.scatter(X[Y_ == i, 0], X[Y_ == i, 1], color=color, s=4) + plt.scatter(X[Y_ == i, 0], X[Y_ == i, 1], .8, color=color) # Plot an ellipse to show the Gaussian component angle = np.arctan(u[1] / u[0]) - angle = 180 * angle / np.pi # convert to degrees - ell = mpl.patches.Ellipse(mean, v[0], v[1], 180 + angle, color=color) + angle = 180. * angle / np.pi # convert to degrees + ell = mpl.patches.Ellipse(mean, v[0], v[1], 180. + angle, color=color) ell.set_clip_box(splot.bbox) ell.set_alpha(0.5) splot.add_artist(ell) - plt.xlim(-6, 4 * np.pi - 6) - plt.ylim(-5, 5) + plt.xlim(-6., 4. * np.pi - 6.) + plt.ylim(-5., 5.) plt.title(title) plt.xticks(()) plt.yticks(()) + +# Number of samples per component +n_samples = 100 + +# Generate random sample following a sine curve +np.random.seed(0) +X = np.zeros((n_samples, 2)) +step = 4. * np.pi / n_samples + +for i in range(X.shape[0]): + x = i * step - 6. + X[i, 0] = x + np.random.normal(0, 0.1) + X[i, 1] = 3. * (np.sin(x) + np.random.normal(0, .2)) + +# Fit a Gaussian mixture with EM using ten components +gmm = mixture.GaussianMixture(n_components=10, covariance_type='full', + max_iter=100).fit(X) +plot_results(X, gmm.predict(X), gmm.means_, gmm.covariances_, 0, + 'Expectation-maximization') + +# Fit a Dirichlet process Gaussian mixture using ten components +dpgmm = mixture.DPGMM(n_components=10, covariance_type='full', alpha=0.01, + n_iter=100).fit(X) +plot_results(X, dpgmm.predict(X), dpgmm.means_, dpgmm._get_covars(), 1, + 'Dirichlet Process,alpha=0.01') + + +# Fit a Dirichlet process Gaussian mixture using ten components +dpgmm = mixture.DPGMM(n_components=10, covariance_type='diag', alpha=100., + n_iter=100).fit(X) +plot_results(X, dpgmm.predict(X), dpgmm.means_, dpgmm._get_covars(), 2, + 'Dirichlet Process,alpha=100.') plt.show() diff --git a/sklearn/base.py b/sklearn/base.py index 45eacf5af2a90..8c3a9a8eba4da 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -463,6 +463,24 @@ def fit_transform(self, X, y=None, **fit_params): return self.fit(X, y, **fit_params).transform(X) +class DensityMixin(object): + """Mixin class for all density estimators in scikit-learn.""" + _estimator_type = "DensityEstimator" + + def score(self, X, y=None): + """Returns the score of the model on the data X + + Parameters + ---------- + X : array-like, shape = (n_samples, n_features) + + Returns + ------- + score: float + """ + pass + + ############################################################################### class MetaEstimatorMixin(object): """Mixin class for all meta estimators in scikit-learn.""" diff --git a/sklearn/mixture/__init__.py b/sklearn/mixture/__init__.py index b3ac173bab9b9..8269ec4a31d91 100644 --- a/sklearn/mixture/__init__.py +++ b/sklearn/mixture/__init__.py @@ -7,10 +7,14 @@ from .gmm import _validate_covars from .dpgmm import DPGMM, VBGMM +from .gaussian_mixture import GaussianMixture + + __all__ = ['DPGMM', 'GMM', 'VBGMM', '_validate_covars', 'distribute_covar_matrix_to_match_covariance_type', 'log_multivariate_normal_density', - 'sample_gaussian'] + 'sample_gaussian', + 'GaussianMixture'] diff --git a/sklearn/mixture/base.py b/sklearn/mixture/base.py new file mode 100644 index 0000000000000..3b4a043534019 --- /dev/null +++ b/sklearn/mixture/base.py @@ -0,0 +1,442 @@ +"""Base class for mixture models.""" + +# Author: Wei Xue +# Modified by Thierry Guillemot + +from __future__ import print_function + +import warnings +from abc import ABCMeta, abstractmethod +from time import time + +import numpy as np + +from .. import cluster +from ..base import BaseEstimator +from ..base import DensityMixin +from ..externals import six +from ..exceptions import ConvergenceWarning +from ..utils import check_array, check_random_state +from ..utils.extmath import logsumexp + + +def _check_shape(param, param_shape, name): + """Validate the shape of the input parameter 'param'. + + Parameters + ---------- + param : array + + param_shape : tuple + + name : string + """ + param = np.array(param) + if param.shape != param_shape: + raise ValueError("The parameter '%s' should have the shape of %s, " + "but got %s" % (name, param_shape, param.shape)) + + +def _check_X(X, n_components=None, n_features=None): + """Check the input data X. + + Parameters + ---------- + X : array-like, shape (n_samples, n_features) + + n_components : int + + Returns + ------- + X : array, shape (n_samples, n_features) + """ + X = check_array(X, dtype=[np.float64, np.float32]) + if n_components is not None and X.shape[0] < n_components: + raise ValueError('Expected n_samples >= n_components ' + 'but got n_components = %d, n_samples = %d' + % (n_components, X.shape[0])) + if n_features is not None and X.shape[1] != n_features: + raise ValueError("Expected the input data X have %d features, " + "but got %d features" + % (n_features, X.shape[1])) + return X + + +class BaseMixture(six.with_metaclass(ABCMeta, DensityMixin, BaseEstimator)): + """Base class for mixture models. + + This abstract class specifies an interface for all mixture classes and + provides basic common methods for mixture models. + """ + + def __init__(self, n_components, tol, reg_covar, + max_iter, n_init, init_params, random_state, warm_start, + verbose, verbose_interval): + self.n_components = n_components + self.tol = tol + self.reg_covar = reg_covar + self.max_iter = max_iter + self.n_init = n_init + self.init_params = init_params + self.random_state = random_state + self.warm_start = warm_start + self.verbose = verbose + self.verbose_interval = verbose_interval + + def _check_initial_parameters(self, X): + """Check values of the basic parameters. + + Parameters + ---------- + X : array-like, shape (n_samples, n_features) + """ + if self.n_components < 1: + raise ValueError("Invalid value for 'n_components': %d " + "Estimation requires at least one component" + % self.n_components) + + if self.tol < 0.: + raise ValueError("Invalid value for 'tol': %.5f " + "Tolerance used by the EM must be non-negative" + % self.tol) + + if self.n_init < 1: + raise ValueError("Invalid value for 'n_init': %d " + "Estimation requires at least one run" + % self.n_init) + + if self.max_iter < 1: + raise ValueError("Invalid value for 'max_iter': %d " + "Estimation requires at least one iteration" + % self.max_iter) + + if self.reg_covar < 0.: + raise ValueError("Invalid value for 'reg_covar': %.5f " + "regularization on covariance must be " + "non-negative" + % self.reg_covar) + + # Check all the parameters values of the derived class + self._check_parameters(X) + + @abstractmethod + def _check_parameters(self, X): + """Check initial parameters of the derived class. + + Parameters + ---------- + X : array-like, shape (n_samples, n_features) + """ + pass + + def _initialize_parameters(self, X): + """Initialize the model parameters. + + Parameters + ---------- + X : array-like, shape (n_samples, n_features) + """ + n_samples = X.shape[0] + random_state = check_random_state(self.random_state) + + if self.init_params == 'kmeans': + resp = np.zeros((n_samples, self.n_components)) + label = cluster.KMeans(n_clusters=self.n_components, n_init=1, + random_state=random_state).fit(X).labels_ + resp[np.arange(n_samples), label] = 1 + elif self.init_params == 'random': + resp = random_state.rand(X.shape[0], self.n_components) + resp /= resp.sum(axis=1)[:, np.newaxis] + else: + raise ValueError("Unimplemented initialization method '%s'" + % self.init_params) + + self._initialize(X, resp) + + @abstractmethod + def _initialize(self, X, resp): + """Initialize the model parameters of the derived class. + + Parameters + ---------- + X : array-like, shape (n_samples, n_features) + + resp : array-like, shape (n_samples, n_components) + """ + pass + + def fit(self, X, y=None): + """Estimate model parameters with the EM algorithm. + + The method fit the model `n_init` times and set the parameters with + which the model has the largest likelihood or lower bound. Within each + trial, the method iterates between E-step and M-step for `max_iter` + times until the change of likelihood or lower bound is less than + `tol`, otherwise, a `ConvergenceWarning` is raised. + + Parameters + ---------- + X : array-like, shape (n_samples, n_features) + List of n_features-dimensional data points. Each row + corresponds to a single data point. + + Returns + ------- + self + """ + X = _check_X(X, self.n_components) + self._check_initial_parameters(X) + + # if we enable warm_start, we will have a unique initialisation + do_init = not(self.warm_start and hasattr(self, 'converged_')) + n_init = self.n_init if do_init else 1 + + max_log_likelihood = -np.infty + self.converged_ = False + + for init in range(n_init): + self._print_verbose_msg_init_beg(init) + + if do_init: + self._initialize_parameters(X) + current_log_likelihood, resp = self._e_step(X) + + for n_iter in range(self.max_iter): + prev_log_likelihood = current_log_likelihood + + self._m_step(X, resp) + current_log_likelihood, resp = self._e_step(X) + change = current_log_likelihood - prev_log_likelihood + self._print_verbose_msg_iter_end(n_iter, change) + + if abs(change) < self.tol: + self.converged_ = True + break + + self._print_verbose_msg_init_end(current_log_likelihood) + + if current_log_likelihood > max_log_likelihood: + max_log_likelihood = current_log_likelihood + best_params = self._get_parameters() + best_n_iter = n_iter + + if not self.converged_: + warnings.warn('Initialization %d did not converged. ' + 'Try different init parameters, ' + 'or increase n_init, tol ' + 'or check for degenerate data.' + % (init + 1), ConvergenceWarning) + + self._set_parameters(best_params) + self.n_iter_ = best_n_iter + + return self + + @abstractmethod + def _e_step(self, X): + """E step. + + Parameters + ---------- + X : array-like, shape (n_samples, n_features) + + Returns + ------- + log-likelihood : scalar + + responsibility : array, shape (n_samples, n_components) + """ + pass + + @abstractmethod + def _m_step(self, X, resp): + """M step. + + Parameters + ---------- + X : array-like, shape (n_samples, n_features) + + resp : array-like, shape (n_samples, n_components) + """ + pass + + @abstractmethod + def _check_is_fitted(self): + pass + + @abstractmethod + def _get_parameters(self): + pass + + @abstractmethod + def _set_parameters(self, params): + pass + + def score_samples(self, X): + """Compute the weighted log probabilities for each sample. + + Parameters + ---------- + X : array-like, shape (n_samples, n_features) + List of n_features-dimensional data points. Each row + corresponds to a single data point. + + Returns + ------- + log_prob : array, shape (n_samples,) + Log probabilities of each data point in X. + """ + self._check_is_fitted() + X = _check_X(X, None, self.means_.shape[1]) + + return logsumexp(self._estimate_weighted_log_prob(X), axis=1) + + def score(self, X, y=None): + """Compute the per-sample average log-likelihood of the given data X. + + Parameters + ---------- + X : array-like, shape (n_samples, n_dimensions) + List of n_features-dimensional data points. Each row + corresponds to a single data point. + + Returns + ------- + log_likelihood : float + Log likelihood of the Gaussian mixture given X. + """ + return self.score_samples(X).mean() + + def predict(self, X, y=None): + """Predict the labels for the data samples in X using trained model. + + Parameters + ---------- + X : array-like, shape (n_samples, n_features) + List of n_features-dimensional data points. Each row + corresponds to a single data point. + + Returns + ------- + labels : array, shape (n_samples,) + Component labels. + """ + self._check_is_fitted() + X = _check_X(X, None, self.means_.shape[1]) + return self._estimate_weighted_log_prob(X).argmax(axis=1) + + def predict_proba(self, X): + """Predict posterior probability of data per each component. + + Parameters + ---------- + X : array-like, shape (n_samples, n_features) + List of n_features-dimensional data points. Each row + corresponds to a single data point. + + Returns + ------- + resp : array, shape (n_samples, n_components) + Returns the probability of the sample for each Gaussian + (state) in the model. + """ + self._check_is_fitted() + X = _check_X(X, None, self.means_.shape[1]) + _, _, log_resp = self._estimate_log_prob_resp(X) + return np.exp(log_resp) + + def _estimate_weighted_log_prob(self, X): + """Estimate the weighted log-probabilities, log P(X | Z) + log weights. + + Parameters + ---------- + X : array-like, shape (n_samples, n_features) + + Returns + ------- + weighted_log_prob : array, shape (n_features, n_component) + """ + return self._estimate_log_prob(X) + self._estimate_log_weights() + + @abstractmethod + def _estimate_log_weights(self): + """Estimate log-weights in EM algorithm, E[ log pi ] in VB algorithm. + + Returns + ------- + log_weight : array, shape (n_components, ) + """ + pass + + @abstractmethod + def _estimate_log_prob(self, X): + """Estimate the log-probabilities log P(X | Z). + + Compute the log-probabilities per each component for each sample. + + Parameters + ---------- + X : array-like, shape (n_samples, n_features) + + Returns + ------- + log_prob : array, shape (n_samples, n_component) + """ + pass + + def _estimate_log_prob_resp(self, X): + """Estimate log probabilities and responsibilities for each sample. + + Compute the log probabilities, weighted log probabilities per + component and responsibilities for each sample in X with respect to + the current state of the model. + + Parameters + ---------- + X : array-like, shape (n_samples, n_features) + + Returns + ------- + log_prob_norm : array, shape (n_samples,) + log p(X) + + log_prob : array, shape (n_samples, n_components) + log p(X|Z) + log weights + + log_responsibilities : array, shape (n_samples, n_components) + logarithm of the responsibilities + """ + weighted_log_prob = self._estimate_weighted_log_prob(X) + log_prob_norm = logsumexp(weighted_log_prob, axis=1) + with np.errstate(under='ignore'): + # ignore underflow + log_resp = weighted_log_prob - log_prob_norm[:, np.newaxis] + return log_prob_norm, weighted_log_prob, log_resp + + def _print_verbose_msg_init_beg(self, n_init): + """Print verbose message on initialization.""" + if self.verbose == 1: + print("Initialization %d" % n_init) + elif self.verbose >= 2: + print("Initialization %d" % n_init) + self._init_prev_time = time() + self._iter_prev_time = self._init_prev_time + + def _print_verbose_msg_iter_end(self, n_iter, diff_ll): + """Print verbose message on initialization.""" + if n_iter % self.verbose_interval == 0: + if self.verbose == 1: + print(" Iteration %d" % n_iter) + elif self.verbose >= 2: + cur_time = time() + print(" Iteration %d\t time lapse %.5fs\t ll change %.5f" % ( + n_iter, cur_time - self._iter_prev_time, diff_ll)) + self._iter_prev_time = cur_time + + def _print_verbose_msg_init_end(self, ll): + """Print verbose message on the end of iteration.""" + if self.verbose == 1: + print("Initialization converged: %s" % self.converged_) + elif self.verbose >= 2: + print("Initialization converged: %s\t time lapse %.5fs\t ll %.5f" % + (self.converged_, time() - self._init_prev_time, ll)) diff --git a/sklearn/mixture/dpgmm.py b/sklearn/mixture/dpgmm.py index 815e900ca26e3..9372d51573c26 100644 --- a/sklearn/mixture/dpgmm.py +++ b/sklearn/mixture/dpgmm.py @@ -20,7 +20,7 @@ from ..utils.extmath import logsumexp, pinvh, squared_norm from ..utils.validation import check_is_fitted from .. import cluster -from .gmm import GMM +from .gmm import _GMMBase def digamma(x): @@ -106,7 +106,7 @@ def _bound_state_log_lik(X, initial_bound, precs, means, covariance_type): return bound -class DPGMM(GMM): +class DPGMM(_GMMBase): """Variational Inference for the Infinite Gaussian Mixture Model. DPGMM stands for Dirichlet Process Gaussian Mixture Model, and it @@ -492,7 +492,7 @@ def _fit(self, X, y=None): A initialization step is performed before entering the em algorithm. If you want to avoid this step, set the keyword - argument init_params to the empty string '' when creating + argument init_params to the empty string '' when when creating the object. Likewise, if you would like just to do an initialization, set n_iter=0. diff --git a/sklearn/mixture/gaussian_mixture.py b/sklearn/mixture/gaussian_mixture.py new file mode 100644 index 0000000000000..90c2ed7ea0ba0 --- /dev/null +++ b/sklearn/mixture/gaussian_mixture.py @@ -0,0 +1,651 @@ +"""Gaussian Mixture Model.""" + +# Author: Wei Xue +# Modified by Thierry Guillemot + +import numpy as np + +from scipy import linalg + +from .base import BaseMixture, _check_shape +from ..externals.six.moves import zip +from ..utils import check_array +from ..utils.validation import check_is_fitted + + +############################################################################### +# Gaussian mixture shape checkers used by the GaussianMixture class + +def _check_weights(weights, n_components): + """Check the user provided 'weights'. + + Parameters + ---------- + weights : array-like, shape (n_components,) + The proportions of components of each mixture. + + n_components : int + Number of components. + + Returns + ------- + weights : array, shape (n_components,) + """ + weights = check_array(weights, dtype=[np.float64, np.float32], + ensure_2d=False) + _check_shape(weights, (n_components,), 'weights') + + # check range + if (any(np.less(weights, 0)) or + any(np.greater(weights, 1))): + raise ValueError("The parameter 'weights' should be in the range " + "[0, 1], but got max value %.5f, min value %.5f" + % (np.min(weights), np.max(weights))) + + # check normalization + if not np.allclose(np.abs(1 - np.sum(weights)), 0.0): + raise ValueError("The parameter 'weights' should be normalized, " + "but got sum(weights) = %.5f" % np.sum(weights)) + return weights + + +def _check_means(means, n_components, n_features): + """Validate the provided 'means'. + + Parameters + ---------- + means : array-like, shape (n_components, n_features) + The centers of the current components. + + n_components : int + Number of components. + + n_features : int + Number of features. + + Returns + ------- + means : array, (n_components, n_features) + """ + means = check_array(means, dtype=[np.float64, np.float32], ensure_2d=False) + _check_shape(means, (n_components, n_features), 'means') + return means + + +def _check_covariance_matrix(covariance, covariance_type): + """Check a covariance matrix is symmetric and positive-definite.""" + if (not np.allclose(covariance, covariance.T) or + np.any(np.less_equal(linalg.eigvalsh(covariance), .0))): + raise ValueError("'%s covariance' should be symmetric, " + "positive-definite" % covariance_type) + + +def _check_covariance_positivity(covariance, covariance_type): + """Check a covariance vector is positive-definite.""" + if np.any(np.less_equal(covariance, 0.0)): + raise ValueError("'%s covariance' should be " + "positive" % covariance_type) + + +def _check_covariances_full(covariances, covariance_type): + """Check the covariance matrices are symmetric and positive-definite.""" + for k, cov in enumerate(covariances): + _check_covariance_matrix(cov, covariance_type) + + +def _check_covariances(covariances, covariance_type, n_components, n_features): + """Validate user provided covariances. + + Parameters + ---------- + covariances : array-like, + 'full' : shape of (n_components, n_features, n_features) + 'tied' : shape of (n_features, n_features) + 'diag' : shape of (n_components, n_features) + 'spherical' : shape of (n_components,) + + covariance_type : string + + n_components : int + Number of components. + + n_features : int + Number of features. + + Returns + ------- + covariances : array + """ + covariances = check_array(covariances, dtype=[np.float64, np.float32], + ensure_2d=False, + allow_nd=covariance_type is 'full') + + covariances_shape = {'full': (n_components, n_features, n_features), + 'tied': (n_features, n_features), + 'diag': (n_components, n_features), + 'spherical': (n_components,)} + _check_shape(covariances, covariances_shape[covariance_type], + '%s covariance' % covariance_type) + + check_functions = {'full': _check_covariances_full, + 'tied': _check_covariance_matrix, + 'diag': _check_covariance_positivity, + 'spherical': _check_covariance_positivity} + check_functions[covariance_type](covariances, covariance_type) + + return covariances + + +############################################################################### +# Gaussian mixture parameters estimators (used by the M-Step) + +def _estimate_gaussian_covariance_full(resp, X, nk, means, reg_covar): + """Estimate the full covariance matrices. + + Parameters + ---------- + resp : array-like, shape (n_samples, n_components) + + X : array-like, shape (n_samples, n_features) + + nk : array-like, shape (n_components,) + + means : array-like, shape (n_components, n_features) + + reg_covar : float + + Returns + ------- + covariances : array, shape (n_components, n_features, n_features) + """ + n_features = X.shape[1] + n_components = means.shape[0] + covariances = np.empty((n_components, n_features, n_features)) + for k in range(n_components): + diff = X - means[k] + covariances[k] = np.dot(resp[:, k] * diff.T, diff) / nk[k] + covariances[k].flat[::n_features + 1] += reg_covar + return covariances + + +def _estimate_gaussian_covariance_tied(resp, X, nk, means, reg_covar): + """Estimate the tied covariance matrix. + + Parameters + ---------- + resp : array-like, shape (n_samples, n_components) + + X : array-like, shape (n_samples, n_features) + + nk : array-like, shape (n_components,) + + means : array-like, shape (n_components, n_features) + + reg_covar : float + + Returns + ------- + covariances : array, shape (n_features, n_features) + """ + avg_X2 = np.dot(X.T, X) + avg_means2 = np.dot(nk * means.T, means) + covariances = avg_X2 - avg_means2 + covariances /= X.shape[0] + covariances.flat[::len(covariances) + 1] += reg_covar + return covariances + + +def _estimate_gaussian_covariance_diag(resp, X, nk, means, reg_covar): + """Estimate the diagonal covariance matrices. + + Parameters + ---------- + responsibilities : array-like, shape (n_samples, n_components) + + X : array-like, shape (n_samples, n_features) + + nk : array-like, shape (n_components,) + + means : array-like, shape (n_components, n_features) + + reg_covar : float + + Returns + ------- + covariances : array, shape (n_components, n_features) + """ + avg_X2 = np.dot(resp.T, X * X) / nk[:, np.newaxis] + avg_means2 = means ** 2 + avg_X_means = means * np.dot(resp.T, X) / nk[:, np.newaxis] + return avg_X2 - 2 * avg_X_means + avg_means2 + reg_covar + + +def _estimate_gaussian_covariance_spherical(resp, X, nk, means, reg_covar): + """Estimate the spherical covariance matrices. + + Parameters + ---------- + responsibilities : array-like, shape (n_samples, n_components) + + X : array-like, shape (n_samples, n_features) + + nk : array-like, shape (n_components,) + + means : array-like, shape (n_components, n_features) + + reg_covar : float + + Returns + ------- + covariances : array, shape (n_components,) + """ + covariances = _estimate_gaussian_covariance_diag(resp, X, nk, means, + reg_covar) + return covariances.mean(axis=1) + + +def _estimate_gaussian_parameters(X, resp, reg_covar, covariance_type): + """Estimate the Gaussian distribution parameters. + + Parameters + ---------- + X : array-like, shape (n_samples, n_features) + The input data array. + + resp : array-like, shape (n_samples, n_features) + The responsibilities for each data sample in X. + + reg_covar : float + The regularization added to each covariance matrices. + + covariance_type : {'full', 'tied', 'diag', 'spherical'} + The type of covariance matrices. + + Returns + ------- + nk : array, shape (n_components,) + The numbers of data samples in the current components. + + means : array, shape (n_components, n_features) + The centers of the current components. + + covariances : array + The sample covariances of the current components. + The shape depends of the covariance_type. + """ + compute_covariance = { + "full": _estimate_gaussian_covariance_full, + "tied": _estimate_gaussian_covariance_tied, + "diag": _estimate_gaussian_covariance_diag, + "spherical": _estimate_gaussian_covariance_spherical} + + nk = resp.sum(axis=0) + 10 * np.finfo(float).eps + means = np.dot(resp.T, X) / nk[:, np.newaxis] + covariances = compute_covariance[covariance_type]( + resp, X, nk, means, reg_covar) + + return nk, means, covariances + + +############################################################################### +# Gaussian mixture probability estimators + +def _estimate_log_gaussian_prob_full(X, means, covariances): + """Estimate the log Gaussian probability for 'full' covariance. + + Parameters + ---------- + X : array-like, shape (n_samples, n_features) + + means : array-like, shape (n_components, n_features) + + covariances : array-like, shape (n_components, n_features, n_features) + + Returns + ------- + log_prob : array, shape (n_samples, n_components) + """ + n_samples, n_features = X.shape + n_components = means.shape[0] + log_prob = np.empty((n_samples, n_components)) + for k, (mu, cov) in enumerate(zip(means, covariances)): + try: + cov_chol = linalg.cholesky(cov, lower=True) + except linalg.LinAlgError: + raise ValueError("The algorithm has diverged because of too " + "few samples per components. " + "Try to decrease the number of components, or " + "increase reg_covar.") + cv_log_det = 2. * np.sum(np.log(np.diagonal(cov_chol))) + cv_sol = linalg.solve_triangular(cov_chol, (X - mu).T, + lower=True).T + log_prob[:, k] = - .5 * (n_features * np.log(2. * np.pi) + + cv_log_det + + np.sum(np.square(cv_sol), axis=1)) + return log_prob + + +def _estimate_log_gaussian_prob_tied(X, means, covariances): + """Estimate the log Gaussian probability for 'tied' covariance. + + Parameters + ---------- + X : array-like, shape (n_samples, n_features) + + means : array-like, shape (n_components, n_features) + + covariances : array-like, shape (n_features, n_features) + + Returns + ------- + log_prob : array-like, shape (n_samples, n_components) + """ + n_samples, n_features = X.shape + n_components = means.shape[0] + log_prob = np.empty((n_samples, n_components)) + try: + cov_chol = linalg.cholesky(covariances, lower=True) + except linalg.LinAlgError: + raise ValueError("The algorithm has diverged because of too " + "few samples per components. " + "Try to decrease the number of components, or " + "increase reg_covar.") + cv_log_det = 2. * np.sum(np.log(np.diagonal(cov_chol))) + for k, mu in enumerate(means): + cv_sol = linalg.solve_triangular(cov_chol, (X - mu).T, + lower=True).T + log_prob[:, k] = np.sum(np.square(cv_sol), axis=1) + log_prob = - .5 * (n_features * np.log(2. * np.pi) + cv_log_det + log_prob) + return log_prob + + +def _estimate_log_gaussian_prob_diag(X, means, covariances): + """Estimate the log Gaussian probability for 'diag' covariance. + + Parameters + ---------- + X : array-like, shape (n_samples, n_features) + + means : array-like, shape (n_components, n_features) + + covariances : array-like, shape (n_components, n_features) + + Returns + ------- + log_prob : array-like, shape (n_samples, n_components) + """ + if np.any(np.less_equal(covariances, 0.0)): + raise ValueError("The algorithm has diverged because of too " + "few samples per components. " + "Try to decrease the number of components, or " + "increase reg_covar.") + n_samples, n_features = X.shape + log_prob = - .5 * (n_features * np.log(2. * np.pi) + + np.sum(np.log(covariances), 1) + + np.sum((means ** 2 / covariances), 1) - + 2. * np.dot(X, (means / covariances).T) + + np.dot(X ** 2, (1. / covariances).T)) + return log_prob + + +def _estimate_log_gaussian_prob_spherical(X, means, covariances): + """Estimate the log Gaussian probability for 'spherical' covariance. + + Parameters + ---------- + X : array-like, shape (n_samples, n_features) + + means : array-like, shape (n_components, n_features) + + covariances : array-like, shape (n_components, ) + + Returns + ------- + log_prob : array-like, shape (n_samples, n_components) + """ + if np.any(np.less_equal(covariances, 0.0)): + raise ValueError("The algorithm has diverged because of too " + "few samples per components. " + "Try to decrease the number of components, or " + "increase reg_covar.") + n_samples, n_features = X.shape + log_prob = - .5 * (n_features * np.log(2 * np.pi) + + n_features * np.log(covariances) + + np.sum(means ** 2, 1) / covariances - + 2 * np.dot(X, means.T / covariances) + + np.outer(np.sum(X ** 2, axis=1), 1. / covariances)) + return log_prob + + +class GaussianMixture(BaseMixture): + """Gaussian Mixture. + + Representation of a Gaussian mixture model probability distribution. + This class allows to estimate the parameters of a Gaussian mixture + distribution. + + Parameters + ---------- + n_components : int, defaults to 1. + The number of mixture components. + + covariance_type : {'full', 'tied', 'diag', 'spherical'}, + defaults to 'full'. + String describing the type of covariance parameters to use. + Must be one of:: + 'full' (each component has its own general covariance matrix). + 'tied' (all components share the same general covariance matrix), + 'diag' (each component has its own diagonal covariance matrix), + 'spherical' (each component has its own single variance), + + tol : float, defaults to 1e-3. + The convergence threshold. EM iterations will stop when the + log_likelihood average gain is below this threshold. + + reg_covar : float, defaults to 0. + Non-negative regularization added to the diagonal of covariance. + Allows to assure that the covariance matrices are all positive. + + max_iter : int, defaults to 100. + The number of EM iterations to perform. + + n_init : int, defaults to 1. + The number of initializations to perform. The best results is kept. + + init_params : {'kmeans', 'random'}, defaults to 'kmeans'. + The method used to initialize the weights, the means and the + covariances. + Must be one of:: + 'kmeans' : responsibilities are initialized using kmeans. + 'random' : responsibilities are initialized randomly. + + weights_init : array-like, shape (n_components, ), optional + The user-provided initial weights, defaults to None. + If it None, weights are initialized using the `init_params` method. + + means_init: array-like, shape (n_components, n_features), optional + The user-provided initial means, defaults to None, + If it None, means are initialized using the `init_params` method. + + covariances_init: array-like, optional. + The user-provided initial covariances, defaults to None. + If it None, covariances are initialized using the 'init_params' method. + The shape depends on 'covariance_type':: + (n_components,) if 'spherical', + (n_features, n_features) if 'tied', + (n_components, n_features) if 'diag', + (n_components, n_features, n_features) if 'full' + + random_state: RandomState or an int seed, defaults to None. + A random number generator instance. + + warm_start : bool, default to False. + If 'warm_start' is True, the solution of the last fitting is used as + initialization for the next call of fit(). This can speed up + convergence when fit is called several time on similar problems. + + verbose : int, default to 0. + Enable verbose output. If 1 then it prints the current + initialization and each iteration step. If greater than 1 then + it prints also the log probability and the time needed + for each step. + + Attributes + ---------- + weights_ : array, shape (n_components,) + The weights of each mixture components. + `weights_` will not exist before a call to fit. + + means_ : array, shape (n_components, n_features) + The mean of each mixture component. + `means_` will not exist before a call to fit. + + covariances_ : array + The covariance of each mixture component. + The shape depends on `covariance_type`:: + (n_components,) if 'spherical', + (n_features, n_features) if 'tied', + (n_components, n_features) if 'diag', + (n_components, n_features, n_features) if 'full' + `covariances_` will not exist before a call to fit. + + converged_ : bool + True when convergence was reached in fit(), False otherwise. + `converged_` will not exist before a call to fit. + + n_iter_ : int + Number of step used by the best fit of EM to reach the convergence. + `n_iter_` will not exist before a call to fit. + """ + + def __init__(self, n_components=1, covariance_type='full', tol=1e-3, + reg_covar=1e-6, max_iter=100, n_init=1, init_params='kmeans', + weights_init=None, means_init=None, covariances_init=None, + random_state=None, warm_start=False, + verbose=0, verbose_interval=10): + super(GaussianMixture, self).__init__( + n_components=n_components, tol=tol, reg_covar=reg_covar, + max_iter=max_iter, n_init=n_init, init_params=init_params, + random_state=random_state, warm_start=warm_start, + verbose=verbose, verbose_interval=verbose_interval) + + self.covariance_type = covariance_type + self.weights_init = weights_init + self.means_init = means_init + self.covariances_init = covariances_init + + def _check_parameters(self, X): + """Check the Gaussian mixture parameters are well defined.""" + if self.covariance_type not in ['spherical', 'tied', 'diag', 'full']: + raise ValueError("Invalid value for 'covariance_type': %s " + "'covariance_type' should be in " + "['spherical', 'tied', 'diag', 'full']" + % self.covariance_type) + + if self.weights_init is not None: + self.weights_init = _check_weights(self.weights_init, + self.n_components) + + if self.means_init is not None: + self.means_init = _check_means(self.means_init, + self.n_components, X.shape[1]) + + if self.covariances_init is not None: + self.covariances_init = _check_covariances(self.covariances_init, + self.covariance_type, + self.n_components, + X.shape[1]) + + def _initialize(self, X, resp): + """Initialization of the Gaussian mixture parameters. + + Parameters + ---------- + X : array-like, shape (n_samples, n_features) + + resp : array-like, shape (n_samples, n_components) + """ + weights, means, covariances = _estimate_gaussian_parameters( + X, resp, self.reg_covar, self.covariance_type) + weights /= X.shape[0] + + self.weights_ = (weights if self.weights_init is None + else self.weights_init) + self.means_ = means if self.means_init is None else self.means_init + self.covariances_ = (covariances if self.covariances_init is None + else self.covariances_init) + + def _e_step(self, X): + log_prob_norm, _, log_resp = self._estimate_log_prob_resp(X) + return np.mean(log_prob_norm), np.exp(log_resp) + + def _m_step(self, X, resp): + self.weights_, self.means_, self.covariances_ = ( + _estimate_gaussian_parameters(X, resp, self.reg_covar, + self.covariance_type)) + self.weights_ /= X.shape[0] + + def _estimate_log_prob(self, X): + estimate_log_prob_functions = { + "full": _estimate_log_gaussian_prob_full, + "tied": _estimate_log_gaussian_prob_tied, + "diag": _estimate_log_gaussian_prob_diag, + "spherical": _estimate_log_gaussian_prob_spherical + } + return estimate_log_prob_functions[self.covariance_type]( + X, self.means_, self.covariances_) + + def _estimate_log_weights(self): + return np.log(self.weights_) + + def _check_is_fitted(self): + check_is_fitted(self, ['weights_', 'means_', 'covariances_']) + + def _get_parameters(self): + return self.weights_, self.means_, self.covariances_ + + def _set_parameters(self, params): + self.weights_, self.means_, self.covariances_ = params + + def _n_parameters(self): + """Return the number of free parameters in the model.""" + ndim = self.means_.shape[1] + if self.covariance_type == 'full': + cov_params = self.n_components * ndim * (ndim + 1) / 2. + elif self.covariance_type == 'diag': + cov_params = self.n_components * ndim + elif self.covariance_type == 'tied': + cov_params = ndim * (ndim + 1) / 2. + elif self.covariance_type == 'spherical': + cov_params = self.n_components + mean_params = ndim * self.n_components + return int(cov_params + mean_params + self.n_components - 1) + + def bic(self, X): + """Bayesian information criterion for the current model on the input X. + + Parameters + ---------- + X : array of shape (n_samples, n_dimensions) + + Returns + ------- + bic: float + The greater the better. + """ + return (-2 * self.score(X) * X.shape[0] + + self._n_parameters() * np.log(X.shape[0])) + + def aic(self, X): + """Akaike information criterion for the current model on the input X. + + Parameters + ---------- + X : array of shape(n_samples, n_dimensions) + + Returns + ------- + aic: float + The greater the better. + """ + return -2 * self.score(X) * X.shape[0] + 2 * self._n_parameters() diff --git a/sklearn/mixture/gmm.py b/sklearn/mixture/gmm.py index f6b5a095f28ae..85aa384a22e98 100644 --- a/sklearn/mixture/gmm.py +++ b/sklearn/mixture/gmm.py @@ -15,7 +15,7 @@ from time import time from ..base import BaseEstimator -from ..utils import check_random_state, check_array +from ..utils import check_random_state, check_array, deprecated from ..utils.extmath import logsumexp from ..utils.validation import check_is_fitted from .. import cluster @@ -112,7 +112,7 @@ def sample_gaussian(mean, covar, covariance_type='diag', n_samples=1, return (rand.T + mean).T -class GMM(BaseEstimator): +class _GMMBase(BaseEstimator): """Gaussian Mixture Model. Representation of a Gaussian mixture model probability distribution. @@ -649,6 +649,19 @@ def aic(self, X): return - 2 * self.score(X).sum() + 2 * self._n_parameters() +@deprecated("The class GMM is deprecated and " + "will be removed in 0.20. Use class GaussianMixture instead.") +class GMM(_GMMBase): + def __init__(self, n_components=1, covariance_type='diag', + random_state=None, tol=1e-3, min_covar=1e-3, + n_iter=100, n_init=1, params='wmc', init_params='wmc', + verbose=0): + super(GMM, self).__init__( + n_components=n_components, covariance_type=covariance_type, + random_state=random_state, tol=tol, min_covar=min_covar, + n_iter=n_iter, n_init=n_init, params=params, + init_params=init_params, verbose=verbose) + ######################################################################### # some helper routines ######################################################################### diff --git a/sklearn/mixture/tests/test_gaussian_mixture.py b/sklearn/mixture/tests/test_gaussian_mixture.py new file mode 100644 index 0000000000000..64cdbe54c9f30 --- /dev/null +++ b/sklearn/mixture/tests/test_gaussian_mixture.py @@ -0,0 +1,825 @@ +import sys +import warnings + +import numpy as np + +from scipy import stats + +from sklearn.covariance import EmpiricalCovariance +from sklearn.datasets.samples_generator import make_spd_matrix +from sklearn.externals.six.moves import cStringIO as StringIO +from sklearn.metrics.cluster import adjusted_rand_score +from sklearn.mixture.gaussian_mixture import GaussianMixture +from sklearn.mixture.gaussian_mixture import _estimate_gaussian_covariance_diag +from sklearn.mixture.gaussian_mixture import _estimate_gaussian_covariance_full +from sklearn.mixture.gaussian_mixture import ( + _estimate_gaussian_covariance_spherical) +from sklearn.mixture.gaussian_mixture import _estimate_gaussian_covariance_tied +from sklearn.exceptions import ConvergenceWarning, NotFittedError +from sklearn.utils.extmath import fast_logdet +from sklearn.utils.testing import assert_allclose +from sklearn.utils.testing import assert_almost_equal +from sklearn.utils.testing import assert_array_almost_equal +from sklearn.utils.testing import assert_array_equal +from sklearn.utils.testing import assert_equal +from sklearn.utils.testing import assert_greater +from sklearn.utils.testing import assert_greater_equal +from sklearn.utils.testing import assert_raise_message +from sklearn.utils.testing import assert_true +from sklearn.utils.testing import assert_warns_message + + +COVARIANCE_TYPE = ['full', 'tied', 'diag', 'spherical'] + + +def generate_data(n_samples, n_features, weights, means, covariances, + covariance_type): + rng = np.random.RandomState(0) + + X = [] + if covariance_type == 'spherical': + for _, (w, m, c) in enumerate(zip(weights, means, + covariances['spherical'])): + X.append(rng.multivariate_normal(m, c * np.eye(n_features), + int(np.round(w * n_samples)))) + if covariance_type == 'diag': + for _, (w, m, c) in enumerate(zip(weights, means, + covariances['diag'])): + X.append(rng.multivariate_normal(m, np.diag(c), + int(np.round(w * n_samples)))) + if covariance_type == 'tied': + for _, (w, m) in enumerate(zip(weights, means)): + X.append(rng.multivariate_normal(m, covariances['tied'], + int(np.round(w * n_samples)))) + if covariance_type == 'full': + for _, (w, m, c) in enumerate(zip(weights, means, + covariances['full'])): + X.append(rng.multivariate_normal(m, c, + int(np.round(w * n_samples)))) + + X = np.vstack(X) + return X + + +class RandomData(object): + def __init__(self, rng, n_samples=500, n_components=2, n_features=2, + scale=50): + self.n_samples = n_samples + self.n_components = n_components + self.n_features = n_features + + self.weights = rng.rand(n_components) + self.weights = self.weights / self.weights.sum() + self.means = rng.rand(n_components, n_features) * scale + self.covariances = { + 'spherical': .5 + rng.rand(n_components), + 'diag': (.5 + rng.rand(n_components, n_features)) ** 2, + 'tied': make_spd_matrix(n_features, random_state=rng), + 'full': np.array([make_spd_matrix( + n_features, random_state=rng) * .5 + for _ in range(n_components)])} + + self.X = dict(zip(COVARIANCE_TYPE, [generate_data( + n_samples, n_features, self.weights, self.means, self.covariances, + cov_type) for cov_type in COVARIANCE_TYPE])) + self.Y = np.hstack([k * np.ones(int(np.round(w * n_samples))) + for k, w in enumerate(self.weights)]) + + +def test_gaussian_mixture_attributes(): + # test bad parameters + rng = np.random.RandomState(0) + X = rng.rand(10, 2) + + n_components_bad = 0 + gmm = GaussianMixture(n_components=n_components_bad) + assert_raise_message(ValueError, + "Invalid value for 'n_components': %d " + "Estimation requires at least one component" + % n_components_bad, gmm.fit, X) + + # covariance_type should be in [spherical, diag, tied, full] + covariance_type_bad = 'bad_covariance_type' + gmm = GaussianMixture(covariance_type=covariance_type_bad) + assert_raise_message(ValueError, + "Invalid value for 'covariance_type': %s " + "'covariance_type' should be in " + "['spherical', 'tied', 'diag', 'full']" + % covariance_type_bad, + gmm.fit, X) + + tol_bad = -1 + gmm = GaussianMixture(tol=tol_bad) + assert_raise_message(ValueError, + "Invalid value for 'tol': %.5f " + "Tolerance used by the EM must be non-negative" + % tol_bad, gmm.fit, X) + + reg_covar_bad = -1 + gmm = GaussianMixture(reg_covar=reg_covar_bad) + assert_raise_message(ValueError, + "Invalid value for 'reg_covar': %.5f " + "regularization on covariance must be " + "non-negative" % reg_covar_bad, gmm.fit, X) + + max_iter_bad = 0 + gmm = GaussianMixture(max_iter=max_iter_bad) + assert_raise_message(ValueError, + "Invalid value for 'max_iter': %d " + "Estimation requires at least one iteration" + % max_iter_bad, gmm.fit, X) + + n_init_bad = 0 + gmm = GaussianMixture(n_init=n_init_bad) + assert_raise_message(ValueError, + "Invalid value for 'n_init': %d " + "Estimation requires at least one run" + % n_init_bad, gmm.fit, X) + + init_params_bad = 'bad_method' + gmm = GaussianMixture(init_params=init_params_bad) + assert_raise_message(ValueError, + "Unimplemented initialization method '%s'" + % init_params_bad, + gmm.fit, X) + + # test good parameters + n_components, tol, n_init, max_iter, reg_covar = 2, 1e-4, 3, 30, 1e-1 + covariance_type, init_params = 'full', 'random' + gmm = GaussianMixture(n_components=n_components, tol=tol, n_init=n_init, + max_iter=max_iter, reg_covar=reg_covar, + covariance_type=covariance_type, + init_params=init_params).fit(X) + + assert_equal(gmm.n_components, n_components) + assert_equal(gmm.covariance_type, covariance_type) + assert_equal(gmm.tol, tol) + assert_equal(gmm.reg_covar, reg_covar) + assert_equal(gmm.max_iter, max_iter) + assert_equal(gmm.n_init, n_init) + assert_equal(gmm.init_params, init_params) + + +def test_check_X(): + from sklearn.mixture.base import _check_X + rng = np.random.RandomState(0) + + n_samples, n_components, n_features = 10, 2, 2 + + X_bad_dim = rng.rand(n_components - 1, n_features) + assert_raise_message(ValueError, + 'Expected n_samples >= n_components ' + 'but got n_components = %d, n_samples = %d' + % (n_components, X_bad_dim.shape[0]), + _check_X, X_bad_dim, n_components) + + X_bad_dim = rng.rand(n_components, n_features + 1) + assert_raise_message(ValueError, + 'Expected the input data X have %d features, ' + 'but got %d features' + % (n_features, X_bad_dim.shape[1]), + _check_X, X_bad_dim, n_components, n_features) + + X = rng.rand(n_samples, n_features) + assert_array_equal(X, _check_X(X, n_components, n_features)) + + +def test_check_weights(): + rng = np.random.RandomState(0) + rand_data = RandomData(rng) + + n_components = rand_data.n_components + X = rand_data.X['full'] + + g = GaussianMixture(n_components=n_components) + + # Check bad shape + weights_bad_shape = rng.rand(n_components, 1) + g.weights_init = weights_bad_shape + assert_raise_message(ValueError, + "The parameter 'weights' should have the shape of " + "(%d,), " + "but got %s" % (n_components, + str(weights_bad_shape.shape)), + g.fit, X) + + # Check bad range + weights_bad_range = rng.rand(n_components) + 1 + g.weights_init = weights_bad_range + assert_raise_message(ValueError, + "The parameter 'weights' should be in the range " + "[0, 1], but got max value %.5f, min value %.5f" + % (np.min(weights_bad_range), + np.max(weights_bad_range)), + g.fit, X) + + # Check bad normalization + weights_bad_norm = rng.rand(n_components) + weights_bad_norm = weights_bad_norm / (weights_bad_norm.sum() + 1) + g.weights_init = weights_bad_norm + assert_raise_message(ValueError, + "The parameter 'weights' should be normalized, " + "but got sum(weights) = %.5f" + % np.sum(weights_bad_norm), + g.fit, X) + + # Check good weights matrix + weights = rand_data.weights + g = GaussianMixture(weights_init=weights, n_components=n_components) + g.fit(X) + assert_array_equal(weights, g.weights_init) + + +def test_check_means(): + rng = np.random.RandomState(0) + rand_data = RandomData(rng) + + n_components, n_features = rand_data.n_components, rand_data.n_features + X = rand_data.X['full'] + + g = GaussianMixture(n_components=n_components) + + # Check means bad shape + means_bad_shape = rng.rand(n_components + 1, n_features) + g.means_init = means_bad_shape + assert_raise_message(ValueError, + "The parameter 'means' should have the shape of ", + g.fit, X) + + # Check good means matrix + means = rand_data.means + g.means_init = means + g.fit(X) + assert_array_equal(means, g.means_init) + + +def test_check_covariances(): + rng = np.random.RandomState(0) + rand_data = RandomData(rng) + + n_components, n_features = rand_data.n_components, rand_data.n_features + + # Define the bad covariances for each covariance_type + covariances_bad_shape = { + 'full': rng.rand(n_components + 1, n_features, n_features), + 'tied': rng.rand(n_features + 1, n_features + 1), + 'diag': rng.rand(n_components + 1, n_features), + 'spherical': rng.rand(n_components + 1)} + + # Define not positive-definite covariances + covariances_not_pos = rng.rand(n_components, n_features, n_features) + covariances_not_pos[0] = np.eye(n_features) + covariances_not_pos[0, 0, 0] = -1. + + covariances_not_positive = { + 'full': covariances_not_pos, + 'tied': covariances_not_pos[0], + 'diag': -1. * np.ones((n_components, n_features)), + 'spherical': -1. * np.ones(n_components)} + + not_positive_errors = { + 'full': 'symmetric, positive-definite', + 'tied': 'symmetric, positive-definite', + 'diag': 'positive', + 'spherical': 'positive'} + + for cov_type in ['full', 'tied', 'diag', 'spherical']: + X = rand_data.X[cov_type] + g = GaussianMixture(n_components=n_components, + covariance_type=cov_type) + + # Check covariance with bad shapes + g.covariances_init = covariances_bad_shape[cov_type] + assert_raise_message(ValueError, + "The parameter '%s covariance' should have " + "the shape of" % cov_type, + g.fit, X) + + # Check not positive covariances + g.covariances_init = covariances_not_positive[cov_type] + assert_raise_message(ValueError, + "'%s covariance' should be %s" + % (cov_type, not_positive_errors[cov_type]), + g.fit, X) + + # Check the correct init of covariances_init + g.covariances_init = rand_data.covariances[cov_type] + g.fit(X) + assert_array_equal(rand_data.covariances[cov_type], g.covariances_init) + + +def test_suffstat_sk_full(): + # compare the EmpiricalCovariance.covariance fitted on X*sqrt(resp) + # with _sufficient_sk_full, n_components=1 + rng = np.random.RandomState(0) + n_samples, n_features = 500, 2 + + # special case 1, assuming data is "centered" + X = rng.rand(n_samples, n_features) + resp = rng.rand(n_samples, 1) + X_resp = np.sqrt(resp) * X + nk = np.array([n_samples]) + xk = np.zeros((1, n_features)) + covars_pred = _estimate_gaussian_covariance_full(resp, X, nk, xk, 0) + ecov = EmpiricalCovariance(assume_centered=True) + ecov.fit(X_resp) + assert_almost_equal(ecov.error_norm(covars_pred[0], norm='frobenius'), 0) + assert_almost_equal(ecov.error_norm(covars_pred[0], norm='spectral'), 0) + + # special case 2, assuming resp are all ones + resp = np.ones((n_samples, 1)) + nk = np.array([n_samples]) + xk = X.mean().reshape((1, -1)) + covars_pred = _estimate_gaussian_covariance_full(resp, X, nk, xk, 0) + ecov = EmpiricalCovariance(assume_centered=False) + ecov.fit(X) + assert_almost_equal(ecov.error_norm(covars_pred[0], norm='frobenius'), 0) + assert_almost_equal(ecov.error_norm(covars_pred[0], norm='spectral'), 0) + + +def test_suffstat_sk_tied(): + # use equation Nk * Sk / N = S_tied + rng = np.random.RandomState(0) + n_samples, n_features, n_components = 500, 2, 2 + + resp = rng.rand(n_samples, n_components) + resp = resp / resp.sum(axis=1)[:, np.newaxis] + X = rng.rand(n_samples, n_features) + nk = resp.sum(axis=0) + xk = np.dot(resp.T, X) / nk[:, np.newaxis] + covars_pred_full = _estimate_gaussian_covariance_full(resp, X, nk, xk, 0) + covars_pred_full = np.sum(nk[:, np.newaxis, np.newaxis] * covars_pred_full, + 0) / n_samples + + covars_pred_tied = _estimate_gaussian_covariance_tied(resp, X, nk, xk, 0) + ecov = EmpiricalCovariance() + ecov.covariance_ = covars_pred_full + assert_almost_equal(ecov.error_norm(covars_pred_tied, norm='frobenius'), 0) + assert_almost_equal(ecov.error_norm(covars_pred_tied, norm='spectral'), 0) + + +def test_suffstat_sk_diag(): + # test against 'full' case + rng = np.random.RandomState(0) + n_samples, n_features, n_components = 500, 2, 2 + + resp = rng.rand(n_samples, n_components) + resp = resp / resp.sum(axis=1)[:, np.newaxis] + X = rng.rand(n_samples, n_features) + nk = resp.sum(axis=0) + xk = np.dot(resp.T, X) / nk[:, np.newaxis] + covars_pred_full = _estimate_gaussian_covariance_full(resp, X, nk, xk, 0) + covars_pred_full = np.array([np.diag(np.diag(d)) for d in + covars_pred_full]) + covars_pred_diag = _estimate_gaussian_covariance_diag(resp, X, nk, xk, 0) + covars_pred_diag = np.array([np.diag(d) for d in covars_pred_diag]) + ecov = EmpiricalCovariance() + for (cov_full, cov_diag) in zip(covars_pred_full, covars_pred_diag): + ecov.covariance_ = cov_full + assert_almost_equal(ecov.error_norm(cov_diag, norm='frobenius'), 0) + assert_almost_equal(ecov.error_norm(cov_diag, norm='spectral'), 0) + + +def test_gaussian_suffstat_sk_spherical(): + # computing spherical covariance equals to the variance of one-dimension + # data after flattening, n_components=1 + rng = np.random.RandomState(0) + n_samples, n_features = 500, 2 + + X = rng.rand(n_samples, n_features) + X = X - X.mean() + resp = np.ones((n_samples, 1)) + nk = np.array([n_samples]) + xk = X.mean() + covars_pred_spherical = _estimate_gaussian_covariance_spherical(resp, X, + nk, xk, 0) + covars_pred_spherical2 = (np.dot(X.flatten().T, X.flatten()) / + (n_features * n_samples)) + assert_almost_equal(covars_pred_spherical, covars_pred_spherical2) + + +def _naive_lmvnpdf_diag(X, means, covars): + resp = np.empty((len(X), len(means))) + stds = np.sqrt(covars) + for i, (mean, std) in enumerate(zip(means, stds)): + resp[:, i] = stats.norm.logpdf(X, mean, std).sum(axis=1) + return resp + + +def test_gaussian_mixture_log_probabilities(): + from sklearn.mixture.gaussian_mixture import ( + _estimate_log_gaussian_prob_full, + _estimate_log_gaussian_prob_tied, + _estimate_log_gaussian_prob_diag, + _estimate_log_gaussian_prob_spherical) + + # test aginst with _naive_lmvnpdf_diag + rng = np.random.RandomState(0) + rand_data = RandomData(rng) + n_samples = 500 + n_features = rand_data.n_features + n_components = rand_data.n_components + + means = rand_data.means + covars_diag = rng.rand(n_components, n_features) + X = rng.rand(n_samples, n_features) + log_prob_naive = _naive_lmvnpdf_diag(X, means, covars_diag) + + # full covariances + covars_full = np.array([np.diag(x) for x in covars_diag]) + + log_prob = _estimate_log_gaussian_prob_full(X, means, covars_full) + assert_array_almost_equal(log_prob, log_prob_naive) + + # diag covariances + log_prob = _estimate_log_gaussian_prob_diag(X, means, covars_diag) + assert_array_almost_equal(log_prob, log_prob_naive) + + # tied + covars_tied = covars_full.mean(axis=0) + log_prob_naive = _naive_lmvnpdf_diag(X, means, + [np.diag(covars_tied)] * n_components) + log_prob = _estimate_log_gaussian_prob_tied(X, means, covars_tied) + assert_array_almost_equal(log_prob, log_prob_naive) + + # spherical + covars_spherical = covars_diag.mean(axis=1) + log_prob_naive = _naive_lmvnpdf_diag(X, means, + [[k] * n_features for k in + covars_spherical]) + log_prob = _estimate_log_gaussian_prob_spherical(X, means, + covars_spherical) + assert_array_almost_equal(log_prob, log_prob_naive) + +# skip tests on weighted_log_probabilities, log_weights + + +def test_gaussian_mixture_estimate_log_prob_resp(): + # test whether responsibilities are normalized + rng = np.random.RandomState(0) + rand_data = RandomData(rng, scale=5) + n_samples = rand_data.n_samples + n_features = rand_data.n_features + n_components = rand_data.n_components + + X = rng.rand(n_samples, n_features) + for cov_type in COVARIANCE_TYPE: + weights = rand_data.weights + means = rand_data.means + covariances = rand_data.covariances[cov_type] + g = GaussianMixture(n_components=n_components, random_state=rng, + weights_init=weights, means_init=means, + covariances_init=covariances, + covariance_type=cov_type) + g.fit(X) + resp = g.predict_proba(X) + assert_array_almost_equal(resp.sum(axis=1), np.ones(n_samples)) + assert_array_equal(g.weights_init, weights) + assert_array_equal(g.means_init, means) + assert_array_equal(g.covariances_init, covariances) + + +def test_gaussian_mixture_predict_predict_proba(): + rng = np.random.RandomState(0) + rand_data = RandomData(rng) + for cov_type in COVARIANCE_TYPE: + X = rand_data.X[cov_type] + Y = rand_data.Y + g = GaussianMixture(n_components=rand_data.n_components, + random_state=rng, weights_init=rand_data.weights, + means_init=rand_data.means, + covariances_init=rand_data.covariances[cov_type], + covariance_type=cov_type) + + # Check a warning message arrive if we don't do fit + assert_raise_message(NotFittedError, + "This GaussianMixture instance is not fitted " + "yet. Call 'fit' with appropriate arguments " + "before using this method.", g.predict, X) + + g.fit(X) + Y_pred = g.predict(X) + Y_pred_proba = g.predict_proba(X).argmax(axis=1) + assert_array_equal(Y_pred, Y_pred_proba) + assert_greater(adjusted_rand_score(Y, Y_pred), .95) + + +def test_gaussian_mixture_fit(): + # recover the ground truth + rng = np.random.RandomState(0) + rand_data = RandomData(rng) + n_features = rand_data.n_features + n_components = rand_data.n_components + + for cov_type in COVARIANCE_TYPE: + X = rand_data.X[cov_type] + g = GaussianMixture(n_components=n_components, n_init=20, max_iter=100, + reg_covar=0, random_state=rng, + covariance_type=cov_type) + g.fit(X) + # needs more data to pass the test with rtol=1e-7 + assert_allclose(np.sort(g.weights_), np.sort(rand_data.weights), + rtol=0.1, atol=1e-2) + + arg_idx1 = g.means_[:, 0].argsort() + arg_idx2 = rand_data.means[:, 0].argsort() + assert_allclose(g.means_[arg_idx1], rand_data.means[arg_idx2], + rtol=0.1, atol=1e-2) + + if cov_type == 'spherical': + cov_pred = np.array([np.eye(n_features) * c + for c in g.covariances_]) + cov_test = np.array([np.eye(n_features) * c for c in + rand_data.covariances['spherical']]) + elif cov_type == 'diag': + cov_pred = np.array([np.diag(d) for d in g.covariances_]) + cov_test = np.array([np.diag(d) for d in + rand_data.covariances['diag']]) + elif cov_type == 'tied': + cov_pred = np.array([g.covariances_] * n_components) + cov_test = np.array([rand_data.covariances['tied']] * n_components) + elif cov_type == 'full': + cov_pred = g.covariances_ + cov_test = rand_data.covariances['full'] + arg_idx1 = np.trace(cov_pred, axis1=1, axis2=2).argsort() + arg_idx2 = np.trace(cov_test, axis1=1, axis2=2).argsort() + for k, h in zip(arg_idx1, arg_idx2): + ecov = EmpiricalCovariance() + ecov.covariance_ = cov_test[h] + # the accuracy depends on the number of data and randomness, rng + assert_allclose(ecov.error_norm(cov_pred[k]), 0, atol=0.1) + + +def test_gaussian_mixture_fit_best_params(): + rng = np.random.RandomState(0) + rand_data = RandomData(rng) + n_components = rand_data.n_components + n_init = 10 + for cov_type in COVARIANCE_TYPE: + X = rand_data.X[cov_type] + g = GaussianMixture(n_components=n_components, n_init=1, + max_iter=100, reg_covar=0, random_state=rng, + covariance_type=cov_type) + ll = [] + for _ in range(n_init): + g.fit(X) + ll.append(g.score(X)) + ll = np.array(ll) + g_best = GaussianMixture(n_components=n_components, + n_init=n_init, max_iter=100, reg_covar=0, + random_state=rng, covariance_type=cov_type) + g_best.fit(X) + assert_almost_equal(ll.min(), g_best.score(X)) + + +def test_gaussian_mixture_fit_convergence_warning(): + rng = np.random.RandomState(0) + rand_data = RandomData(rng, scale=1) + n_components = rand_data.n_components + max_iter = 1 + for cov_type in COVARIANCE_TYPE: + X = rand_data.X[cov_type] + g = GaussianMixture(n_components=n_components, n_init=1, + max_iter=max_iter, reg_covar=0, random_state=rng, + covariance_type=cov_type) + assert_warns_message(ConvergenceWarning, + 'Initialization %d did not converged. ' + 'Try different init parameters, ' + 'or increase n_init, tol ' + 'or check for degenerate data.' + % max_iter, g.fit, X) + + +def test_multiple_init(): + # Test that multiple inits does not much worse than a single one + rng = np.random.RandomState(0) + n_samples, n_features, n_components = 50, 5, 2 + X = rng.randn(n_samples, n_features) + for cv_type in COVARIANCE_TYPE: + train1 = GaussianMixture(n_components=n_components, + covariance_type=cv_type, + random_state=rng).fit(X).score(X) + train2 = GaussianMixture(n_components=n_components, + covariance_type=cv_type, + random_state=rng, n_init=5).fit(X).score(X) + assert_greater_equal(train2, train1) + + +def test_gaussian_mixture_n_parameters(): + # Test that the right number of parameters is estimated + rng = np.random.RandomState(0) + n_samples, n_features, n_components = 50, 5, 2 + X = rng.randn(n_samples, n_features) + n_params = {'spherical': 13, 'diag': 21, 'tied': 26, 'full': 41} + for cv_type in COVARIANCE_TYPE: + g = GaussianMixture( + n_components=n_components, covariance_type=cv_type, + random_state=rng).fit(X) + assert_equal(g._n_parameters(), n_params[cv_type]) + + +def test_bic_1d_1component(): + # Test all of the covariance_types return the same BIC score for + # 1-dimensional, 1 component fits. + rng = np.random.RandomState(0) + n_samples, n_dim, n_components = 100, 1, 1 + X = rng.randn(n_samples, n_dim) + bic_full = GaussianMixture(n_components=n_components, + covariance_type='full', + random_state=rng).fit(X).bic(X) + for covariance_type in ['tied', 'diag', 'spherical']: + bic = GaussianMixture(n_components=n_components, + covariance_type=covariance_type, + random_state=rng).fit(X).bic(X) + assert_almost_equal(bic_full, bic) + + +def test_gaussian_mixture_aic_bic(): + # Test the aic and bic criteria + rng = np.random.RandomState(0) + n_samples, n_features, n_components = 50, 3, 2 + X = rng.randn(n_samples, n_features) + # standard gaussian entropy + sgh = 0.5 * (fast_logdet(np.cov(X.T, bias=1)) + + n_features * (1 + np.log(2 * np.pi))) + for cv_type in COVARIANCE_TYPE: + g = GaussianMixture( + n_components=n_components, covariance_type=cv_type, + random_state=rng, max_iter=200) + g.fit(X) + aic = 2 * n_samples * sgh + 2 * g._n_parameters() + bic = (2 * n_samples * sgh + + np.log(n_samples) * g._n_parameters()) + bound = n_features / np.sqrt(n_samples) + assert_true((g.aic(X) - aic) / n_samples < bound) + assert_true((g.bic(X) - bic) / n_samples < bound) + + +def test_gaussian_mixture_verbose(): + rng = np.random.RandomState(0) + rand_data = RandomData(rng) + n_components = rand_data.n_components + for cov_type in COVARIANCE_TYPE: + X = rand_data.X[cov_type] + g = GaussianMixture(n_components=n_components, n_init=1, + max_iter=100, reg_covar=0, random_state=rng, + covariance_type=cov_type, verbose=1) + h = GaussianMixture(n_components=n_components, n_init=1, + max_iter=100, reg_covar=0, random_state=rng, + covariance_type=cov_type, verbose=2) + old_stdout = sys.stdout + sys.stdout = StringIO() + try: + g.fit(X) + h.fit(X) + finally: + sys.stdout = old_stdout + + +def test_warm_start(): + + random_state = 0 + rng = np.random.RandomState(random_state) + n_samples, n_features, n_components = 500, 2, 2 + X = rng.rand(n_samples, n_features) + + # Assert the warm_start give the same result for the same number of iter + g = GaussianMixture(n_components=n_components, n_init=1, + max_iter=2, reg_covar=0, random_state=random_state, + warm_start=False) + h = GaussianMixture(n_components=n_components, n_init=1, + max_iter=1, reg_covar=0, random_state=random_state, + warm_start=True) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", ConvergenceWarning) + g.fit(X) + score1 = h.fit(X).score(X) + score2 = h.fit(X).score(X) + + assert_almost_equal(g.weights_, h.weights_) + assert_almost_equal(g.means_, h.means_) + assert_almost_equal(g.covariances_, h.covariances_) + assert_greater(score2, score1) + + # Assert that by using warm_start we can converge to a good solution + g = GaussianMixture(n_components=n_components, n_init=1, + max_iter=5, reg_covar=0, random_state=random_state, + warm_start=False, tol=1e-6) + h = GaussianMixture(n_components=n_components, n_init=1, + max_iter=5, reg_covar=0, random_state=random_state, + warm_start=True, tol=1e-6) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", ConvergenceWarning) + g.fit(X) + h.fit(X).fit(X) + + assert_true(not g.converged_) + assert_true(h.converged_) + + +def test_score(): + cov_type = 'full' + rng = np.random.RandomState(0) + rand_data = RandomData(rng, scale=7) + n_components = rand_data.n_components + X = rand_data.X[cov_type] + + # Check the error message if we don't call fit + gmm1 = GaussianMixture(n_components=n_components, n_init=1, + max_iter=1, reg_covar=0, random_state=rng, + covariance_type=cov_type) + assert_raise_message(NotFittedError, + "This GaussianMixture instance is not fitted " + "yet. Call 'fit' with appropriate arguments " + "before using this method.", gmm1.score, X) + + # Check score value + with warnings.catch_warnings(): + warnings.simplefilter("ignore", ConvergenceWarning) + gmm1.fit(X) + gmm_score = gmm1.score(X) + gmm_score_proba = gmm1.score_samples(X).mean() + assert_almost_equal(gmm_score, gmm_score_proba) + + # Check if the score increase + gmm2 = GaussianMixture(n_components=n_components, n_init=1, + max_iter=1000, reg_covar=0, random_state=rng, + covariance_type=cov_type).fit(X) + assert_greater(gmm2.score(X), gmm1.score(X)) + + +def test_score_samples(): + cov_type = 'full' + rng = np.random.RandomState(0) + rand_data = RandomData(rng, scale=7) + n_components = rand_data.n_components + X = rand_data.X[cov_type] + + # Check the error message if we don't call fit + gmm = GaussianMixture(n_components=n_components, n_init=1, + max_iter=100, reg_covar=0, random_state=rng, + covariance_type=cov_type) + assert_raise_message(NotFittedError, + "This GaussianMixture instance is not fitted " + "yet. Call 'fit' with appropriate arguments " + "before using this method.", gmm.score_samples, X) + + gmm_score_samples = gmm.fit(X).score_samples(X) + assert_equal(gmm_score_samples.shape[0], rand_data.n_samples) + + +def test_monotonic_likelihood(): + # We check that each step of the EM without regularization improve + # monotonically the training set likelihood + rng = np.random.RandomState(0) + rand_data = RandomData(rng, scale=7) + n_components = rand_data.n_components + + for cov_type in COVARIANCE_TYPE: + X = rand_data.X[cov_type] + gmm = GaussianMixture(n_components=n_components, + covariance_type=cov_type, reg_covar=0, + warm_start=True, max_iter=1, random_state=rng, + tol=1e-7) + current_log_likelihood = -np.infty + with warnings.catch_warnings(): + warnings.simplefilter("ignore", ConvergenceWarning) + # Do one training iteration at a time so we can make sure that the + # training log likelihood increases after each iteration. + for _ in range(300): + prev_log_likelihood = current_log_likelihood + try: + current_log_likelihood = gmm.fit(X).score(X) + except ConvergenceWarning: + pass + assert_greater_equal(current_log_likelihood, + prev_log_likelihood) + + if gmm.converged_: + break + + +def test_regularisation(): + # We train the GaussianMixture on degenerate data by defining two clusters + # of a 0 covariance. + rng = np.random.RandomState(0) + n_samples, n_features = 10, 5 + + X = np.vstack((np.ones((n_samples // 2, n_features)), + np.zeros((n_samples // 2, n_features)))) + + for cov_type in COVARIANCE_TYPE: + gmm = GaussianMixture(n_components=n_samples, covariance_type=cov_type, + reg_covar=0, random_state=rng) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", RuntimeWarning) + assert_raise_message(ValueError, + "The algorithm has diverged because of too " + "few samples per components. " + "Try to decrease the number of components, " + "or increase reg_covar.", gmm.fit, X) + + gmm.set_params(reg_covar=1e-6).fit(X) diff --git a/sklearn/mixture/tests/test_gmm.py b/sklearn/mixture/tests/test_gmm.py index 7d79c3c12abd6..b726ebebb1b60 100644 --- a/sklearn/mixture/tests/test_gmm.py +++ b/sklearn/mixture/tests/test_gmm.py @@ -1,6 +1,9 @@ +# These tests are those of the deprecated GMM class + import unittest import copy import sys +import warnings from nose.tools import assert_true import numpy as np @@ -14,6 +17,7 @@ from sklearn.metrics.cluster import adjusted_rand_score from sklearn.externals.six.moves import cStringIO as StringIO + rng = np.random.RandomState(0) @@ -121,30 +125,33 @@ def test_lvmpdf_full_cv_non_positive_definite(): def test_GMM_attributes(): - n_components, n_features = 10, 4 - covariance_type = 'diag' - g = mixture.GMM(n_components, covariance_type, random_state=rng) - weights = rng.rand(n_components) - weights = weights / weights.sum() - means = rng.randint(-20, 20, (n_components, n_features)) - - assert_true(g.n_components == n_components) - assert_true(g.covariance_type == covariance_type) - - g.weights_ = weights - assert_array_almost_equal(g.weights_, weights) - g.means_ = means - assert_array_almost_equal(g.means_, means) - - covars = (0.1 + 2 * rng.rand(n_components, n_features)) ** 2 - g.covars_ = covars - assert_array_almost_equal(g.covars_, covars) - assert_raises(ValueError, g._set_covars, []) - assert_raises(ValueError, g._set_covars, - np.zeros((n_components - 2, n_features))) - - assert_raises(ValueError, mixture.GMM, n_components=20, - covariance_type='badcovariance_type') + # This function tests the deprecated old GMM class + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + n_components, n_features = 10, 4 + covariance_type = 'diag' + g = mixture.GMM(n_components, covariance_type, random_state=rng) + weights = rng.rand(n_components) + weights = weights / weights.sum() + means = rng.randint(-20, 20, (n_components, n_features)) + + assert_true(g.n_components == n_components) + assert_true(g.covariance_type == covariance_type) + + g.weights_ = weights + assert_array_almost_equal(g.weights_, weights) + g.means_ = means + assert_array_almost_equal(g.means_, means) + + covars = (0.1 + 2 * rng.rand(n_components, n_features)) ** 2 + g.covars_ = covars + assert_array_almost_equal(g.covars_, covars) + assert_raises(ValueError, g._set_covars, []) + assert_raises(ValueError, g._set_covars, + np.zeros((n_components - 2, n_features))) + + assert_raises(ValueError, mixture.GMM, n_components=20, + covariance_type='badcovariance_type') class GMMTester(): @@ -169,114 +176,137 @@ def _setUp(self): + 5 * self.I for x in range(self.n_components)])} def test_eval(self): - if not self.do_test_eval: - return # DPGMM does not support setting the means and - # covariances before fitting There is no way of fixing this - # due to the variational parameters being more expressive than - # covariance matrices - g = self.model(n_components=self.n_components, - covariance_type=self.covariance_type, random_state=rng) - # Make sure the means are far apart so responsibilities.argmax() - # picks the actual component used to generate the observations. - g.means_ = 20 * self.means - g.covars_ = self.covars[self.covariance_type] - g.weights_ = self.weights - - gaussidx = np.repeat(np.arange(self.n_components), 5) - n_samples = len(gaussidx) - X = rng.randn(n_samples, self.n_features) + g.means_[gaussidx] - - ll, responsibilities = g.score_samples(X) - - self.assertEqual(len(ll), n_samples) - self.assertEqual(responsibilities.shape, - (n_samples, self.n_components)) - assert_array_almost_equal(responsibilities.sum(axis=1), - np.ones(n_samples)) - assert_array_equal(responsibilities.argmax(axis=1), gaussidx) + # This function tests the deprecated old GMM class + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + if not self.do_test_eval: + return # DPGMM does not support setting the means and + # covariances before fitting There is no way of fixing this + # due to the variational parameters being more expressive than + # covariance matrices + g = self.model(n_components=self.n_components, + covariance_type=self.covariance_type, random_state=rng) + # Make sure the means are far apart so responsibilities.argmax() + # picks the actual component used to generate the observations. + g.means_ = 20 * self.means + g.covars_ = self.covars[self.covariance_type] + g.weights_ = self.weights + + gaussidx = np.repeat(np.arange(self.n_components), 5) + n_samples = len(gaussidx) + X = rng.randn(n_samples, self.n_features) + g.means_[gaussidx] + + ll, responsibilities = g.score_samples(X) + + self.assertEqual(len(ll), n_samples) + self.assertEqual(responsibilities.shape, + (n_samples, self.n_components)) + assert_array_almost_equal(responsibilities.sum(axis=1), + np.ones(n_samples)) + assert_array_equal(responsibilities.argmax(axis=1), gaussidx) def test_sample(self, n=100): - g = self.model(n_components=self.n_components, - covariance_type=self.covariance_type, random_state=rng) - # Make sure the means are far apart so responsibilities.argmax() - # picks the actual component used to generate the observations. - g.means_ = 20 * self.means - g.covars_ = np.maximum(self.covars[self.covariance_type], 0.1) - g.weights_ = self.weights - - samples = g.sample(n) - self.assertEqual(samples.shape, (n, self.n_features)) + # This function tests the deprecated old GMM class + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + g = self.model(n_components=self.n_components, + covariance_type=self.covariance_type, + random_state=rng) + # Make sure the means are far apart so responsibilities.argmax() + # picks the actual component used to generate the observations. + g.means_ = 20 * self.means + g.covars_ = np.maximum(self.covars[self.covariance_type], 0.1) + g.weights_ = self.weights + + samples = g.sample(n) + self.assertEqual(samples.shape, (n, self.n_features)) def test_train(self, params='wmc'): - g = mixture.GMM(n_components=self.n_components, - covariance_type=self.covariance_type) - g.weights_ = self.weights - g.means_ = self.means - g.covars_ = 20 * self.covars[self.covariance_type] - - # Create a training set by sampling from the predefined distribution. - X = g.sample(n_samples=100) - g = self.model(n_components=self.n_components, - covariance_type=self.covariance_type, - random_state=rng, min_covar=1e-1, - n_iter=1, init_params=params) - g.fit(X) - - # Do one training iteration at a time so we can keep track of - # the log likelihood to make sure that it increases after each - # iteration. - trainll = [] - for _ in range(5): - g.params = params - g.init_params = '' + # This function tests the deprecated old GMM class + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + g = mixture.GMM(n_components=self.n_components, + covariance_type=self.covariance_type) + g.weights_ = self.weights + g.means_ = self.means + g.covars_ = 20 * self.covars[self.covariance_type] + + # Create a training set by sampling from the predefined distribution. + X = g.sample(n_samples=100) + g = self.model(n_components=self.n_components, + covariance_type=self.covariance_type, + random_state=rng, min_covar=1e-1, + n_iter=1, init_params=params) g.fit(X) - trainll.append(self.score(g, X)) - g.n_iter = 10 - g.init_params = '' - g.params = params - g.fit(X) # finish fitting - - # Note that the log likelihood will sometimes decrease by a - # very small amount after it has more or less converged due to - # the addition of min_covar to the covariance (to prevent - # underflow). This is why the threshold is set to -0.5 - # instead of 0. - delta_min = np.diff(trainll).min() - self.assertTrue( - delta_min > self.threshold, - "The min nll increase is %f which is lower than the admissible" - " threshold of %f, for model %s. The likelihoods are %s." - % (delta_min, self.threshold, self.covariance_type, trainll)) + + # Do one training iteration at a time so we can keep track of + # the log likelihood to make sure that it increases after each + # iteration. + trainll = [] + for _ in range(5): + g.params = params + g.init_params = '' + g.fit(X) + trainll.append(self.score(g, X)) + g.n_iter = 10 + g.init_params = '' + g.params = params + g.fit(X) # finish fitting + + # Note that the log likelihood will sometimes decrease by a + # very small amount after it has more or less converged due to + # the addition of min_covar to the covariance (to prevent + # underflow). This is why the threshold is set to -0.5 + # instead of 0. + delta_min = np.diff(trainll).min() + self.assertTrue( + delta_min > self.threshold, + "The min nll increase is %f which is lower than the admissible" + " threshold of %f, for model %s. The likelihoods are %s." + % (delta_min, self.threshold, self.covariance_type, trainll)) def test_train_degenerate(self, params='wmc'): - # Train on degenerate data with 0 in some dimensions - # Create a training set by sampling from the predefined distribution. - X = rng.randn(100, self.n_features) - X.T[1:] = 0 - g = self.model(n_components=2, covariance_type=self.covariance_type, - random_state=rng, min_covar=1e-3, n_iter=5, - init_params=params) - g.fit(X) - trainll = g.score(X) - self.assertTrue(np.sum(np.abs(trainll / 100 / X.shape[1])) < 5) + # This function tests the deprecated old GMM class + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + # Train on degenerate data with 0 in some dimensions + # Create a training set by sampling from the predefined + # distribution. + X = rng.randn(100, self.n_features) + X.T[1:] = 0 + g = self.model(n_components=2, + covariance_type=self.covariance_type, + random_state=rng, min_covar=1e-3, n_iter=5, + init_params=params) + g.fit(X) + trainll = g.score(X) + self.assertTrue(np.sum(np.abs(trainll / 100 / X.shape[1])) < 5) def test_train_1d(self, params='wmc'): - # Train on 1-D data - # Create a training set by sampling from the predefined distribution. - X = rng.randn(100, 1) - # X.T[1:] = 0 - g = self.model(n_components=2, covariance_type=self.covariance_type, - random_state=rng, min_covar=1e-7, n_iter=5, - init_params=params) - g.fit(X) - trainll = g.score(X) - if isinstance(g, mixture.DPGMM): - self.assertTrue(np.sum(np.abs(trainll / 100)) < 5) - else: - self.assertTrue(np.sum(np.abs(trainll / 100)) < 2) + # This function tests the deprecated old GMM class + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + # Train on 1-D data + # Create a training set by sampling from the predefined + # distribution. + X = rng.randn(100, 1) + # X.T[1:] = 0 + g = self.model(n_components=2, + covariance_type=self.covariance_type, + random_state=rng, min_covar=1e-7, n_iter=5, + init_params=params) + g.fit(X) + trainll = g.score(X) + if isinstance(g, mixture.DPGMM): + self.assertTrue(np.sum(np.abs(trainll / 100)) < 5) + else: + self.assertTrue(np.sum(np.abs(trainll / 100)) < 2) def score(self, g, X): - return g.score(X).sum() + # This function tests the deprecated old GMM class + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + return g.score(X).sum() class TestGMMWithSphericalCovars(unittest.TestCase, GMMTester): @@ -304,93 +334,111 @@ class TestGMMWithFullCovars(unittest.TestCase, GMMTester): def test_multiple_init(): - # Test that multiple inits does not much worse than a single one - X = rng.randn(30, 5) - X[:10] += 2 - g = mixture.GMM(n_components=2, covariance_type='spherical', - random_state=rng, min_covar=1e-7, n_iter=5) - train1 = g.fit(X).score(X).sum() - g.n_init = 5 - train2 = g.fit(X).score(X).sum() - assert_true(train2 >= train1 - 1.e-2) + # This function tests the deprecated old GMM class + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + # Test that multiple inits does not much worse than a single one + X = rng.randn(30, 5) + X[:10] += 2 + g = mixture.GMM(n_components=2, covariance_type='spherical', + random_state=rng, min_covar=1e-7, n_iter=5) + train1 = g.fit(X).score(X).sum() + g.n_init = 5 + train2 = g.fit(X).score(X).sum() + assert_true(train2 >= train1 - 1.e-2) def test_n_parameters(): - # Test that the right number of parameters is estimated - n_samples, n_dim, n_components = 7, 5, 2 - X = rng.randn(n_samples, n_dim) - n_params = {'spherical': 13, 'diag': 21, 'tied': 26, 'full': 41} - for cv_type in ['full', 'tied', 'diag', 'spherical']: - g = mixture.GMM(n_components=n_components, covariance_type=cv_type, - random_state=rng, min_covar=1e-7, n_iter=1) - g.fit(X) - assert_true(g._n_parameters() == n_params[cv_type]) + # This function tests the deprecated old GMM class + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + # Test that the right number of parameters is estimated + n_samples, n_dim, n_components = 7, 5, 2 + X = rng.randn(n_samples, n_dim) + n_params = {'spherical': 13, 'diag': 21, 'tied': 26, 'full': 41} + for cv_type in ['full', 'tied', 'diag', 'spherical']: + g = mixture.GMM(n_components=n_components, covariance_type=cv_type, + random_state=rng, min_covar=1e-7, n_iter=1) + g.fit(X) + assert_true(g._n_parameters() == n_params[cv_type]) def test_1d_1component(): - # Test all of the covariance_types return the same BIC score for - # 1-dimensional, 1 component fits. - n_samples, n_dim, n_components = 100, 1, 1 - X = rng.randn(n_samples, n_dim) - g_full = mixture.GMM(n_components=n_components, covariance_type='full', - random_state=rng, min_covar=1e-7, n_iter=1) - g_full.fit(X) - g_full_bic = g_full.bic(X) - for cv_type in ['tied', 'diag', 'spherical']: - g = mixture.GMM(n_components=n_components, covariance_type=cv_type, - random_state=rng, min_covar=1e-7, n_iter=1) - g.fit(X) - assert_array_almost_equal(g.bic(X), g_full_bic) + # This function tests the deprecated old GMM class + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + # Test all of the covariance_types return the same BIC score for + # 1-dimensional, 1 component fits. + n_samples, n_dim, n_components = 100, 1, 1 + X = rng.randn(n_samples, n_dim) + g_full = mixture.GMM(n_components=n_components, covariance_type='full', + random_state=rng, min_covar=1e-7, n_iter=1) + g_full.fit(X) + g_full_bic = g_full.bic(X) + for cv_type in ['tied', 'diag', 'spherical']: + g = mixture.GMM(n_components=n_components, covariance_type=cv_type, + random_state=rng, min_covar=1e-7, n_iter=1) + g.fit(X) + assert_array_almost_equal(g.bic(X), g_full_bic) def assert_fit_predict_correct(model, X): - model2 = copy.deepcopy(model) + # This function tests the deprecated old GMM class + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + model2 = copy.deepcopy(model) - predictions_1 = model.fit(X).predict(X) - predictions_2 = model2.fit_predict(X) + predictions_1 = model.fit(X).predict(X) + predictions_2 = model2.fit_predict(X) - assert adjusted_rand_score(predictions_1, predictions_2) == 1.0 + assert adjusted_rand_score(predictions_1, predictions_2) == 1.0 def test_fit_predict(): """ test that gmm.fit_predict is equivalent to gmm.fit + gmm.predict """ - lrng = np.random.RandomState(101) + # This function tests the deprecated old GMM class + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + lrng = np.random.RandomState(101) - n_samples, n_dim, n_comps = 100, 2, 2 - mu = np.array([[8, 8]]) - component_0 = lrng.randn(n_samples, n_dim) - component_1 = lrng.randn(n_samples, n_dim) + mu - X = np.vstack((component_0, component_1)) + n_samples, n_dim, n_comps = 100, 2, 2 + mu = np.array([[8, 8]]) + component_0 = lrng.randn(n_samples, n_dim) + component_1 = lrng.randn(n_samples, n_dim) + mu + X = np.vstack((component_0, component_1)) - for m_constructor in (mixture.GMM, mixture.VBGMM, mixture.DPGMM): - model = m_constructor(n_components=n_comps, covariance_type='full', - min_covar=1e-7, n_iter=5, - random_state=np.random.RandomState(0)) - assert_fit_predict_correct(model, X) + for m_constructor in (mixture.GMM, mixture.VBGMM, mixture.DPGMM): + model = m_constructor(n_components=n_comps, covariance_type='full', + min_covar=1e-7, n_iter=5, + random_state=np.random.RandomState(0)) + assert_fit_predict_correct(model, X) - model = mixture.GMM(n_components=n_comps, n_iter=0) - z = model.fit_predict(X) - assert np.all(z == 0), "Quick Initialization Failed!" + model = mixture.GMM(n_components=n_comps, n_iter=0) + z = model.fit_predict(X) + assert np.all(z == 0), "Quick Initialization Failed!" def test_aic(): - # Test the aic and bic criteria - n_samples, n_dim, n_components = 50, 3, 2 - X = rng.randn(n_samples, n_dim) - SGH = 0.5 * (X.var() + np.log(2 * np.pi)) # standard gaussian entropy - - for cv_type in ['full', 'tied', 'diag', 'spherical']: - g = mixture.GMM(n_components=n_components, covariance_type=cv_type, - random_state=rng, min_covar=1e-7) - g.fit(X) - aic = 2 * n_samples * SGH * n_dim + 2 * g._n_parameters() - bic = (2 * n_samples * SGH * n_dim + - np.log(n_samples) * g._n_parameters()) - bound = n_dim * 3. / np.sqrt(n_samples) - assert_true(np.abs(g.aic(X) - aic) / n_samples < bound) - assert_true(np.abs(g.bic(X) - bic) / n_samples < bound) + # This function tests the deprecated old GMM class + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + # Test the aic and bic criteria + n_samples, n_dim, n_components = 50, 3, 2 + X = rng.randn(n_samples, n_dim) + SGH = 0.5 * (X.var() + np.log(2 * np.pi)) # standard gaussian entropy + + for cv_type in ['full', 'tied', 'diag', 'spherical']: + g = mixture.GMM(n_components=n_components, covariance_type=cv_type, + random_state=rng, min_covar=1e-7) + g.fit(X) + aic = 2 * n_samples * SGH * n_dim + 2 * g._n_parameters() + bic = (2 * n_samples * SGH * n_dim + + np.log(n_samples) * g._n_parameters()) + bound = n_dim * 3. / np.sqrt(n_samples) + assert_true(np.abs(g.aic(X) - aic) / n_samples < bound) + assert_true(np.abs(g.bic(X) - bic) / n_samples < bound) def check_positive_definite_covars(covariance_type): @@ -412,30 +460,33 @@ def check_positive_definite_covars(covariance_type): This function ensures that some later optimization will not introduce the problem again. """ - rng = np.random.RandomState(1) - # we build a dataset with 2 2d component. The components are unbalanced - # (respective weights 0.9 and 0.1) - X = rng.randn(100, 2) - X[-10:] += (3, 3) # Shift the 10 last points - - gmm = mixture.GMM(2, params="wc", covariance_type=covariance_type, - min_covar=1e-3) - - # This is a non-regression test for issue #2640. The following call used - # to trigger: - # numpy.linalg.linalg.LinAlgError: 2-th leading minor not positive definite - gmm.fit(X) - - if covariance_type == "diag" or covariance_type == "spherical": - assert_greater(gmm.covars_.min(), 0) - else: - if covariance_type == "tied": - covs = [gmm.covars_] + # This function tests the deprecated old GMM class + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + rng = np.random.RandomState(1) + # we build a dataset with 2 2d component. The components are unbalanced + # (respective weights 0.9 and 0.1) + X = rng.randn(100, 2) + X[-10:] += (3, 3) # Shift the 10 last points + + gmm = mixture.GMM(2, params="wc", covariance_type=covariance_type, + min_covar=1e-3) + + # This is a non-regression test for issue #2640. The following call used + # to trigger: + # numpy.linalg.linalg.LinAlgError: 2-th leading minor not positive definite + gmm.fit(X) + + if covariance_type == "diag" or covariance_type == "spherical": + assert_greater(gmm.covars_.min(), 0) else: - covs = gmm.covars_ + if covariance_type == "tied": + covs = [gmm.covars_] + else: + covs = gmm.covars_ - for c in covs: - assert_greater(np.linalg.det(c), 0) + for c in covs: + assert_greater(np.linalg.det(c), 0) def test_positive_definite_covars(): @@ -445,28 +496,34 @@ def test_positive_definite_covars(): def test_verbose_first_level(): - # Create sample data - X = rng.randn(30, 5) - X[:10] += 2 - g = mixture.GMM(n_components=2, n_init=2, verbose=1) - - old_stdout = sys.stdout - sys.stdout = StringIO() - try: - g.fit(X) - finally: - sys.stdout = old_stdout + # This function tests the deprecated old GMM class + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + # Create sample data + X = rng.randn(30, 5) + X[:10] += 2 + g = mixture.GMM(n_components=2, n_init=2, verbose=1) + + old_stdout = sys.stdout + sys.stdout = StringIO() + try: + g.fit(X) + finally: + sys.stdout = old_stdout def test_verbose_second_level(): - # Create sample data - X = rng.randn(30, 5) - X[:10] += 2 - g = mixture.GMM(n_components=2, n_init=2, verbose=2) - - old_stdout = sys.stdout - sys.stdout = StringIO() - try: - g.fit(X) - finally: - sys.stdout = old_stdout + # This function tests the deprecated old GMM class + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + # Create sample data + X = rng.randn(30, 5) + X[:10] += 2 + g = mixture.GMM(n_components=2, n_init=2, verbose=2) + + old_stdout = sys.stdout + sys.stdout = StringIO() + try: + g.fit(X) + finally: + sys.stdout = old_stdout From 6fb6c63b7b0c5565244018ed55b5c335546a0496 Mon Sep 17 00:00:00 2001 From: Thierry Date: Wed, 20 Apr 2016 18:28:36 +0200 Subject: [PATCH 2/2] Modification of the ignore_warning function and _IgnoreWarning class. --- doc/whats_new.rst | 15 +- sklearn/mixture/gaussian_mixture.py | 2 +- sklearn/mixture/tests/test_gmm.py | 558 ++++++++++++++-------------- sklearn/utils/testing.py | 98 +++-- sklearn/utils/tests/test_testing.py | 95 ++++- 5 files changed, 424 insertions(+), 344 deletions(-) diff --git a/doc/whats_new.rst b/doc/whats_new.rst index 42d6290a8cd95..ea21b88bc70ef 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -131,9 +131,12 @@ Enhancements - Add option to show ``indicator features`` in the output of Imputer. By `Mani Teja`_. - - Reduce the memory usage for 32-bit float input arrays of :func:`utils.mean_variance_axis` and + - Reduce the memory usage for 32-bit float input arrays of :func:`utils.mean_variance_axis` and :func:`utils.incr_mean_variance_axis` by supporting cython fused types. By `YenChen Lin`_. + - The :func: `ignore_warnings` now accept a category argument to ignore only + the warnings of a specified type. By `Thierry Guillemot`_. + Bug fixes ......... @@ -201,6 +204,12 @@ API changes summary - Access to public attributes ``.X_`` and ``.y_`` has been deprecated in :class:`isotonic.IsotonicRegression`. By `Jonathan Arfa`_. + - The old :class:`GMM` is deprecated in favor of the new + :class:`GaussianMixture`. The new class compute the Gaussian mixture + faster than before and some of computationnal problems have been solved. + By `Wei Xue`_ and `Thierry Guillemot`_. + + .. _changes_0_17_1: @@ -4151,3 +4160,7 @@ David Huard, Dave Morrill, Ed Schofield, Travis Oliphant, Pearu Peterson. .. _JPFrancoia: https://github.com/JPFrancoia .. _Mani Teja: https://github.com/maniteja123 + +.. _Thierry Guillemot: https://github.com/tguillemot + +.. _Wei Xue: https://github.com/xuewei4d \ No newline at end of file diff --git a/sklearn/mixture/gaussian_mixture.py b/sklearn/mixture/gaussian_mixture.py index 90c2ed7ea0ba0..37bb35cbd15ec 100644 --- a/sklearn/mixture/gaussian_mixture.py +++ b/sklearn/mixture/gaussian_mixture.py @@ -279,7 +279,7 @@ def _estimate_gaussian_parameters(X, resp, reg_covar, covariance_type): "diag": _estimate_gaussian_covariance_diag, "spherical": _estimate_gaussian_covariance_spherical} - nk = resp.sum(axis=0) + 10 * np.finfo(float).eps + nk = resp.sum(axis=0) + 10 * np.finfo(resp.dtype).eps means = np.dot(resp.T, X) / nk[:, np.newaxis] covariances = compute_covariance[covariance_type]( resp, X, nk, means, reg_covar) diff --git a/sklearn/mixture/tests/test_gmm.py b/sklearn/mixture/tests/test_gmm.py index b726ebebb1b60..0d06320608b51 100644 --- a/sklearn/mixture/tests/test_gmm.py +++ b/sklearn/mixture/tests/test_gmm.py @@ -16,6 +16,7 @@ from sklearn.utils.testing import assert_raise_message from sklearn.metrics.cluster import adjusted_rand_score from sklearn.externals.six.moves import cStringIO as StringIO +from sklearn.utils.testing import ignore_warnings rng = np.random.RandomState(0) @@ -124,34 +125,33 @@ def test_lvmpdf_full_cv_non_positive_definite(): X, mu, cv, 'full') +# This function tests the deprecated old GMM class +@ignore_warnings(category=DeprecationWarning) def test_GMM_attributes(): - # This function tests the deprecated old GMM class - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - n_components, n_features = 10, 4 - covariance_type = 'diag' - g = mixture.GMM(n_components, covariance_type, random_state=rng) - weights = rng.rand(n_components) - weights = weights / weights.sum() - means = rng.randint(-20, 20, (n_components, n_features)) - - assert_true(g.n_components == n_components) - assert_true(g.covariance_type == covariance_type) - - g.weights_ = weights - assert_array_almost_equal(g.weights_, weights) - g.means_ = means - assert_array_almost_equal(g.means_, means) - - covars = (0.1 + 2 * rng.rand(n_components, n_features)) ** 2 - g.covars_ = covars - assert_array_almost_equal(g.covars_, covars) - assert_raises(ValueError, g._set_covars, []) - assert_raises(ValueError, g._set_covars, - np.zeros((n_components - 2, n_features))) - - assert_raises(ValueError, mixture.GMM, n_components=20, - covariance_type='badcovariance_type') + n_components, n_features = 10, 4 + covariance_type = 'diag' + g = mixture.GMM(n_components, covariance_type, random_state=rng) + weights = rng.rand(n_components) + weights = weights / weights.sum() + means = rng.randint(-20, 20, (n_components, n_features)) + + assert_true(g.n_components == n_components) + assert_true(g.covariance_type == covariance_type) + + g.weights_ = weights + assert_array_almost_equal(g.weights_, weights) + g.means_ = means + assert_array_almost_equal(g.means_, means) + + covars = (0.1 + 2 * rng.rand(n_components, n_features)) ** 2 + g.covars_ = covars + assert_array_almost_equal(g.covars_, covars) + assert_raises(ValueError, g._set_covars, []) + assert_raises(ValueError, g._set_covars, + np.zeros((n_components - 2, n_features))) + + assert_raises(ValueError, mixture.GMM, n_components=20, + covariance_type='badcovariance_type') class GMMTester(): @@ -175,138 +175,132 @@ def _setUp(self): 'full': np.array([make_spd_matrix(self.n_features, random_state=0) + 5 * self.I for x in range(self.n_components)])} + # This function tests the deprecated old GMM class + @ignore_warnings(category=DeprecationWarning) def test_eval(self): - # This function tests the deprecated old GMM class - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - if not self.do_test_eval: - return # DPGMM does not support setting the means and - # covariances before fitting There is no way of fixing this - # due to the variational parameters being more expressive than - # covariance matrices - g = self.model(n_components=self.n_components, - covariance_type=self.covariance_type, random_state=rng) - # Make sure the means are far apart so responsibilities.argmax() - # picks the actual component used to generate the observations. - g.means_ = 20 * self.means - g.covars_ = self.covars[self.covariance_type] - g.weights_ = self.weights - - gaussidx = np.repeat(np.arange(self.n_components), 5) - n_samples = len(gaussidx) - X = rng.randn(n_samples, self.n_features) + g.means_[gaussidx] - - ll, responsibilities = g.score_samples(X) - - self.assertEqual(len(ll), n_samples) - self.assertEqual(responsibilities.shape, - (n_samples, self.n_components)) - assert_array_almost_equal(responsibilities.sum(axis=1), - np.ones(n_samples)) - assert_array_equal(responsibilities.argmax(axis=1), gaussidx) + if not self.do_test_eval: + return # DPGMM does not support setting the means and + # covariances before fitting There is no way of fixing this + # due to the variational parameters being more expressive than + # covariance matrices + g = self.model(n_components=self.n_components, + covariance_type=self.covariance_type, random_state=rng) + # Make sure the means are far apart so responsibilities.argmax() + # picks the actual component used to generate the observations. + g.means_ = 20 * self.means + g.covars_ = self.covars[self.covariance_type] + g.weights_ = self.weights + + gaussidx = np.repeat(np.arange(self.n_components), 5) + n_samples = len(gaussidx) + X = rng.randn(n_samples, self.n_features) + g.means_[gaussidx] + + ll, responsibilities = g.score_samples(X) + + self.assertEqual(len(ll), n_samples) + self.assertEqual(responsibilities.shape, + (n_samples, self.n_components)) + assert_array_almost_equal(responsibilities.sum(axis=1), + np.ones(n_samples)) + assert_array_equal(responsibilities.argmax(axis=1), gaussidx) + # This function tests the deprecated old GMM class + @ignore_warnings(category=DeprecationWarning) def test_sample(self, n=100): - # This function tests the deprecated old GMM class - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - g = self.model(n_components=self.n_components, - covariance_type=self.covariance_type, - random_state=rng) - # Make sure the means are far apart so responsibilities.argmax() - # picks the actual component used to generate the observations. - g.means_ = 20 * self.means - g.covars_ = np.maximum(self.covars[self.covariance_type], 0.1) - g.weights_ = self.weights - - samples = g.sample(n) - self.assertEqual(samples.shape, (n, self.n_features)) + g = self.model(n_components=self.n_components, + covariance_type=self.covariance_type, + random_state=rng) + # Make sure the means are far apart so responsibilities.argmax() + # picks the actual component used to generate the observations. + g.means_ = 20 * self.means + g.covars_ = np.maximum(self.covars[self.covariance_type], 0.1) + g.weights_ = self.weights + + samples = g.sample(n) + self.assertEqual(samples.shape, (n, self.n_features)) + # This function tests the deprecated old GMM class + @ignore_warnings(category=DeprecationWarning) def test_train(self, params='wmc'): - # This function tests the deprecated old GMM class - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - g = mixture.GMM(n_components=self.n_components, - covariance_type=self.covariance_type) - g.weights_ = self.weights - g.means_ = self.means - g.covars_ = 20 * self.covars[self.covariance_type] - - # Create a training set by sampling from the predefined distribution. - X = g.sample(n_samples=100) - g = self.model(n_components=self.n_components, - covariance_type=self.covariance_type, - random_state=rng, min_covar=1e-1, - n_iter=1, init_params=params) - g.fit(X) - - # Do one training iteration at a time so we can keep track of - # the log likelihood to make sure that it increases after each - # iteration. - trainll = [] - for _ in range(5): - g.params = params - g.init_params = '' - g.fit(X) - trainll.append(self.score(g, X)) - g.n_iter = 10 - g.init_params = '' + g = mixture.GMM(n_components=self.n_components, + covariance_type=self.covariance_type) + g.weights_ = self.weights + g.means_ = self.means + g.covars_ = 20 * self.covars[self.covariance_type] + + # Create a training set by sampling from the predefined distribution. + X = g.sample(n_samples=100) + g = self.model(n_components=self.n_components, + covariance_type=self.covariance_type, + random_state=rng, min_covar=1e-1, + n_iter=1, init_params=params) + g.fit(X) + + # Do one training iteration at a time so we can keep track of + # the log likelihood to make sure that it increases after each + # iteration. + trainll = [] + for _ in range(5): g.params = params - g.fit(X) # finish fitting - - # Note that the log likelihood will sometimes decrease by a - # very small amount after it has more or less converged due to - # the addition of min_covar to the covariance (to prevent - # underflow). This is why the threshold is set to -0.5 - # instead of 0. - delta_min = np.diff(trainll).min() - self.assertTrue( - delta_min > self.threshold, - "The min nll increase is %f which is lower than the admissible" - " threshold of %f, for model %s. The likelihoods are %s." - % (delta_min, self.threshold, self.covariance_type, trainll)) + g.init_params = '' + g.fit(X) + trainll.append(self.score(g, X)) + g.n_iter = 10 + g.init_params = '' + g.params = params + g.fit(X) # finish fitting + + # Note that the log likelihood will sometimes decrease by a + # very small amount after it has more or less converged due to + # the addition of min_covar to the covariance (to prevent + # underflow). This is why the threshold is set to -0.5 + # instead of 0. + delta_min = np.diff(trainll).min() + self.assertTrue( + delta_min > self.threshold, + "The min nll increase is %f which is lower than the admissible" + " threshold of %f, for model %s. The likelihoods are %s." + % (delta_min, self.threshold, self.covariance_type, trainll)) + # This function tests the deprecated old GMM class + @ignore_warnings(category=DeprecationWarning) def test_train_degenerate(self, params='wmc'): - # This function tests the deprecated old GMM class - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - # Train on degenerate data with 0 in some dimensions - # Create a training set by sampling from the predefined - # distribution. - X = rng.randn(100, self.n_features) - X.T[1:] = 0 - g = self.model(n_components=2, - covariance_type=self.covariance_type, - random_state=rng, min_covar=1e-3, n_iter=5, - init_params=params) - g.fit(X) - trainll = g.score(X) - self.assertTrue(np.sum(np.abs(trainll / 100 / X.shape[1])) < 5) + # Train on degenerate data with 0 in some dimensions + # Create a training set by sampling from the predefined + # distribution. + X = rng.randn(100, self.n_features) + X.T[1:] = 0 + g = self.model(n_components=2, + covariance_type=self.covariance_type, + random_state=rng, min_covar=1e-3, n_iter=5, + init_params=params) + g.fit(X) + trainll = g.score(X) + self.assertTrue(np.sum(np.abs(trainll / 100 / X.shape[1])) < 5) + # This function tests the deprecated old GMM class + @ignore_warnings(category=DeprecationWarning) def test_train_1d(self, params='wmc'): - # This function tests the deprecated old GMM class - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - # Train on 1-D data - # Create a training set by sampling from the predefined - # distribution. - X = rng.randn(100, 1) - # X.T[1:] = 0 - g = self.model(n_components=2, - covariance_type=self.covariance_type, - random_state=rng, min_covar=1e-7, n_iter=5, - init_params=params) - g.fit(X) - trainll = g.score(X) - if isinstance(g, mixture.DPGMM): - self.assertTrue(np.sum(np.abs(trainll / 100)) < 5) - else: - self.assertTrue(np.sum(np.abs(trainll / 100)) < 2) + # Train on 1-D data + # Create a training set by sampling from the predefined + # distribution. + X = rng.randn(100, 1) + # X.T[1:] = 0 + g = self.model(n_components=2, + covariance_type=self.covariance_type, + random_state=rng, min_covar=1e-7, n_iter=5, + init_params=params) + g.fit(X) + trainll = g.score(X) + if isinstance(g, mixture.DPGMM): + self.assertTrue(np.sum(np.abs(trainll / 100)) < 5) + else: + self.assertTrue(np.sum(np.abs(trainll / 100)) < 2) + # This function tests the deprecated old GMM class + @ignore_warnings(category=DeprecationWarning) def score(self, g, X): - # This function tests the deprecated old GMM class - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - return g.score(X).sum() + return g.score(X).sum() class TestGMMWithSphericalCovars(unittest.TestCase, GMMTester): @@ -333,114 +327,107 @@ class TestGMMWithFullCovars(unittest.TestCase, GMMTester): setUp = GMMTester._setUp +# This function tests the deprecated old GMM class +@ignore_warnings(category=DeprecationWarning) def test_multiple_init(): - # This function tests the deprecated old GMM class - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - # Test that multiple inits does not much worse than a single one - X = rng.randn(30, 5) - X[:10] += 2 - g = mixture.GMM(n_components=2, covariance_type='spherical', - random_state=rng, min_covar=1e-7, n_iter=5) - train1 = g.fit(X).score(X).sum() - g.n_init = 5 - train2 = g.fit(X).score(X).sum() - assert_true(train2 >= train1 - 1.e-2) - - + # Test that multiple inits does not much worse than a single one + X = rng.randn(30, 5) + X[:10] += 2 + g = mixture.GMM(n_components=2, covariance_type='spherical', + random_state=rng, min_covar=1e-7, n_iter=5) + train1 = g.fit(X).score(X).sum() + g.n_init = 5 + train2 = g.fit(X).score(X).sum() + assert_true(train2 >= train1 - 1.e-2) + + +# This function tests the deprecated old GMM class +@ignore_warnings(category=DeprecationWarning) def test_n_parameters(): - # This function tests the deprecated old GMM class - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - # Test that the right number of parameters is estimated - n_samples, n_dim, n_components = 7, 5, 2 - X = rng.randn(n_samples, n_dim) - n_params = {'spherical': 13, 'diag': 21, 'tied': 26, 'full': 41} - for cv_type in ['full', 'tied', 'diag', 'spherical']: - g = mixture.GMM(n_components=n_components, covariance_type=cv_type, - random_state=rng, min_covar=1e-7, n_iter=1) - g.fit(X) - assert_true(g._n_parameters() == n_params[cv_type]) - - + n_samples, n_dim, n_components = 7, 5, 2 + X = rng.randn(n_samples, n_dim) + n_params = {'spherical': 13, 'diag': 21, 'tied': 26, 'full': 41} + for cv_type in ['full', 'tied', 'diag', 'spherical']: + g = mixture.GMM(n_components=n_components, covariance_type=cv_type, + random_state=rng, min_covar=1e-7, n_iter=1) + g.fit(X) + assert_true(g._n_parameters() == n_params[cv_type]) + + +# This function tests the deprecated old GMM class +@ignore_warnings(category=DeprecationWarning) def test_1d_1component(): - # This function tests the deprecated old GMM class - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - # Test all of the covariance_types return the same BIC score for - # 1-dimensional, 1 component fits. - n_samples, n_dim, n_components = 100, 1, 1 - X = rng.randn(n_samples, n_dim) - g_full = mixture.GMM(n_components=n_components, covariance_type='full', - random_state=rng, min_covar=1e-7, n_iter=1) - g_full.fit(X) - g_full_bic = g_full.bic(X) - for cv_type in ['tied', 'diag', 'spherical']: - g = mixture.GMM(n_components=n_components, covariance_type=cv_type, - random_state=rng, min_covar=1e-7, n_iter=1) - g.fit(X) - assert_array_almost_equal(g.bic(X), g_full_bic) + # Test all of the covariance_types return the same BIC score for + # 1-dimensional, 1 component fits. + n_samples, n_dim, n_components = 100, 1, 1 + X = rng.randn(n_samples, n_dim) + g_full = mixture.GMM(n_components=n_components, covariance_type='full', + random_state=rng, min_covar=1e-7, n_iter=1) + g_full.fit(X) + g_full_bic = g_full.bic(X) + for cv_type in ['tied', 'diag', 'spherical']: + g = mixture.GMM(n_components=n_components, covariance_type=cv_type, + random_state=rng, min_covar=1e-7, n_iter=1) + g.fit(X) + assert_array_almost_equal(g.bic(X), g_full_bic) def assert_fit_predict_correct(model, X): - # This function tests the deprecated old GMM class - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - model2 = copy.deepcopy(model) + model2 = copy.deepcopy(model) - predictions_1 = model.fit(X).predict(X) - predictions_2 = model2.fit_predict(X) + predictions_1 = model.fit(X).predict(X) + predictions_2 = model2.fit_predict(X) - assert adjusted_rand_score(predictions_1, predictions_2) == 1.0 + assert adjusted_rand_score(predictions_1, predictions_2) == 1.0 +# This function tests the deprecated old GMM class +@ignore_warnings(category=DeprecationWarning) def test_fit_predict(): """ test that gmm.fit_predict is equivalent to gmm.fit + gmm.predict """ - # This function tests the deprecated old GMM class - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - lrng = np.random.RandomState(101) + lrng = np.random.RandomState(101) - n_samples, n_dim, n_comps = 100, 2, 2 - mu = np.array([[8, 8]]) - component_0 = lrng.randn(n_samples, n_dim) - component_1 = lrng.randn(n_samples, n_dim) + mu - X = np.vstack((component_0, component_1)) + n_samples, n_dim, n_comps = 100, 2, 2 + mu = np.array([[8, 8]]) + component_0 = lrng.randn(n_samples, n_dim) + component_1 = lrng.randn(n_samples, n_dim) + mu + X = np.vstack((component_0, component_1)) - for m_constructor in (mixture.GMM, mixture.VBGMM, mixture.DPGMM): - model = m_constructor(n_components=n_comps, covariance_type='full', - min_covar=1e-7, n_iter=5, - random_state=np.random.RandomState(0)) - assert_fit_predict_correct(model, X) + for m_constructor in (mixture.GMM, mixture.VBGMM, mixture.DPGMM): + model = m_constructor(n_components=n_comps, covariance_type='full', + min_covar=1e-7, n_iter=5, + random_state=np.random.RandomState(0)) + assert_fit_predict_correct(model, X) - model = mixture.GMM(n_components=n_comps, n_iter=0) - z = model.fit_predict(X) - assert np.all(z == 0), "Quick Initialization Failed!" + model = mixture.GMM(n_components=n_comps, n_iter=0) + z = model.fit_predict(X) + assert np.all(z == 0), "Quick Initialization Failed!" +# This function tests the deprecated old GMM class +@ignore_warnings(category=DeprecationWarning) def test_aic(): - # This function tests the deprecated old GMM class - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - # Test the aic and bic criteria - n_samples, n_dim, n_components = 50, 3, 2 - X = rng.randn(n_samples, n_dim) - SGH = 0.5 * (X.var() + np.log(2 * np.pi)) # standard gaussian entropy - - for cv_type in ['full', 'tied', 'diag', 'spherical']: - g = mixture.GMM(n_components=n_components, covariance_type=cv_type, - random_state=rng, min_covar=1e-7) - g.fit(X) - aic = 2 * n_samples * SGH * n_dim + 2 * g._n_parameters() - bic = (2 * n_samples * SGH * n_dim + - np.log(n_samples) * g._n_parameters()) - bound = n_dim * 3. / np.sqrt(n_samples) - assert_true(np.abs(g.aic(X) - aic) / n_samples < bound) - assert_true(np.abs(g.bic(X) - bic) / n_samples < bound) - - + # Test the aic and bic criteria + n_samples, n_dim, n_components = 50, 3, 2 + X = rng.randn(n_samples, n_dim) + SGH = 0.5 * (X.var() + np.log(2 * np.pi)) # standard gaussian entropy + + for cv_type in ['full', 'tied', 'diag', 'spherical']: + g = mixture.GMM(n_components=n_components, covariance_type=cv_type, + random_state=rng, min_covar=1e-7) + g.fit(X) + aic = 2 * n_samples * SGH * n_dim + 2 * g._n_parameters() + bic = (2 * n_samples * SGH * n_dim + + np.log(n_samples) * g._n_parameters()) + bound = n_dim * 3. / np.sqrt(n_samples) + assert_true(np.abs(g.aic(X) - aic) / n_samples < bound) + assert_true(np.abs(g.bic(X) - bic) / n_samples < bound) + + +# This function tests the deprecated old GMM class +@ignore_warnings(category=DeprecationWarning) def check_positive_definite_covars(covariance_type): r"""Test that covariance matrices do not become non positive definite @@ -460,33 +447,30 @@ def check_positive_definite_covars(covariance_type): This function ensures that some later optimization will not introduce the problem again. """ - # This function tests the deprecated old GMM class - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - rng = np.random.RandomState(1) - # we build a dataset with 2 2d component. The components are unbalanced - # (respective weights 0.9 and 0.1) - X = rng.randn(100, 2) - X[-10:] += (3, 3) # Shift the 10 last points - - gmm = mixture.GMM(2, params="wc", covariance_type=covariance_type, - min_covar=1e-3) - - # This is a non-regression test for issue #2640. The following call used - # to trigger: - # numpy.linalg.linalg.LinAlgError: 2-th leading minor not positive definite - gmm.fit(X) - - if covariance_type == "diag" or covariance_type == "spherical": - assert_greater(gmm.covars_.min(), 0) + rng = np.random.RandomState(1) + # we build a dataset with 2 2d component. The components are unbalanced + # (respective weights 0.9 and 0.1) + X = rng.randn(100, 2) + X[-10:] += (3, 3) # Shift the 10 last points + + gmm = mixture.GMM(2, params="wc", covariance_type=covariance_type, + min_covar=1e-3) + + # This is a non-regression test for issue #2640. The following call used + # to trigger: + # numpy.linalg.linalg.LinAlgError: 2-th leading minor not positive definite + gmm.fit(X) + + if covariance_type == "diag" or covariance_type == "spherical": + assert_greater(gmm.covars_.min(), 0) + else: + if covariance_type == "tied": + covs = [gmm.covars_] else: - if covariance_type == "tied": - covs = [gmm.covars_] - else: - covs = gmm.covars_ + covs = gmm.covars_ - for c in covs: - assert_greater(np.linalg.det(c), 0) + for c in covs: + assert_greater(np.linalg.det(c), 0) def test_positive_definite_covars(): @@ -495,35 +479,33 @@ def test_positive_definite_covars(): yield check_positive_definite_covars, covariance_type +# This function tests the deprecated old GMM class +@ignore_warnings(category=DeprecationWarning) def test_verbose_first_level(): - # This function tests the deprecated old GMM class - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - # Create sample data - X = rng.randn(30, 5) - X[:10] += 2 - g = mixture.GMM(n_components=2, n_init=2, verbose=1) - - old_stdout = sys.stdout - sys.stdout = StringIO() - try: - g.fit(X) - finally: - sys.stdout = old_stdout + # Create sample data + X = rng.randn(30, 5) + X[:10] += 2 + g = mixture.GMM(n_components=2, n_init=2, verbose=1) + + old_stdout = sys.stdout + sys.stdout = StringIO() + try: + g.fit(X) + finally: + sys.stdout = old_stdout +# This function tests the deprecated old GMM class +@ignore_warnings(category=DeprecationWarning) def test_verbose_second_level(): - # This function tests the deprecated old GMM class - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - # Create sample data - X = rng.randn(30, 5) - X[:10] += 2 - g = mixture.GMM(n_components=2, n_init=2, verbose=2) - - old_stdout = sys.stdout - sys.stdout = StringIO() - try: - g.fit(X) - finally: - sys.stdout = old_stdout + # Create sample data + X = rng.randn(30, 5) + X[:10] += 2 + g = mixture.GMM(n_components=2, n_init=2, verbose=2) + + old_stdout = sys.stdout + sys.stdout = StringIO() + try: + g.fit(X) + finally: + sys.stdout = old_stdout diff --git a/sklearn/utils/testing.py b/sklearn/utils/testing.py index 4db2addda92c0..47b37576348c0 100644 --- a/sklearn/utils/testing.py +++ b/sklearn/utils/testing.py @@ -8,6 +8,7 @@ # Arnaud Joly # Denis Engemann # Giorgio Patrini +# Thierry Guillemot # License: BSD 3 clause import os import inspect @@ -93,8 +94,7 @@ def assert_not_in(x, container): # for Python 2 def assert_raises_regex(expected_exception, expected_regexp, callable_obj=None, *args, **kwargs): - """Helper function to check for message patterns in exceptions""" - + """Helper function to check for message patterns in exceptions.""" not_raised = False try: callable_obj(*args, **kwargs) @@ -165,7 +165,6 @@ def assert_warns(warning_class, func, *args, **kw): result : the return value of `func` """ - # very important to avoid uncontrolled state propagation clean_warning_registry() with warnings.catch_warnings(record=True) as w: @@ -282,13 +281,18 @@ def assert_no_warnings(func, *args, **kw): return result -def ignore_warnings(obj=None): - """ Context manager and decorator to ignore warnings +def ignore_warnings(obj=None, category=Warning): + """Context manager and decorator to ignore warnings. Note. Using this (in both variants) will clear all warnings from all python modules loaded. In case you need to test cross-module-warning-logging this is not your tool of choice. + Parameters + ---------- + category : warning class, defaults to Warning. + The category to filter. If Warning, all categories will be muted. + Examples -------- >>> with ignore_warnings(): @@ -300,47 +304,44 @@ def ignore_warnings(obj=None): >>> ignore_warnings(nasty_warn)() 42 - """ if callable(obj): - return _ignore_warnings(obj) + return _IgnoreWarnings(category=category)(obj) else: - return _IgnoreWarnings() - - -def _ignore_warnings(fn): - """Decorator to catch and hide warnings without visual nesting""" - @wraps(fn) - def wrapper(*args, **kwargs): - # very important to avoid uncontrolled state propagation - clean_warning_registry() - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter('always') - return fn(*args, **kwargs) - w[:] = [] - - return wrapper + return _IgnoreWarnings(category=category) class _IgnoreWarnings(object): + """Improved and simplified Python warnings context manager and decorator. - """Improved and simplified Python warnings context manager - + This class allows to ignore the warnings raise by a function. Copied from Python 2.7.5 and modified as required. + + Parameters + ---------- + category : tuple of warning class, defaut to Warning + The category to filter. By default, all the categories will be muted. + """ - def __init__(self): - """ - Parameters - ========== - category : warning class - The category to filter. Defaults to Warning. If None, - all categories will be muted. - """ + def __init__(self, category): self._record = True self._module = sys.modules['warnings'] self._entered = False self.log = [] + self.category = category + + def __call__(self, fn): + """Decorator to catch and hide warnings without visual nesting.""" + @wraps(fn) + def wrapper(*args, **kwargs): + # very important to avoid uncontrolled state propagation + clean_warning_registry() + with warnings.catch_warnings(): + warnings.simplefilter("ignore", self.category) + return fn(*args, **kwargs) + + return wrapper def __repr__(self): args = [] @@ -353,22 +354,13 @@ def __repr__(self): def __enter__(self): clean_warning_registry() # be safe and not propagate state + chaos - warnings.simplefilter('always') + warnings.simplefilter("ignore", self.category) if self._entered: raise RuntimeError("Cannot enter %r twice" % self) self._entered = True self._filters = self._module.filters self._module.filters = self._filters[:] self._showwarning = self._module.showwarning - if self._record: - self.log = [] - - def showwarning(*args, **kwargs): - self.log.append(warnings.WarningMessage(*args, **kwargs)) - self._module.showwarning = showwarning - return self.log - else: - return None def __exit__(self, *exc_info): if not self._entered: @@ -407,7 +399,7 @@ def _assert_allclose(actual, desired, rtol=1e-7, atol=0, def assert_raise_message(exceptions, message, function, *args, **kwargs): - """Helper function to test error messages in exceptions + """Helper function to test error messages in exceptions. Parameters ---------- @@ -616,8 +608,8 @@ def is_abstract(c): all_classes = set(all_classes) estimators = [c for c in all_classes - if (issubclass(c[1], BaseEstimator) - and c[0] != 'BaseEstimator')] + if (issubclass(c[1], BaseEstimator) and + c[0] != 'BaseEstimator')] # get rid of abstract base classes estimators = [c for c in estimators if not is_abstract(c[1])] @@ -647,7 +639,8 @@ def is_abstract(c): estimators = filtered_estimators if type_filter: raise ValueError("Parameter type_filter must be 'classifier', " - "'regressor', 'transformer', 'cluster' or None, got" + "'regressor', 'transformer', 'cluster' or " + "None, got" " %s." % repr(type_filter)) # drop duplicates, sort for reproducibility @@ -662,7 +655,6 @@ def set_random_state(estimator, random_state=0): Classes for whom random_state is deprecated are ignored. Currently DBSCAN is one such class. """ - if isinstance(estimator, DBSCAN): return @@ -671,8 +663,7 @@ def set_random_state(estimator, random_state=0): def if_matplotlib(func): - """Test decorator that skips test if matplotlib not installed. """ - + """Test decorator that skips test if matplotlib not installed.""" @wraps(func) def run_test(*args, **kwargs): try: @@ -723,7 +714,7 @@ def func(*args, **kwargs): def if_safe_multiprocessing_with_blas(func): - """Decorator for tests involving both BLAS calls and multiprocessing + """Decorator for tests involving both BLAS calls and multiprocessing. Under POSIX (e.g. Linux or OSX), using multiprocessing in conjunction with some implementation of BLAS (or other libraries that manage an internal @@ -740,7 +731,6 @@ def if_safe_multiprocessing_with_blas(func): for multiprocessing to avoid this issue. However it can cause pickling errors on interactively defined functions. It therefore not enabled by default. - """ @wraps(func) def run_test(*args, **kwargs): @@ -752,7 +742,7 @@ def run_test(*args, **kwargs): def clean_warning_registry(): - """Safe way to reset warnings """ + """Safe way to reset warnings.""" warnings.resetwarnings() reg = "__warningregistry__" for mod_name, mod in list(sys.modules.items()): @@ -775,7 +765,9 @@ def check_skip_travis(): def _delete_folder(folder_path, warn=False): """Utility function to cleanup a temporary folder if still existing. - Copy from joblib.pool (for independence)""" + + Copy from joblib.pool (for independence). + """ try: if os.path.exists(folder_path): # This can fail under windows, diff --git a/sklearn/utils/tests/test_testing.py b/sklearn/utils/tests/test_testing.py index b1d1b29b81e20..ea76333a6eafc 100644 --- a/sklearn/utils/tests/test_testing.py +++ b/sklearn/utils/tests/test_testing.py @@ -13,7 +13,8 @@ assert_no_warnings, assert_equal, set_random_state, - assert_raise_message) + assert_raise_message, + ignore_warnings) from sklearn.tree import DecisionTreeClassifier from sklearn.discriminant_analysis import LinearDiscriminantAnalysis @@ -96,6 +97,98 @@ def _no_raise(): "test", _no_raise) +def test_ignore_warning(): + # This check that ignore_warning decorateur and context manager are working + # as expected + def _warning_function(): + warnings.warn("deprecation warning", DeprecationWarning) + + def _multiple_warning_function(): + warnings.warn("deprecation warning", DeprecationWarning) + warnings.warn("deprecation warning") + + # Check the function directly + assert_no_warnings(ignore_warnings(_warning_function)) + assert_no_warnings(ignore_warnings(_warning_function, + category=DeprecationWarning)) + assert_warns(DeprecationWarning, ignore_warnings(_warning_function, + category=UserWarning)) + assert_warns(UserWarning, + ignore_warnings(_multiple_warning_function, + category=DeprecationWarning)) + assert_warns(DeprecationWarning, + ignore_warnings(_multiple_warning_function, + category=UserWarning)) + assert_no_warnings(ignore_warnings(_warning_function, + category=(DeprecationWarning, + UserWarning))) + + # Check the decorator + @ignore_warnings + def decorator_no_warning(): + _warning_function() + _multiple_warning_function() + + @ignore_warnings(category=(DeprecationWarning, UserWarning)) + def decorator_no_warning_multiple(): + _multiple_warning_function() + + @ignore_warnings(category=DeprecationWarning) + def decorator_no_deprecation_warning(): + _warning_function() + + @ignore_warnings(category=UserWarning) + def decorator_no_user_warning(): + _warning_function() + + @ignore_warnings(category=DeprecationWarning) + def decorator_no_deprecation_multiple_warning(): + _multiple_warning_function() + + @ignore_warnings(category=UserWarning) + def decorator_no_user_multiple_warning(): + _multiple_warning_function() + + assert_no_warnings(decorator_no_warning) + assert_no_warnings(decorator_no_warning_multiple) + assert_no_warnings(decorator_no_deprecation_warning) + assert_warns(DeprecationWarning, decorator_no_user_warning) + assert_warns(UserWarning, decorator_no_deprecation_multiple_warning) + assert_warns(DeprecationWarning, decorator_no_user_multiple_warning) + + # Check the context manager + def context_manager_no_warning(): + with ignore_warnings(): + _warning_function() + + def context_manager_no_warning_multiple(): + with ignore_warnings(category=(DeprecationWarning, UserWarning)): + _multiple_warning_function() + + def context_manager_no_deprecation_warning(): + with ignore_warnings(category=DeprecationWarning): + _warning_function() + + def context_manager_no_user_warning(): + with ignore_warnings(category=UserWarning): + _warning_function() + + def context_manager_no_deprecation_multiple_warning(): + with ignore_warnings(category=DeprecationWarning): + _multiple_warning_function() + + def context_manager_no_user_multiple_warning(): + with ignore_warnings(category=UserWarning): + _multiple_warning_function() + + assert_no_warnings(context_manager_no_warning) + assert_no_warnings(context_manager_no_warning_multiple) + assert_no_warnings(context_manager_no_deprecation_warning) + assert_warns(DeprecationWarning, context_manager_no_user_warning) + assert_warns(UserWarning, context_manager_no_deprecation_multiple_warning) + assert_warns(DeprecationWarning, context_manager_no_user_multiple_warning) + + # This class is inspired from numpy 1.7 with an alteration to check # the reset warning filters after calls to assert_warns. # This assert_warns behavior is specific to scikit-learn because