8000 ENH: Forbid (0, 1) weights in weighted quantile (#9211) · numpy/numpy@ea8df3f · GitHub
[go: up one dir, main page]

Skip to content

Commit ea8df3f

Browse files
committed
ENH: Forbid (0, 1) weights in weighted quantile (#9211)
1 parent cdd65ef commit ea8df3f

File tree

4 files changed

+48
-74
lines changed

4 files changed

+48
-74
lines changed

numpy/lib/function_base.py

Lines changed: 33 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -2017,7 +2017,7 @@ def disp(mesg, device=None, linefeed=True):
20172017
"(deprecated in NumPy 2.0)",
20182018
DeprecationWarning,
20192019
stacklevel=2
2020-
)
2020+
)
20212021

20222022
if device is None:
20232023
device = sys.stdout
@@ -3984,22 +3984,22 @@ def _median(a, axis=None, out=None, overwrite_input=False):
39843984
return rout
39853985

39863986

3987-
def _percentile_dispatcher(a, q, axis=None, weights=None, out=None,
3987+
def _percentile_dispatcher(a, q, axis=None, out=None,
39883988
overwrite_input=None, method=None, keepdims=None, *,
3989-
interpolation=None):
3989+
weights=None, interpolation=None):
39903990
return (a, q, out)
39913991

39923992

39933993
@array_function_dispatch(_percentile_dispatcher)
39943994
def percentile(a,
39953995
q,
39963996
axis=None,
3997-
weights=None,
39983997
out=None,
39993998
overwrite_input=False,
40003999
method="linear",
40014000
keepdims=False,
40024001
*,
4002+
weights=None,
40034003
interpolation=None):
40044004
"""
40054005
Compute the q-th percentile of the data along the specified axis.
@@ -4319,22 +4319,22 @@ def percentile(a,
43194319
a, q, axis, weights, out, overwrite_input, method, keepdims)
43204320

43214321

4322-
def _quantile_dispatcher(a, q, axis=None, weights=None, out=None,
4322+
def _quantile_dispatcher(a, q, axis=None, out=None,
43234323
overwrite_input=None, method=None, keepdims=None, *,
4324-
interpolation=None):
4324+
weights=None, interpolation=None):
43254325
return (a, q, out)
43264326

43274327

43284328
@array_function_dispatch(_quantile_dispatcher)
43294329
def quantile(a,
43304330
q,
43314331
axis=None,
4332-
weights=None,
43334332
out=None,
43344333
overwrite_input=False,
43354334
method="linear",
43364335
keepdims=False,
43374336
*,
4337+
weights=None,
43384338
interpolation=None):
43394339
"""
43404340
Compute the q-th quantile of the data along the specified axis.
@@ -4647,11 +4647,10 @@ def _validate_and_ureduce_weights(a, axis, wgts):
46474647
46484648
Weights cannot:
46494649
* be negative
4650+
* be (0, 1)
46504651
* sum to 0
46514652
However, they can be
46524653
* 0, as long as they do not sum to 0
4653-
* less than 1. In this case, all weights are re-normalized by
4654-
the lowest non-zero weight prior to computation.
46554654
46564655
Weights will be broadcasted to the shape of a, then reduced as done
46574656
via _ureduce().
@@ -4678,6 +4677,9 @@ def _validate_and_ureduce_weights(a, axis, wgts):
46784677
if (wgts < 0).any():
46794678
raise ValueError("Negative weight not allowed.")
46804679

4680+
if ((0 < wgts) & (wgts < 1)).any():
4681+
raise ValueError("Partial weight (0, 1) not allowed.")
4682+
46814683
# dims to reshape to, before broadcast
46824684
if axis is None:
46834685
dims = tuple(range(a.ndim)) # all axes
@@ -4714,22 +4716,6 @@ def _validate_and_ureduce_weights(a, axis, wgts):
47144716
# Obtain a weights array of the same shape as ureduced a
47154717
wgts = _ureduce(wgts, func=lambda x, **kwargs: x, axis=dims)
47164718

