|
16 | 16 | from sklearn.utils.testing import assert_less
|
17 | 17 | from sklearn.utils.testing import assert_warns
|
18 | 18 | from sklearn.utils.testing import if_safe_multiprocessing_with_blas
|
19 |
| -from sklearn.utils.testing import if_not_mac_os |
20 | 19 | from sklearn.utils.testing import assert_raise_message
|
21 | 20 |
|
22 | 21 |
|
@@ -272,14 +271,18 @@ def test_k_means_explicit_init_shape():
|
272 | 271 | msg = "does not match the number of features of the data"
|
273 | 272 | assert_raises_regex(ValueError, msg, km.fit, X)
|
274 | 273 | # for callable init
|
275 |
| - km = Class(n_init=1, init=lambda X_, k, random_state: X_[:, :2], n_clusters=len(X)) |
| 274 | + km = Class(n_init=1, |
| 275 | + init=lambda X_, k, random_state: X_[:, :2], |
| 276 | + n_clusters=len(X)) |
276 | 277 | assert_raises_regex(ValueError, msg, km.fit, X)
|
277 | 278 | # mismatch of number of clusters
|
278 | 279 | msg = "does not match the number of clusters"
|
279 | 280 | km = Class(n_init=1, init=X[:2, :], n_clusters=3)
|
280 | 281 | assert_raises_regex(ValueError, msg, km.fit, X)
|
281 | 282 | # for callable init
|
282 |
| - km = Class(n_init=1, init=lambda X_, k, random_state: X_[:2, :], n_clusters=3) |
| 283 | + km = Class(n_init=1, |
| 284 | + init=lambda X_, k, random_state: X_[:2, :], |
| 285 | + n_clusters=3) |
283 | 286 | assert_raises_regex(ValueError, msg, km.fit, X)
|
284 | 287 |
|
285 | 288 |
|
@@ -730,4 +733,122 @@ def test_x_squared_norms_init_centroids():
|
730 | 733 | def test_max_iter_error():
|
731 | 734 |
|
732 | 735 | km = KMeans(max_iter=-1)
|
733 |
| - assert_raise_message(ValueError, 'Number of iterations should be', km.fit, X) |
| 736 | + assert_raise_message(ValueError, |
| 737 | + 'Number of iterations should be', km.fit, X) |
| 738 | + |
| 739 | + |
| 740 | +def test_kmeans_float32_64(): |
| 741 | + km = KMeans(n_init=1, random_state=11) |
| 742 | + |
| 743 | + # float64 data |
| 744 | + km.fit(X) |
| 745 | + # dtype of cluster centers has to be the dtype of the input data |
| 746 | + assert_equal(km.cluster_centers_.dtype, np.float64) |
| 747 | + inertia64 = km.inertia_ |
| 748 | + X_new64 = km.transform(km.cluster_centers_) |
| 749 | + pred64 = km.predict(X[0]) |
| 750 | + |
| 751 | + # float32 data |
| 752 | + km.fit(np.float32(X)) |
| 753 | + # dtype of cluster centers has to be the dtype of the input data |
| 754 | + assert_equal(km.cluster_centers_.dtype, np.float32) |
| 755 | + inertia32 = km.inertia_ |
| 756 | + X_new32 = km.transform(km.cluster_centers_) |
| 757 | + pred32 = km.predict(X[0]) |
| 758 | + |
| 759 | + # compare arrays with low precision since the difference between |
| 760 | + # 32 and 64 bit sometimes makes a difference up to the 4th decimal place |
| 761 | + assert_array_almost_equal(inertia32, inertia64, decimal=4) |
| 762 | + assert_array_almost_equal(X_new32, X_new64, decimal=4) |
| 763 | + # both predictions have to be the same and correspond to the correct label |
| 764 | + assert_equal(pred32, pred64) |
| 765 | + assert_equal(pred32, km.labels_[0]) |
| 766 | + assert_equal(pred64, km.labels_[0]) |
| 767 | + |
| 768 | + # float64 sparse data |
| 769 | + km.fit(X_csr) |
| 770 | + # dtype of cluster centers has to be the dtype of the input data |
| 771 | + assert_equal(km.cluster_centers_.dtype, np.float64) |
| 772 | + inertia64 = km.inertia_ |
| 773 | + X_new64 = km.transform(km.cluster_centers_) |
| 774 | + pred64 = km.predict(X_csr[0]) |
| 775 | + |
| 776 | + # float32 sparse data |
| 777 | + # Note: at the moment sparse data is always processed as float64 internally |
| 778 | + km.fit(sp.csr_matrix(X_csr, dtype=np.float32)) |
| 779 | + assert_equal(km.cluster_centers_.dtype, np.float64) |
| 780 | + inertia32 = km.inertia_ |
| 781 | + X_new32 = km.transform(km.cluster_centers_) |
| 782 | + pred32 = km.predict(X_csr[0]) |
| 783 | + |
| 784 | + assert_array_almost_equal(inertia32, inertia64) |
| 785 | + assert_array_almost_equal(X_new32, X_new64) |
| 786 | + # both predictions have to be the same and correspond to the correct label |
| 787 | + assert_equal(pred32, pred64) |
| 788 | + assert_equal(pred32, km.labels_[0]) |
| 789 | + assert_equal(pred64, km.labels_[0]) |
| 790 | + |
| 791 | + |
| 792 | +def test_mb_k_means_float32_64(): |
| 793 | + km = MiniBatchKMeans(n_init=1, random_state=30) |
| 794 | + |
| 795 | + # float64 data |
| 796 | + km.fit(X) |
| 797 | + # dtype of cluster centers has to be the dtype of the input data |
| 798 | + assert_equal(km.cluster_centers_.dtype, np.float64) |
| 799 | + inertia64 = km.inertia_ |
| 800 | + X_new64 = km.transform(km.cluster_centers_) |
| 801 | + pred64 = km.predict(X[0]) |
| 802 | + km.partial_fit(X[0:3]) |
| 803 | + # dtype of cluster centers has to stay the same after partial_fit |
| 804 | + assert_equal(km.cluster_centers_.dtype, np.float64) |
| 805 | + |
| 806 | + # float32 data |
| 807 | + km.fit(np.float32(X)) |
| 808 | + # dtype of cluster centers has to be the dtype of the input data |
| 809 | + assert_equal(km.cluster_centers_.dtype, np.float32) |
| 810 | + inertia32 = km.inertia_ |
| 811 | + X_new32 = km.transform(km.cluster_centers_) |
| 812 | + pred32 = km.predict(X[0]) |
| 813 | + km.partial_fit(X[0:3]) |
| 814 | + # dtype of cluster centers has to stay the same after partial_fit |
| 815 | + assert_equal(km.cluster_centers_.dtype, np.float32) |
| 816 | + |
| 817 | + # compare arrays with low precision since the difference between |
| 818 | + # 32 and 64 bit sometimes makes a difference up to the 4th decimal place |
| 819 | + assert_array_almost_equal(inertia32, inertia64, decimal=4) |
| 820 | + assert_array_almost_equal(X_new32, X_new64, decimal=4) |
| 821 | + # both predictions have to be the same and correspond to the correct label |
| 822 | + assert_equal(pred32, pred64) |
| 823 | + assert_equal(pred32, km.labels_[0]) |
| 824 | + assert_equal(pred64, km.labels_[0]) |
| 825 | + |
| 826 | + # float64 sparse data |
| 827 | + km.fit(X_csr) |
| 828 | + # dtype of cluster centers has to be the dtype of the input data |
| 829 | + assert_equal(km.cluster_centers_.dtype, np.float64) |
| 830 | + inertia64 = km.inertia_ |
| 831 | + X_new64 = km.transform(km.cluster_centers_) |
| 832 | + pred64 = km.predict(X_csr[0]) |
| 833 | + km.partial_fit(X_csr[0:3]) |
| 834 | + # dtype of cluster centers has to stay the same after partial_fit |
| 835 | + assert_equal(km.cluster_centers_.dtype, np.float64) |
| 836 | + |
| 837 | + # float32 sparse data |
| 838 | + # Note: at the moment sparse data is always processed as float64 internally |
| 839 | + km.fit(sp.csr_matrix(X_csr, dtype=np.float32)) |
| 840 | + # dtype of cluster centers has to be always float64 (see Note above.) |
| 841 | + assert_equal(km.cluster_centers_.dtype, np.float64) |
| 842 | + inertia32 = km.inertia_ |
| 843 | + X_new32 = km.transform(km.cluster_centers_) |
| 844 | + pred32 = km.predict(X_csr[0]) |
| 845 | + km.partial_fit(X_csr[0:3]) |
| 846 | + # dtype of cluster centers has to stay the same after partial_fit |
| 847 | + assert_equal(km.cluster_centers_.dtype, np.float64) |
| 848 | + |
| 849 | + assert_array_almost_equal(inertia32, inertia64) |
| 850 | + assert_array_almost_equal(X_new32, X_new64) |
| 851 | + # both predictions have to be the same and correspond to the correct label |
| 852 | + assert_equal(pred32, pred64) |
| 853 | + assert_equal(pred32, km.labels_[0]) |
| 854 | + assert_equal(pred64, km.labels_[0]) |
0 commit comments