diff --git a/sklearn/decomposition/_dict_learning.py b/sklearn/decomposition/_dict_learning.py index f0e070d485e81..2da32c9b0d907 100644 --- a/sklearn/decomposition/_dict_learning.py +++ b/sklearn/decomposition/_dict_learning.py @@ -893,7 +893,8 @@ def dict_learning_online( dictionary = dictionary[:n_components, :] else: dictionary = np.r_[ - dictionary, np.zeros((n_components - r, dictionary.shape[1])) + dictionary, + np.zeros((n_components - r, dictionary.shape[1]), dtype=dictionary.dtype), ] if verbose == 1: @@ -905,25 +906,23 @@ def dict_learning_online( else: X_train = X - # Fortran-order dict better suited for the sparse coding which is the - # bottleneck of this algorithm. - dictionary = check_array( - dictionary, order="F", dtype=[np.float64, np.float32], copy=False - ) - dictionary = np.require(dictionary, requirements="W") - X_train = check_array( X_train, order="C", dtype=[np.float64, np.float32], copy=False ) + # Fortran-order dict better suited for the sparse coding which is the + # bottleneck of this algorithm. + dictionary = check_array(dictionary, order="F", dtype=X_train.dtype, copy=False) + dictionary = np.require(dictionary, requirements="W") + batches = gen_batches(n_samples, batch_size) batches = itertools.cycle(batches) # The covariance of the dictionary if inner_stats is None: - A = np.zeros((n_components, n_components)) + A = np.zeros((n_components, n_components), dtype=X_train.dtype) # The data approximation - B = np.zeros((n_features, n_components)) + B = np.zeros((n_features, n_components), dtype=X_train.dtype) else: A = inner_stats[0].copy() B = inner_stats[1].copy() diff --git a/sklearn/decomposition/tests/test_dict_learning.py b/sklearn/decomposition/tests/test_dict_learning.py index b7ee66d0e78cb..456379eee6d08 100644 --- a/sklearn/decomposition/tests/test_dict_learning.py +++ b/sklearn/decomposition/tests/test_dict_learning.py @@ -740,6 +740,10 @@ def test_dictionary_learning_dtype_match( assert dict_learner.components_.dtype == expected_type assert dict_learner.transform(X.astype(data_type)).dtype == expected_type + if dictionary_learning_transformer is MiniBatchDictionaryLearning: + assert dict_learner.inner_stats_[0].dtype == expected_type + assert dict_learner.inner_stats_[1].dtype == expected_type + @pytest.mark.parametrize("method", ("lars", "cd")) @pytest.mark.parametrize(