8000 MNT Replace kwargs by named args for resample by alfaro96 · Pull Request #17324 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

MNT Replace kwargs by named args for resample #17324

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 25, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 17 additions & 19 deletions sklearn/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,43 +413,48 @@ def _get_column_indices(X, key):
"strings, or boolean mask is allowed")


def resample(*arrays, **options):
"""Resample arrays or sparse matrices in a consistent way
def resample(*arrays,
replace=True,
n_samples=None,
random_state=None,
stratify=None):
"""Resample arrays or sparse matrices in a consistent way.

The default strategy implements one step of the bootstrapping
procedure.

Parameters
----------
*arrays : sequence of indexable data-structures
*arrays : sequence of array-like of shape (n_samples,) or \
(n_samples, n_outputs)
Indexable data-structures can be arrays, lists, dataframes or scipy
sparse matrices with consistent first dimension.

Other Parameters
----------------
replace : boolean, True by default
replace : bool, default=True
Implements resampling with replacement. If False, this will implement
(sliced) random permutations.

n_samples : int, None by default
n_samples : int, default=None
Number of samples to generate. If left to None this is
automatically set to the first dimension of the arrays.
If replace is False it should not be larger than the length of
arrays.

random_state : int, RandomState instance or None, optional (default=None)
random_state : int or RandomState instance, default=None
Determines random number generation for shuffling
the data.
Pass an int for reproducible results across multiple function calls.
See :term:`Glossary <random_state>`.

stratify : array-like or None (default=None)
stratify : array-like of shape (n_samples,) or (n_samples, n_outputs), \
default=None
If not None, data is split in a stratified fashion, using this as
the class labels.

Returns
-------
resampled_arrays : sequence of indexable data-structures
resampled_arrays : sequence of array-like of shape (n_samples,) or \
(n_samples, n_outputs)
Sequence of resampled copies of the collections. The original arrays
are not impacted.

Expand Down Expand Up @@ -492,18 +497,12 @@ def resample(*arrays, **options):
... random_state=0)
[1, 1, 1, 0, 1]


See also
--------
:func:`sklearn.utils.shuffle`
"""

random_state = check_random_state(options.pop('random_state', None))
replace = options.pop('replace', True)
max_n_samples = options.pop('n_samples', None)
stratify = options.pop('stratify', None)
if options:
raise ValueError("Unexpected kw arguments: %r" % options.keys())
max_n_samples = n_samples
random_state = check_random_state(random_state)

if len(arrays) == 0:
return None
Expand Down Expand Up @@ -556,7 +555,6 @@ def resample(*arrays, **options):

indices = random_state.permutation(indices)


# convert sparse matrices to CSR for row-based indexing
arrays = [a.tocsr() if issparse(a) else a for a in arrays]
resampled_arrays = [_safe_indexing(a, indices) for a in arrays]
Expand Down
2 changes: 0 additions & 2 deletions sklearn/utils/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,6 @@ def test_resample():
with pytest.raises(ValueError):
resample([0, 1], [0, 1], replace=False, n_samples=3)

with pytest.raises(ValueError):
4952 resample([0, 1], [0, 1], meaning_of_life=42)
# Issue:6581, n_samples can be more when replace is True (default).
assert len(resample([1, 2], n_samples=5)) == 5

Expand Down
0