4717-
# Now check/renormalize weights if any is (0, 1)
4718-
def _normalize(v):
4719-
inds = v > 0
4720-
if (v[inds] < 1).any():
4721-
vec = v.copy()
4722-
vec[inds] = vec[inds] / vec[inds].min() # renormalization
4723-
return vec
4724-
else:
4725-
return v
4726-
4727-
# perform normalization along reduced axis
4728-
if len(dims) > 1:
4729-
wgts = np.apply_along_axis(_normalize, -1, wgts)
4730-
else:
4731-
wgts = np.apply_along_axis(_normalize, dims[0], wgts)
4732-
47334719
return wgts
47344720

47354721

@@ -5022,11 +5008,14 @@ def _get_weighted_quantile_values(arr1d, wgts1d):
50225008

50235009
# each weight occupies a range in weight space w/ left/right bounds
50245010
left_weight_bound = np.roll(wgts1d_cumsum, 1)
5025-
left_weight_bound[0] = 0 # left-most weight bound fixed at 0
5026-
right_weight_bound = wgts1d_cumsum - 1
5011+
# value i left weight index bound = sum(weights before i) + 1 - 1,
5012+
# the +1 due to neighboring values having an index distance of 1,
5013+
# the -1 due to 0-indexing in Python
5014+
left_weight_bound[0] = 0 # left-most weight bound defined to be 0
5015+
right_weight_bound = wgts1d_cumsum - 1 # -1 due to 0-indexing
50275016

50285017
# now construct a mapping from weight bounds to real indexes
5029-
# for example, arr1d=[1, 2] & wgts1d=[2, 3] ->
5018+
# arr1d=[7, 8] & wgts1d=[2, 3] == [7, 7, 8, 8, 8]
50305019
# -> real_indexes=[0, 0, 1, 1] & w_index_bounds=[0, 1, 2, 4]
50315020
indexes = np.arange(arr1d.size)
50325021
real_indexes = np.zeros(2 * indexes.size)
@@ -5039,29 +5028,35 @@ def _get_weighted_quantile_values(arr1d, wgts1d):
50395028
# first define previous_w_indexes/next_w_indexes as the indexes
50405029
# within w_index_bounds whose values sandwich weight_space_indexes.
50415030
# so if w_index_bounds=[0, 1, 2, 4] and weight_space_index=3.5,
5042-
# then previous_w_indexes = 2 and next_w_indexes = 3
5031+
# then previous_w_indexes = 2 and next_w_indexes = 3,
5032+
# meaning weight_space_indexed is sandwiched by w_index_bounds[2]
5033+
# and w_index_bounds[3].
50435034
previous_w_indexes = np.searchsorted(w_index_bounds,
50445035
weight_space_indexes,
50455036
side="right") - 1
50465037
# leverage _get_index() to deal with out-of-bound indices
50475038
previous_w_indexes, next_w_indexes =\
50485039
_get_indexes(w_index_bounds, previous_w_indexes,
50495040
len(w_index_bounds))
5050-
# now redefine previous_w_indexes/next_w_indexes as the weight
5051-
# space indexes that neighbor weight_space_indexes.
5041+
# following earlier example, we now know weight_space_indexed is
5042+
# sandwiched by w_index_bounds[2] and w_index_bounds[3], which are
5043+
# 2 and 4. We want the 2 and 4.
5044+
# so redefine previous_w_indexes/next_w_indexes as the
5045+
# w_index_bounds that neighbor weight_space_indexes.
50525046
previous_w_indexes = w_index_bounds[previous_w_indexes]
50535047
next_w_indexes = w_index_bounds[next_w_indexes]
50545048

5055-
# map all weight space indexes to real indexes, then compute gamma
5049+
# method-dependent gammas determine interpolation scheme between
5050+
# neighboring values, and are computed in weight space.
5051+
gamma =\
5052+
_get_gamma(weight_space_indexes, previous_w_indexes, method)
5053+
5054+
# map all weight space indexes to real indexes
50565055
previous_indexes =\
50575056
np.interp(previous_w_indexes, w_index_bounds, real_indexes)
50585057
next_indexes =\
50595058
np.interp(next_w_indexes, w_index_bounds, real_indexes)
50605059

