8000 MAINT Added automatic validation function for sklearn.neighbors.radiu… · scikit-learn/scikit-learn@1924ffb · GitHub
[go: up one dir, main page]

Skip to content

Commit 1924ffb

Browse files
sqalijeremiedbb
andauthored
MAINT Added automatic validation function for sklearn.neighbors.radius_neighbors_graph (#27245)
Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>
1 parent a05eb6b commit 1924ffb

File tree

2 files changed

+17
-4
lines changed

2 files changed

+17
-4
lines changed

sklearn/neighbors/_graph.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def kneighbors_graph(
7474
Parameters
7575
----------
7676
X : array-like of shape (n_samples, n_features)
77-
Sample data, in the form of a numpy array.
77+
Sample data.
7878
7979
n_neighbors : int
8080
Number of neighbors for each sample.
@@ -148,6 +148,19 @@ def kneighbors_graph(
148148
return X.kneighbors_graph(X=query, n_neighbors=n_neighbors, mode=mode)
149149

150150

151+
@validate_params(
152+
{
153+
"X": ["array-like", RadiusNeighborsMixin],
154+
"radius": [Interval(Real, 0, None, closed="both")],
155+
"mode": [StrOptions({"connectivity", "distance"})],
156+
"metric": [StrOptions(set(itertools.chain(*VALID_METRICS.values()))), callable],
157+
"p": [Interval(Real, 0, None, closed="right"), None],
158+
"metric_params": [dict, None],
159+
"include_self": ["boolean", StrOptions({"auto"})],
160+
"n_jobs": [Integral, None],
161+
},
162+
prefer_skip_nested_validation=False, # metric is not validated yet
163+
)
151164
def radius_neighbors_graph(
152165
X,
153166
radius,
@@ -168,9 +181,8 @@ def radius_neighbors_graph(
168181
169182
Parameters
170183
----------
171-
X : array-like of shape (n_samples, n_features) or BallTree
172-
Sample data, in the form of a numpy array or a precomputed
173-
:class:`BallTree`.
184+
X : array-like of shape (n_samples, n_features)
185+
Sample data.
174186
175187
radius : float
176188
Radius of neighborhoods.

sklearn/tests/test_public_functions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,7 @@ def _check_function_param_validation(
306306
"sklearn.model_selection.train_test_split",
307307
"sklearn.model_selection.validation_curve",
308308
"sklearn.neighbors.kneighbors_graph",
309+
"sklearn.neighbors.radius_neighbors_graph",
309310
"sklearn.neighbors.sort_graph_by_row_values",
310311
"sklearn.preprocessing.add_dummy_feature",
311312
"sklearn.preprocessing.binarize",

0 commit comments

Comments
 (0)
0