8000 TST Adds numpydoc tests for functions (#21245) · samronsin/scikit-learn@be40443 · GitHub
[go: up one dir, main page]

Skip to content

Commit be40443

Browse files
thomasjpfansamronsin
authored andcommitted
TST Adds numpydoc tests for functions (scikit-learn#21245)
1 parent efb5f6c commit be40443

File tree

1 file changed

+298
-0
lines changed

1 file changed

+298
-0
lines changed

maint_tools/test_docstrings.py

Lines changed: 298 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
import re
22
from inspect import signature
3+
import pkgutil
4+
import inspect
5+
import importlib
36
from typing import Optional
47

58
import pytest
69
from sklearn.utils import all_estimators
10+
import sklearn
711

812
numpydoc_validation = pytest.importorskip("numpydoc.validate")
913

@@ -30,6 +34,237 @@
3034
"TransformedTargetRegressor",
3135
]
3236

37+
FUNCTION_DOCSTRING_IGNORE_LIST = [
38+
"sklearn._config.config_context",
39+
"sklearn._config.get_config",
40+
"sklearn.base.clone",
41+
"sklearn.cluster._affinity_propagation.affinity_propagation",
42+
"sklearn.cluster._agglomerative.linkage_tree",
43+
"sklearn.cluster._kmeans.k_means",
44+
"sklearn.cluster._kmeans.kmeans_plusplus",
45+
"sklearn.cluster._mean_shift.estimate_bandwidth",
46+
"sklearn.cluster._mean_shift.get_bin_seeds",
47+
"sklearn.cluster._mean_shift.mean_shift",
48+
"sklearn.cluster._optics.cluster_optics_dbscan",
49+
"sklearn.cluster._optics.cluster_optics_xi",
50+
"sklearn.cluster._optics.compute_optics_graph",
51+
"sklearn.cluster._spectral.spectral_clustering",
52+
"sklearn.compose._column_transformer.make_column_transformer",
53+
"sklearn.covariance._empirical_covariance.empirical_covariance",
54+
"sklearn.covariance._empirical_covariance.log_likelihood",
55+
"sklearn.covariance._graph_lasso.graphical_lasso",
56+
"sklearn.covariance._robust_covariance.fast_mcd",
57+
"sklearn.covariance._shrunk_covariance.ledoit_wolf",
58+
"sklearn.covariance._shrunk_covariance.ledoit_wolf_shrinkage",
59+
"sklearn.covariance._shrunk_covariance.shrunk_covariance",
60+
"sklearn.datasets._base.get_data_home",
61+
"sklearn.datasets._base.load_boston",
62+
"sklearn.datasets._base.load_breast_cancer",
63+
"sklearn.datasets._base.load_diabetes",
64+
"sklearn.datasets._base.load_digits",
65+
"sklearn.datasets._base.load_files",
66+
"sklearn.datasets._base.load_iris",
67+
"sklearn.datasets._base.load_linnerud",
68+
"sklearn.datasets._base.load_sample_image",
69+
"sklearn.datasets._base.load_wine",
70+
"sklearn.datasets._california_housing.fetch_california_housing",
71+
"sklearn.datasets._covtype.fetch_covtype",
72+
"sklearn.datasets._kddcup99.fetch_kddcup99",
73+
"sklearn.datasets._lfw.fetch_lfw_pairs",
74+
"sklearn.datasets._lfw.fetch_lfw_people",
75+
"sklearn.datasets._olivetti_faces.fetch_olivetti_faces",
76+
"sklearn.datasets._openml.fetch_openml",
77+
"sklearn.datasets._rcv1.fetch_rcv1",
78+
"sklearn.datasets._samples_generator.make_biclusters",
79+
"sklearn.datasets._samples_generator.make_blobs",
80+
"sklearn.datasets._samples_generator.make_checkerboard",
81+
"sklearn.datasets._samples_generator.make_classification",
82+
"sklearn.datasets._samples_generator.make_gaussian_quantiles",
83+
"sklearn.datasets._samples_generator.make_hastie_10_2",
84+
"sklearn.datasets._samples_generator.make_multilabel_classification",
85+
"sklearn.datasets._samples_generator.make_regression",
86+
"sklearn.datasets._samples_generator.make_sparse_coded_signal",
87+
"sklearn.datasets._samples_generator.make_sparse_spd_matrix",
88+
"sklearn.datasets._samples_generator.make_spd_matrix",
89+
"sklearn.datasets._species_distributions.fetch_species_distributions",
90+
"sklearn.datasets._svmlight_format_io.dump_svmlight_file",
91+
"sklearn.datasets._svmlight_format_io.load_svmlight_file",
92+
"sklearn.datasets._svmlight_format_io.load_svmlight_files",
93+
"sklearn.datasets._twenty_newsgroups.fetch_20newsgroups",
94+
"sklearn.decomposition._dict_learning.dict_learning",
95+
"sklearn.decomposition._dict_learning.dict_learning_online",
96+
"sklearn.decomposition._dict_learning.sparse_encode",
97+
"sklearn.decomposition._fastica.fastica",
98+
"sklearn.decomposition._nmf.non_negative_factorization",
99+
"sklearn.externals._packaging.version.parse",
100+
"sklearn.feature_extraction.image.extract_patches_2d",
101+
"sklearn.feature_extraction.image.grid_to_graph",
102+
"sklearn.feature_extraction.image.img_to_graph",
103+
"sklearn.feature_extraction.text.strip_accents_ascii",
104+
"sklearn.feature_extraction.text.strip_accents_unicode",
105+< B41A div class="diff-text-inner"> "sklearn.feature_extraction.text.strip_tags",
106+
"sklearn.feature_selection._univariate_selection.chi2",
107+
"sklearn.feature_selection._univariate_selection.f_oneway",
108+
"sklearn.feature_selection._univariate_selection.r_regression",
109+
"sklearn.inspection._partial_dependence.partial_dependence",
110+
"sklearn.inspection._plot.partial_dependence.plot_partial_dependence",
111+
"sklearn.isotonic.isotonic_regression",
112+
"sklearn.linear_model._least_angle.lars_path",
113+
"sklearn.linear_model._least_angle.lars_path_gram",
114+
"sklearn.linear_model._omp.orthogonal_mp",
115+
"sklearn.linear_model._omp.orthogonal_mp_gram",
116+
"sklearn.linear_model._ridge.ridge_regression",
117+
"sklearn.manifold._locally_linear.locally_linear_embedding",
118+
"sklearn.manifold._t_sne.trustworthiness",
119+
"sklearn.metrics._classification.accuracy_score",
120+
"sklearn.metrics._classification.balanced_accuracy_score",
121+
"sklearn.metrics._classification.brier_score_loss",
122+
"sklearn.metrics._classification.classification_report",
123+
"sklearn.metrics._classification.cohen_kappa_score",
124+
"sklearn.metrics._classification.confusion_matrix",
125+
"sklearn.metrics._classification.f1_score",
126+
"sklearn.metrics._classification.fbeta_score",
127+
"sklearn.metrics._classification.hamming_loss",
128+
"sklearn.metrics._classification.hinge_loss",
129+
"sklearn.metrics._classification.jaccard_score",
130+
"sklearn.metrics._classification.log_loss",
131+
"sklearn.metrics._classification.precision_recall_fscore_support",
132+
"sklearn.metrics._classification.precision_score",
133+
"sklearn.metrics._classification.recall_score",
134+
"sklearn.metrics._classification.zero_one_loss",
135+
"sklearn.metrics._plot.confusion_matrix.plot_confusion_matrix",
136+
"sklearn.metrics._plot.det_curve.plot_det_curve",
137+
"sklearn.metrics._plot.precision_recall_curve.plot_precision_recall_curve",
138+
"sklearn.metrics._plot.roc_curve.plot_roc_curve",
139+
"sklearn.metrics._ranking.auc",
140+
"sklearn.metrics._ranking.average_precision_score",
141+
"sklearn.metrics._ranking.coverage_error",
142+
"sklearn.metrics._ranking.dcg_score",
143+
"sklearn.metrics._ranking.label_ranking_average_precision_score",
144+
"sklearn.metrics._ranking.label_ranking_loss",
145+
"sklearn.metrics._ranking.ndcg_score",
146+
"sklearn.metrics._ranking.precision_recall_curve",
147+
"sklearn.metrics._ranking.roc_auc_score",
148+
"sklearn.metrics._ranking.roc_curve",
149+
"sklearn.metrics._ranking.top_k_accuracy_score",
150+
"sklearn.metrics._regression.max_error",
151+
"sklearn.metrics._regression.mean_absolute_error",
152+
"sklearn.metrics._regression.mean_pinball_loss",
153+
"sklearn.metrics._scorer.make_scorer",
154+
"sklearn.metrics.cluster._bicluster.consensus_score",
155+
"sklearn.metrics.cluster._supervised.adjusted_mutual_info_score",
156+
"sklearn.metrics.cluster._supervised.adjusted_rand_score",
157+
"sklearn.metrics.cluster._supervised.completeness_score",
158+
"sklearn.metrics.cluster._supervised.entropy",
159+
"sklearn.metrics.cluster._supervised.fowlkes_mallows_score",
160+
"sklearn.metrics.cluster._supervised.homogeneity_completeness_v_measure",
161+
"sklearn.metrics.cluster._supervised.homogeneity_score",
162+
"sklearn.metrics.cluster._supervised.mutual_info_score",
163+
"sklearn.metrics.cluster._supervised.normalized_mutual_info_score",
164+
"sklearn.metrics.cluster._supervised.pair_confusion_matrix",
165+
"sklearn.metrics.cluster._supervised.rand_score",
166+
"sklearn.metrics.cluster._supervised.v_measure_score",
167+
"sklearn.metrics.cluster._unsupervised.davies_bouldin_score",
168+
"sklearn.metrics.cluster._unsupervised.silhouette_samples",
169+
"sklearn.metrics.cluster._unsupervised.silhouette_score",
170+
"sklearn.metrics.pairwise.additive_chi2_kernel",
171+
"sklearn.metrics.pairwise.check_paired_arrays",
172+
"sklearn.metrics.pairwise.check_pairwise_arrays",
173+
"sklearn.metrics.pairwise.chi2_kernel",
174+
"sklearn.metrics.pairwise.cosine_distances",
175+
"sklearn.metrics.pairwise.cosine_similarity",
176+
"sklearn.metrics.pairwise.distance_metrics",
177+
"sklearn.metrics.pairwise.euclidean_distances",
178+
"sklearn.metrics.pairwise.haversine_distances",
179+
"sklearn.metrics.pairwise.kernel_metrics",
180+
"sklearn.metrics.pairwise.laplacian_kernel",
181+
"sklearn.metrics.pairwise.linear_kernel",
182+
"sklearn.metrics.pairwise.manhattan_distances",
183+
"sklearn.metrics.pairwise.nan_euclidean_distances",
184+
"sklearn.metrics.pairwise.paired_cosine_distances",
185+
"sklearn.metrics.pairwise.paired_distances",
186+
"sklearn.metrics.pairwise.paired_euclidean_distances",
187+
"sklearn.metrics.pairwise.paired_manhattan_distances",
188+
"sklearn.metrics.pairwise.pairwise_distances_argmin",
189+
"sklearn.metrics.pairwise.pairwise_distances_argmin_min",
190+
"sklearn.metrics.pairwise.pairwise_distances_chunked",
191+
"sklearn.metrics.pairwise.pairwise_kernels",
192+
"sklearn.metrics.pairwise.polynomial_kernel",
193+
"sklearn.metrics.pairwise.rbf_kernel",
194+
"sklearn.metrics.pairwise.sigmoid_kernel",
195+
"sklearn.model_selection._split.check_cv",
196+
"sklearn.model_selection._split.train_test_split",
197+
"sklearn.model_selection._validation.cross_val_predict",
198+
"sklearn.model_selection._validation.cross_val_score",
199+
"sklearn.model_selection._validation.cross_validate",
200+
"sklearn.model_selection._validation.learning_curve",
201+
"sklearn.model_selection._validation.permutation_test_score",
202+
"sklearn.model_selection._validation.validation_curve",
203+
"sklearn.neighbors._graph.kneighbors_graph",
204+
"sklearn.neighbors._graph.radius_neighbors_graph",
205+
"sklearn.pipeline.make_union",
206+
"sklearn.preprocessing._data.binarize",
207+
"sklearn.preprocessing._data.maxabs_scale",
208+
"sklearn.preprocessing._data.normalize",
209+
"sklearn.preprocessing._data.power_transform",
210+
"sklearn.preprocessing._data.quantile_transform",
211+
"sklearn.preprocessing._data.robust_scale",
212+
"sklearn.preprocessing._data.scale",
213+
"sklearn.preprocessing._label.label_binarize",
214+
"sklearn.random_projection.johnson_lindenstrauss_min_dim",
215+
"sklearn.svm._bounds.l1_min_c",
216+
"sklearn.tree._export.plot_tree",
217+
"sklearn.utils.axis0_safe_slice",
218+
"sklearn.utils.check_pandas_support",
219+
"sklearn.utils.extmath.cartesian",
220+
"sklearn.utils.extmath.density",
221+
"sklearn.utils.extmath.fast_logdet",
222+
"sklearn.utils.extmath.randomized_range_finder",
223+
"sklearn.utils.extmath.randomized_svd",
224+
"sklearn.utils.extmath.safe_sparse_dot",
225+
"sklearn.utils.extmath.squared_norm",
226+
"sklearn.utils.extmath.stable_cumsum",
227+
"sklearn.utils.extmath.svd_flip",
228+
"sklearn.utils.extmath.weighted_mode",
229+
"sklearn.utils.fixes.delayed",
230+
"sklearn.utils.fixes.linspace",
231+
"sklearn.utils.gen_batches",
232+
"sklearn.utils.gen_even_slices",
233+
"sklearn.utils.get_chunk_n_rows",
234+
"sklearn.utils.graph.graph_shortest_path",
235+
"sklearn.utils.graph.single_source_shortest_path_length",
236+
"sklearn.utils.is_scalar_nan",
237+
"sklearn.utils.metaestimators.available_if",
238+
"sklearn.utils.metaestimators.if_delegate_has_method",
239+
"sklearn.utils.multiclass.check_classification_targets",
240+
"sklearn.utils.multiclass.class_distribution",
241+
"sklearn.utils.multiclass.type_of_target",
242+
"sklearn.utils.multiclass.unique_labels",
243+
"sklearn.utils.resample",
244+
"sklearn.utils.safe_mask",
245+
"sklearn.utils.safe_sqr",
246+
"sklearn.utils.shuffle",
247+
"sklearn.utils.sparsefuncs.count_nonzero",
248+
"sklearn.utils.sparsefuncs.csc_median_axis_0",
249+
"sklearn.utils.sparsefuncs.incr_mean_variance_axis",
250+
"sklearn.utils.sparsefuncs.inplace_swap_column",
251+
"sklearn.utils.sparsefuncs.inplace_swap_row",
252+
"sklearn.utils.sparsefuncs.inplace_swap_row_csc",
253+
"sklearn.utils.sparsefuncs.inplace_swap_row_csr",
254+
"sklearn.utils.sparsefuncs.mean_variance_axis",
255+
"sklearn.utils.sparsefuncs.min_max_axis",
256+
"sklearn.utils.tosequence",
257+
"sklearn.utils.validation.as_float_array",
258+
"sklearn.utils.validation.assert_all_finite",
259+
"sklearn.utils.validation.check_is_fitted",
260+
"sklearn.utils.validation.check_memory",
261+
"sklearn.utils.validation.check_random_state",
262+
"sklearn.utils.validation.column_or_1d",
263+
"sklearn.utils.validation.has_fit_parameter",
264+
"sklearn.utils.validation.indexable",
265+
]
266+
FUNCTION_DOCSTRING_IGNORE_LIST = set(FUNCTION_DOCSTRING_IGNORE_LIST)
267+
33268

34269
def get_all_methods():
35270
estimators = all_estimators()
@@ -50,6 +285,51 @@ def get_all_methods():
50285
yield Estimator, method
51286

52287

288+
def _is_checked_function(item):
289+
if not inspect.isfunction(item):
290+
return False
291+
292+
if item.__name__.startswith("_"):
293+
return False
294+
295+
mod = item.__module__
296+
if not mod.startswith("sklearn.") or mod.endswith("estimator_checks"):
297+
return False
298+
299+
return True
300+
301+
302+
def get_all_functions_names():
303+
"""Get all public functions define in the sklearn module"""
304+
modules_to_ignore = {
305+
"tests",
306+
"externals",
307+
"setup",
308+
"conftest",
309+
"experimental",
310+
"estimator_checks",
311+
}
312+
313+
all_functions_names = set()
314+
for module_finder, module_name, ispkg in pkgutil.walk_packages(
315+
path=sklearn.__path__, prefix="sklearn."
316+
):
317+
module_parts = module_name.split(".")
318+
if (
319+
any(part in modules_to_ignore for part in module_parts)
320+
or "._" in module_name
321+
):
322+
continue
323+
324+
module = importlib.import_module(module_name)
325+
functions = inspect.getmembers(module, _is_checked_function)
326+
for name, func in functions:
327+
full_name = f"{func.__module__}.{func.__name__}"
328+
all_functions_names.add(full_name)
329+
330+
return sorted(all_functions_names)
331+
332+
53333
def filter_errors(errors, method, Estimator=None):
54334
"""
55335
Ignore some errors based on the method type.
@@ -144,6 +424,24 @@ def repr_errors(res, estimator=None, method: Optional[str] = None) -> str:
144424
return msg
145425

146426

427+
@pytest.mark.parametrize("function_name", get_all_functions_names())
428+
def test_function_docstring(function_name, request):
429+
"""Check function docstrings using numpydoc."""
430+
if function_name in FUNCTION_DOCSTRING_IGNORE_LIST:
431+
request.applymarker(
432+
pytest.mark.xfail(run=False, reason="TODO pass numpydoc validation")
433+
)
434+
435+
res = numpydoc_validation.validate(function_name)
436+
437+
res["errors"] = list(filter_errors(res["errors"], method="function"))
438+
439+
if res["errors"]:
440+
msg = repr_errors(res, method=f"Tested function: {function_name}")
441+
442+
raise ValueError(msg)
443+
444+
147445
@pytest.mark.parametrize("Estimator, method", get_all_methods())
148446
def test_docstring(Estimator, method, request):
149447
base_import_path = Estimator.__module__

0 commit comments

Comments
 (0)
0