5061-
# method-dependent gammas determine interpolation scheme between
5062-
# neighboring values, and are computed in weight space.
5063-
gamma =\
5064-
_get_gamma(weight_space_indexes, previous_w_indexes, method)
50655060
previous = take(arr1d, previous_indexes.astype(int))
50665061
next = take(arr1d, next_indexes.astype(int))
50675062
return _lerp(previous, next, gamma, out=out)
@@ -5075,7 +5070,8 @@ def _get_weighted_quantile_values(arr1d, wgts1d):
50755070
result = get_weighted_quantile_values(arr, weights)
50765071

50775072
# now move data to DATA_AXIS to be consistent with no-weights case
5078-
result = np.moveaxis(result, -1, destination=0)
5073+
if axis != -1 and quantiles.ndim:
5074+
result = np.moveaxis(result, -1, destination=0)
50795075

50805076
else:
50815077
values_count = arr.shape[axis]

numpy/lib/nanfunctions.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1219,8 +1219,8 @@ def nanmedian(a, axis=None, out=None, overwrite_input=False, keepdims=np._NoValu
12191219

12201220

12211221
def _nanpercentile_dispatcher(
1222-
a, q, axis=None, weights=None, out=None, overwrite_input=None,
1223-
method=None, keepdims=None, *, interpolation=None):
1222+
a, q, axis=None, out=None, overwrite_input=None,
1223+
method=None, keepdims=None, *, weights=None, interpolation=None):
12241224
return (a, q, out)
12251225

12261226

@@ -1229,12 +1229,12 @@ def nanpercentile(
12291229
a,
12301230
q,
12311231
axis=None,
1232-
weights=None,
12331232
out=None,
12341233
overwrite_input=False,
12351234
method="linear",
12361235
keepdims=np._NoValue,
12371236
*,
1237+
weights=None,
12381238
interpolation=None,
12391239
):
12401240
"""
@@ -1385,9 +1385,9 @@ def nanpercentile(
13851385
a, q, axis, weights, out, overwrite_input, method, keepdims)
13861386

13871387

1388-
def _nanquantile_dispatcher(a, q, axis=None, weights=None, out=None,
1388+
def _nanquantile_dispatcher(a, q, axis=None, out=None,
13891389
overwrite_input=None, method=None, keepdims=None,
1390-
*, interpolation=None):
1390+
*, weights=None, interpolation=None):
13911391
return (a, q, out)
13921392

13931393

@@ -1396,12 +1396,12 @@ def nanquantile(
13961396
a,
13971397
q,
13981398
axis=None,
1399-
weights=None,
14001399
out=None,
14011400
overwrite_input=False,
14021401
method="linear",
14031402
keepdims=np._NoValue,
14041403
*,
1404+
weights=None,
14051405
interpolation=None,
14061406
):
14071407
"""

numpy/lib/tests/test_function_base.py

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3037,10 +3037,10 @@ def test_fraction(self):
30373037

30383038
def test_api(self):
30393039
d = np.ones(5)
3040-
np.percentile(d, 5, None, None, None, False)
3041-
np.percentile(d, 5, None, None, None, False, 'linear')
3040+
np.percentile(d, 5, None, None, False)
3041+
np.percentile(d, 5, None, None, False, 'linear')
30423042
o = np.ones((1,))
3043-
np.percentile(d, 5, None, None, o, False, 'linear')
3043+
np.percentile(d, 5, None, o, False, 'linear')
30443044

30453045
def test_complex(self):
30463046
arr_c = np.array([0.5+3.0j, 2.1+0.5j, 1.6+2.3j], dtype='G')
@@ -3861,15 +3861,8 @@ def test_various_weights(self, method):
38613861
assert_almost_equal(actual, expected)
38623862

