8000 MNT parameter validation for covariance.empirical_covariance (#25146) · npache/scikit-learn@17b8278 · GitHub
[go: up one dir, main page]

Skip to content

Commit 17b8278

Browse files
authored
MNT parameter validation for covariance.empirical_covariance (scikit-learn#25146)
1 parent 0266481 commit 17b8278

File tree

2 files changed

+8
-0
lines changed

2 files changed

+8
-0
lines changed

sklearn/covariance/_empirical_covariance.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from .. import config_context
1818
from ..base import BaseEstimator
1919
from ..utils import check_array
20+
from ..utils._param_validation import validate_params
2021
from ..utils.extmath import fast_logdet
2122
from ..metrics.pairwise import pairwise_distances
2223

@@ -48,6 +49,12 @@ def log_likelihood(emp_cov, precision):
4849
return log_likelihood_
4950

5051

52+
@validate_params(
53+
{
54+
"X": ["array-like"],
55+
"assume_centered": ["boolean"],
56+
}
57+
)
5158
def empirical_covariance(X, *, assume_centered=False):
5259
"""Compute the Maximum likelihood covariance estimator.
5360

sklearn/tests/test_public_functions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ def _check_function_param_validation(
9696
PARAM_VALIDATION_FUNCTION_LIST = [
9797
"sklearn.cluster.estimate_bandwidth",
9898
"sklearn.cluster.kmeans_plusplus",
99+
"sklearn.covariance.empirical_covariance",
99100
"sklearn.feature_extraction.grid_to_graph",
100101
"sklearn.feature_extraction.img_to_graph",
101102
"sklearn.metrics.accuracy_score",

0 commit comments

Comments
 (0)
0