8000 MAINT Parameters validation for sklearn.datasets.load_files (#26203) · lesteve/scikit-learn@bc3a19d · GitHub
[go: up one dir, main page]

Skip to content

Commit bc3a19d

Browse files
MAINT Parameters validation for sklearn.datasets.load_files (scikit-learn#26203)
Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>
1 parent 18a4576 commit bc3a19d

File tree

3 files changed

+18
-3
lines changed
  • sklearn
    • datasets
      • < 8000 div class="PRIVATE_TreeView-item-level-line prc-TreeView-TreeViewItemLevelLine-KPSSL">
  • tests
  • tests
  • 3 files changed

    +18
    -3
    lines changed

    sklearn/datasets/_base.py

    Lines changed: 14 additions & 1 deletion
    < 8000 tr class="diff-line-row">
    Original file line numberDiff line numberDiff line change
    @@ -22,7 +22,7 @@
    2222
    from ..utils import check_random_state
    2323
    from ..utils import check_pandas_support
    2424
    from ..utils.fixes import _open_binary, _open_text, _read_text, _contents
    25-
    from ..utils._param_validation import validate_params, Interval
    25+
    from ..utils._param_validation import validate_params, Interval, StrOptions
    2626

    2727
    import numpy as np
    2828

    @@ -104,6 +104,19 @@ def _convert_data_dataframe(
    104104
    return combined_df, X, y
    105105

    106106

    107+
    @validate_params(
    108+
    {
    109+
    "container_path": [str, os.PathLike],
    110+
    "description": [str, None],
    111+
    "categories": [list, None],
    112+
    "load_content": ["boolean"],
    113+
    "shuffle": ["boolean"],
    114+
    "encoding": [str, None],
    115+
    "decode_error": [StrOptions({"strict", "ignore", "replace"})],
    116+
    "random_state": ["random_state"],
    117+
    "allowed_extensions": [list, None],
    118+
    }
    119+
    )
    107120
    def load_files(
    108121
    container_path,
    109122
    *,

    sklearn/datasets/tests/test_base.py

    Lines changed: 3 additions & 2 deletions
    Original file line numberDiff line numberDiff line change
    @@ -98,10 +98,11 @@ def test_default_load_files(test_category_dir_1, test_category_dir_2, load_files
    9898
    def test_load_files_w_categories_desc_and_encoding(
    9999
    test_category_dir_1, test_category_dir_2, load_files_root
    100100
    ):
    101-
    category = os.path.abspath(test_category_dir_1).split("/").pop()
    101+
    category = os.path.abspath(test_category_dir_1).split(os.sep).pop()
    102102
    res = load_files(
    103-
    load_files_root, description="test", categories=category, encoding="utf-8"
    103+
    load_files_root, description="test", categories=[category], encoding="utf-8"
    104104
    )
    105+
    105106
    assert len(res.filenames) == 1
    106107
    assert len(res.target_names) == 1
    107108
    assert res.DESCR == "test"

    sklearn/tests/test_public_functions.py

    Lines changed: 1 addition & 0 deletions
    Original file line numberDiff line numberDiff line change
    @@ -134,6 +134,7 @@ def _check_function_param_validation(
    134134
    "sklearn.datasets.load_breast_cancer",
    135135
    "sklearn.datasets.load_diabetes",
    136136
    "sklearn.datasets.load_digits",
    137+
    "sklearn.datasets.load_files",
    137138
    "sklearn.datasets.load_iris",
    138139
    "sklearn.datasets.load_linnerud",
    139140
    "sklearn.datasets.load_svmlight_file",

    0 commit comments

    Comments
     (0)
    0