8000 DOC Improve `_check_sample_weight` docstring (#30908) · scikit-learn/scikit-learn@9ce8be6 · GitHub
[go: up one dir, main page]

Skip to content

Commit 9ce8be6

Browse files
authored
DOC Improve _check_sample_weight docstring (#30908)
1 parent fef6202 commit 9ce8be6

File tree

3 files changed

+14
-10
lines changed

3 files changed

+14
-10
lines changed

sklearn/ensemble/_weight_boosting.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def fit(self, X, y, sample_weight=None):
139139
)
140140

141141
sample_weight = _check_sample_weight(
142-
sample_weight, X, np.float64, copy=True, ensure_non_negative=True
142+
sample_weight, X, dtype=np.float64, copy=True, ensure_non_negative=True
143143
)
144144
sample_weight /= sample_weight.sum()
145145

sklearn/tree/_classes.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ def _fit(
358358
)
359359

360360
if sample_weight is not None:
361-
sample_weight = _check_sample_weight(sample_weight, X, DOUBLE)
361+
sample_weight = _check_sample_weight(sample_weight, X, dtype=DOUBLE)
362362

363363
if expanded_class_weight is not None:
364364
if sample_weight is not None:

sklearn/utils/validation.py

+12-8
Original file line numberDiff line numberDiff line change
@@ -2127,7 +2127,7 @@ def _check_psd_eigenvalues(lambdas, enable_warnings=False):
21272127

21282128

21292129
def _check_sample_weight(
2130-
sample_weight, X, dtype=None, copy=False, ensure_non_negative=False
2130+
sample_weight, X, *, dtype=None, ensure_non_negative=False, copy=False
21312131
):
21322132
"""Validate sample weights.
21332133
@@ -2144,18 +2144,22 @@ def _check_sample_weight(
21442144
X : {ndarray, list, sparse matrix}
21452145
Input data.
21462146
2147+
dtype : dtype, default=None
2148+
dtype of the validated `sample_weight`.
2149+
If None, and `sample_weight` is an array:
2150+
2151+
- If `sample_weight.dtype` is one of `{np.float64, np.float32}`,
2152+
then the dtype is preserved.
2153+
- Else the output has NumPy's default dtype: `np.float64`.
2154+
2155+
If `dtype` is not `{np.float32, np.float64, None}`, then output will
2156+
be `np.float64`.
2157+
21472158
ensure_non_negative : bool, default=False,
21482159
Whether or not the weights are expected to be non-negative.
21492160
21502161
.. versionadded:: 1.0
21512162
2152-
dtype : dtype, default=None
2153-
dtype of the validated `sample_weight`.
2154-
If None, and the input `sample_weight` is an array, the dtype of the
2155-
input is preserved; otherwise an array with the default numpy dtype
2156-
is be allocated. If `dtype` is not one of `float32`, `float64`,
2157-
`None`, the output will be of dtype `float64`.
2158-
21592163
copy : bool, default=False
21602164
If True, a copy of sample_weight will be created.
21612165

0 commit comments

Comments
 (0)
0