8000 MAINT parameter validation for sklearn.datasets.dump_svmlight_file (#… · sortofamudkip/scikit-learn@00f49eb · GitHub
[go: up one dir, main page]

Skip to content

Commit 00f49eb

Browse files
MAINT parameter validation for sklearn.datasets.dump_svmlight_file (scikit-learn#25726)
Co-authored-by: jeremiedbb <jeremiedbb@yahoo.fr>
1 parent fa0866a commit 00f49eb

File tree

2 files changed

+16
-3
lines changed

2 files changed

+16
-3
lines changed

sklearn/datasets/_svmlight_format_io.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from .. import __version__
2626

2727
from ..utils import check_array, IS_PYPY
28+
from ..utils._param_validation import validate_params, HasMethods
2829

2930
if not IS_PYPY:
3031
from ._svmlight_format_fast import (
@@ -404,6 +405,17 @@ def _dump_svmlight(X, y, f, multilabel, one_based, comment, query_id):
404405
)
405406

406407

408+
@vali 10000 date_params(
409+
{
410+
"X": ["array-like", "sparse matrix"],
411+
"y": ["array-like", "sparse matrix"],
412+
"f": [str, HasMethods(["write"])],
413+
"zero_based": ["boolean"],
414+
"comment": [str, bytes, None],
415+
"query_id": ["array-like", None],
416+
"multilabel": ["boolean"],
417+
}
418+
)
407419
def dump_svmlight_file(
408420
X,
409421
y,
@@ -428,7 +440,7 @@ def dump_svmlight_file(
428440
Training vectors, where `n_samples` is the number of samples and
429441
`n_features` is the number of features.
430442
431-
y : {array-like, sparse matrix}, shape = [n_samples (, n_labels)]
443+
y : {array-like, sparse matrix}, shape = (n_samples,) or (n_samples, n_labels)
432444
Target values. Class labels must be an
433445
integer or float, or array-like objects of integer or float for
434446
multilabel classifications.
@@ -442,7 +454,7 @@ def dump_svmlight_file(
442454
Whether column indices should be written zero-based (True) or one-based
443455
(False).
444456
445-
comment : str, default=None
457+
comment : str or bytes, default=None
446458
Comment to insert at the top of the file. This should be either a
447459
Unicode string, which will be encoded as UTF-8, or an ASCII byte
448460
string.
@@ -459,7 +471,7 @@ def dump_svmlight_file(
459471
https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multilabel.html).
460472
461473
.. versionadded:: 0.17
462-
parameter *multilabel* to support multilabel datasets.
474+
parameter `multilabel` to support multilabel datasets.
463475
"""
464476
if comment is not None:
465477
# Convert comment string to list of lines in UTF-8.

sklearn/tests/test_public_functions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ def _check_function_param_validation(
102102
"sklearn.cluster.ward_tree",
103103
"sklearn.covariance.empirical_covariance",
104104
"sklearn.covariance.shrunk_covariance",
105+
"sklearn.datasets.dump_svmlight_file",
105106
"sklearn.datasets.fetch_california_housing",
106107
"sklearn.datasets.fetch_kddcup99",
107108
"sklearn.datasets.make_classification",

0 commit comments

Comments
 (0)
0