8000 Add extend testing · scikit-learn/scikit-learn@acd5afb · GitHub
[go: up one dir, main page]

Skip to content

Commit acd5afb

Browse files
committed
Add extend testing
1 parent ba46a28 commit acd5afb

File tree

2 files changed

+84
-36
lines changed

2 files changed

+84
-36
lines changed

â 8000 €Žsklearn/utils/extmath.py

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -707,7 +707,8 @@ def _safe_accumulator_op(op, x, *args, **kwargs):
707707
return result
708708

709709

710-
def _incremental_weighted_mean_and_var(X, sample_weight, last_weighted_mean,
710+
def _incremental_weighted_mean_and_var(X, sample_weight,
711+
last_weighted_mean,
711712
last_weighted_variance,
712713
last_weight_sum):
713714
"""Calculate weighted mean and variance batch update
@@ -762,32 +763,41 @@ def _incremental_weighted_mean_and_var(X, sample_weight, last_weighted_mean,
762763
# updated = the aggregated stats
763764

764765
M = np.isnan(X)
765-
X = np.where(np.isnan(X), 0, X)
766-
new_weight_sum = np.dot(np.transpose(np.reshape(sample_weight, (-1, 1))), ~M).ravel()
767-
total_weight_sum = np.sum(sample_weight, axis=0)
768-
769-
new_weighted_mean = np.average(X, weights=sample_weight, axis=0)
770-
new_weighted_mean = (new_weighted_mean * total_weight_sum) / new_weight_sum
766+
sample_weight_T = np.transpose(np.reshape(sample_weight, (-1, 1)))
767+
new_weight_sum = _safe_accumulator_op(np.dot, sample_weight_T, ~M).ravel()
768+
total_weight_sum = _safe_accumulator_op(np.sum, sample_weight, axis=0)
769+
770+
X_0 = np.where(np.isnan(X), 0, X)
771+
new_weighted_mean = \
772+
_safe_accumulator_op(np.average, X_0, weights=sample_weight, axis=0)
773+
new_weighted_mean *= total_weight_sum / new_weight_sum
771774
updated_weight_sum = last_weight_sum + new_weight_sum
772-
updated_weighted_mean = (last_weight_sum * last_weighted_mean
773-
+ new_weight_sum * new_weighted_mean) / updated_weight_sum
775+
updated_weighted_mean = (
776+
(last_weight_sum * last_weighted_mean +
777+
new_weight_sum * new_weighted_mean) / updated_weight_sum)
774778

775779
if last_weighted_variance is None:
776780
updated_weighted_variance = None
777781
else:
778-
new_weighted_variance = (np.average(
779-
X ** 2, weights=sample_weight, axis=0) * total_weight_sum / new_weight_sum) - new_weighted_mean ** 2
780-
new_element = new_weight_sum * \
781-
(new_weighted_variance + (new_weighted_mean - updated_weighted_mean) ** 2)
782-
last_element = last_weight_sum * \
783-
(last_weighted_variance + (last_weighted_mean - updated_weighted_mean) ** 2)
782+
X_0 = np.where(np.isnan(X), 0, (X-new_weighted_mean)**2)
783+
new_weighted_variance = \
784+
_safe_accumulator_op(
785+
np.average, X_0, weights=sample_weight, axis=0)
786+
new_weighted_variance *= total_weight_sum / new_weight_sum
787+
new_element = (
788+
new_weight_sum *
789+
(new_weighted_variance +
790+
(new_weighted_mean - updated_weighted_mean) ** 2))
791+
last_element = (
792+
last_weight_sum *
793+
(last_weighted_variance +
794+
(last_weighted_mean - updated_weighted_mean) ** 2))
784795
updated_weighted_variance = (
785796
new_element + last_element) / updated_weight_sum
786797

787798
return updated_weighted_mean, updated_weighted_variance, updated_weight_sum
788799

789800

790-
791801
def _incremental_mean_and_var(X, last_mean, last_variance, last_sample_count):
792802
"""Calculate mean update and a Youngs and Cramer variance update.
793803

‎sklearn/utils/tests/test_extmath.py

Lines changed: 58 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from sklearn.utils._testing import assert_warns_message
2323
from sklearn.utils._testing import skip_if_32bit
2424

25-
from sklearn.utils.extmath import density
25+
from sklearn.utils.extmath import density, _safe_accumulator_op
2626
from sklearn.utils.extmath import randomized_svd
2727
from sklearn.utils.extmath import row_norms
2828
from sklearn.utils.extmath import weighted_mode
@@ -470,32 +470,70 @@ def test_incremental_weighted_mean_and_variance_simple():
470470

471471

472472
def test_incremental_weighted_mean_and_variance():
473-
rng = np.random.RandomState(42)
474-
mult = 10
475-
X = rng.rand(1000, 20)*mult
476-
sample_weight = rng.rand(X.shape[0]) * mult
477473

478-
n = X.shape[0]
479-
last_mean, last_weight_sum, last_var = 0, 0, 0
480-
mean_exp = np.average(X, weights=sample_weight, axis=0)
481-
var_exp = np.average(X ** 2, weights=sample_weight, axis=0) - mean_exp ** 2
482-
for chunk_size in [1, 2, 50, n, n + 42]:
483-
for batch in gen_batches(n, chunk_size):
484-
last_mean, last_var, last_weight_sum = \
485-
_incremental_weighted_mean_and_var(X[batch],
486-
sample_weight[batch],
487-
last_mean,
488-
last_var,
489-
last_weight_sum)
490-
assert_almost_equal(last_mean, mean_exp)
491-
assert_almost_equal(last_var, var_exp)
474+
# Testing of correctness and numerical stability
475+
def test(X, sample_weight, mean_exp=None, var_exp=None):
476+
n = X.shape[0]
477+
if mean_exp is None:
478+
mean_exp = \
479+
_safe_accumulator_op(
480+
np.average, X, weights=sample_weight, axis=0)
481+
if var_exp is None:
482+
var_exp = \
483+
_safe_accumulator_op(
484+
np.average, (X-mean_exp)**2, weights=sample_weight, axis=0)
485+
for chunk_size in [1, n//10 + 1, n//4 + 1, n//2 + 1, n]:
486+
last_mean, last_weight_sum, last_var = 0, 0, 0
487+
for batch in gen_batches(n, chunk_size):
488+
last_mean, last_var, last_weight_sum = \
489+
_incremental_weighted_mean_and_var(X[batch],
490+
sample_weight[batch],
491+
last_mean,
492+
last_var,
493+
last_weight_sum)
494+
assert_almost_equal(last_mean, mean_exp)
495+
assert_almost_equal(last_var, var_exp, 6)
496+
497+
HIGH_MEAN = 10e6
498+
LOW_VAR = 10e-7
499+
NORMAL_MEAN = 0.0
500+
NORMAL_VAR = 1.0
501+
SIZE = (100, 20)
502+
503+
rng = np.random.RandomState(42)
504+
NORMAL_WEIGHT = \
505+
rng.normal(loc=NORMAL_MEAN, scale=NORMAL_VAR, size=(SIZE[0],))
506+
ALMOST_ZERO_WEIGHT = \
507+
rng.normal(loc=NORMAL_MEAN, scale=LOW_VAR, size=(SIZE[0],))
508+
ALMOST_ONES_WEIGHT = rng.normal(loc=1.0, scale=LOW_VAR, size=(SIZE[0],))
509+
JUST_WEIGHT = rng.normal(loc=10.0, scale=NORMAL_VAR, size=(SIZE[0],))
510+
ONES_WEIGHT = np.ones(SIZE[0])
511+
512+
means = [NORMAL_MEAN, HIGH_MEAN]
513+
vars = [NORMAL_VAR, LOW_VAR]
514+
weights = \
515+
[NORMAL_WEIGHT, ALMOST_ONES_WEIGHT, JUST_WEIGHT, ALMOST_ZERO_WEIGHT]
516+
means_vars = ((m, v) for m in means for v in vars)
517+
518+
# Comparing with weighted np.average
519+
for mean, var in means_vars:
520+
X = rng.normal(loc=mean, scale=var, size=SIZE)
521+
print(mean, var)
522+
for weight in weights:
523+
test(X, weight)
524+
525+
# Comparing with unweighted np.average
526+
for mean, var in means_vars:
527+
X = rng.normal(loc=mean, scale=var, size=SIZE)
528+
mean_exp = _safe_accumulator_op(np.mean, X, axis=0)
529+
var_exp = _safe_accumulator_op(np.var, X, axis=0)
530+
test(X, ONES_WEIGHT, mean_exp, var_exp)
492531

493532

494533
def test_incremental_weighted_mean_and_variance_ignore_nan():
495534
old_means = np.array([535., 535., 535., 535.])
496535
old_variances = np.array([4225., 4225., 4225., 4225.])
497536
old_weight_sum = np.array([2, 2, 2, 2], dtype=np.int32)
498-
499537
sample_weights_X = np.ones(3)
500538
sample_weights_X_nan = np.ones(4)
501539

0 commit comments

Comments
 (0)
0