38633863
# mix of numeric types
3864-
# due to renormalization triggered by weight < 1,
38653864
# this is expected to be the same as weights = [1, 2, 3]
3866-
weights = [decimal.Decimal(0.5), 1, 1.5]
3867-
actual = np.quantile(ar, q=q, axis=axis, weights=weights,
3868-
method=method)
3869-
assert_almost_equal(actual, expected)
3870-
3871-
# show that normalization means sum of weights is irrelavant
3872-
weights = [0.1, 0.2, 0.3]
3865+
weights = [decimal.Decimal(1.0), 2, 3.0]
38733866
actual = np.quantile(ar, q=q, axis=axis, weights=weights,
38743867
method=method)
38753868
assert_almost_equal(actual, expected)
@@ -3882,12 +3875,6 @@ def test_various_weights(self, method):
38823875
expected = np.quantile(ar_012, q=q, axis=axis, method=method)
38833876
assert_almost_equal(actual, expected)
38843877

3885-
# weight entries < 1
3886-
weights = [0.0, 0.001, 0.002]
3887-
actual = np.quantile(ar, q=q, axis=axis, weights=weights,
3888-
method=method)
3889-
assert_almost_equal(actual, expected)
3890-
38913878
def test_weights_flags(self):
38923879
"""Test that flags are raised on invalid weights."""
38933880
ar = np.arange(6).reshape(2, 3)
@@ -3902,6 +3889,8 @@ def test_weights_flags(self):
39023889
np.quantile(ar, q=q, axis=axis, weights=[1, np.nan])
39033890
with assert_raises_regex(ValueError, "Negative weight not allowed"):
39043891
np.quantile(ar, q=q, axis=axis, weights=[1, -1])
3892+
with assert_raises_regex(ValueError, "Partial weight"):
3893+
np.quantile(ar, q=q, axis=axis, weights=[1, 0.1])
39053894
with assert_raises_regex(ZeroDivisionError, "Weights sum to zero"):
39063895
np.quantile(ar, q=q, axis=axis, weights=[0, 0])
39073896

numpy/lib/tests/test_nanfunctions.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1357,15 +1357,8 @@ def test_various_weights(self, method):
13571357
assert_almost_equal(actual, expected)
13581358

13591359
# mix of numeric types
1360-
# due to renormalization triggered by weight < 1,
13611360
# this is expected to be the same as weights = [1, 2, 3]
1362-
weights = [decimal.Decimal(0.5), 1, 1.5]
1363-
actual = np.nanquantile(ar, q=q, axis=axis, weights=weights,
1364-
method=method)
1365-
assert_almost_equal(actual, expected)
1366-
1367-
# show that normalization means sum of weights is irrelavant
1368-
weights = [0.2, 0.4, 0.6]
1361+
weights = [decimal.Decimal(1.0), 2, 3.0]
13691362
actual = np.nanquantile(ar, q=q, axis=axis, weights=weights,
13701363
method=method)
13711364
assert_almost_equal(actual, expected)
@@ -1378,12 +1371,6 @@ def test_various_weights(self, method):
13781371
expected = np.nanquantile(ar_012, q=q, axis=axis, method=method)
13791372
assert_almost_equal(actual, expected)
13801373

1381-
# weight entries < 1
1382-
weights = [0.0, 0.001, 0.002]
1383-
actual = np.nanquantile(ar, q=q, axis=axis, weights=weights,
1384-
method=method)
1385-
assert_almost_equal(actual, expected)
1386-
13871374
def test_weights_flags(self):
13881375
"""Test that flags are raised on invalid weights."""
13891376
ar = np.arange(6).reshape(2, 3).astype(float)
@@ -1399,6 +1386,8 @@ def test_weights_flags(self):
13991386
np.quantile(ar, q=q, axis=axis, weights=[1, np.nan])
14001387
with assert_raises_regex(ValueError, "Negative weight not allowed"):
14011388
np.quantile(ar, q=q, axis=axis, weights=[1, -1])
1389+
with assert_raises_regex(ValueError, "Partial weight"):
1390+
np.quantile(ar, q=q, axis=axis, weights=[1, 0.1])
14021391
with assert_raises_regex(ZeroDivisionError, "Weights sum to zero"):
14031392
np.quantile(ar, q=q, axis=axis, weights=[0, 0])
14041393

0 commit comments

Comments
 (0)
0