@@ -955,24 +955,30 @@ def _incremental_mean_and_var(
955
955
# new = the current increment
956
956
# updated = the aggregated stats
957
957
last_sum = last_mean * last_sample_count
958
+ X_nan_mask = np .isnan (X )
959
+ if np .any (X_nan_mask
8000
):
960
+ sum_op = np .nansum
961
+ else :
962
+ sum_op = np .sum
958
963
if sample_weight is not None :
959
964
if np_version >= parse_version ("1.16.6" ):
960
965
# equivalent to np.nansum(X * sample_weight, axis=0)
961
966
# safer because np.float64(X*W) != np.float64(X)*np.float64(W)
962
967
# dtype arg of np.matmul only exists since version 1.16
963
968
new_sum = _safe_accumulator_op (
964
- np .matmul , sample_weight , np .where (np . isnan ( X ) , 0 , X )
969
+ np .matmul , sample_weight , np .where (X_nan_mask , 0 , X )
965
970
)
966
971
else :
967
972
new_sum = _safe_accumulator_op (
968
973
np .nansum , X * sample_weight [:, None ], axis = 0
969
974
)
970
975
new_sample_count = _safe_accumulator_op (
971
- np .sum , sample_weight [:, None ] * (~ np . isnan ( X ) ), axis = 0
976
+ np .sum , sample_weight [:, None ] * (~ X_nan_mask ), axis = 0
972
977
)
973
978
else :
974
- new_sum = _safe_accumulator_op (np .nansum , X , axis = 0 )
975
- new_sample_count = np .sum (~ np .isnan (X ), axis = 0 )
979
+ new_sum = _safe_accumulator_op (sum_op , X , axis = 0 )
980
+ n_samples = X .shape [0 ]
981
+ new_sample_count = n_samples - np .sum (X_nan_mask , axis = 0 )
976
982
977
983
updated_sample_count = last_sample_count + new_sample_count
978
984
@@ -982,29 +988,31 @@ def _incremental_mean_and_var(
982
988
updated_variance = None
983
989
else :
984
990
T = new_sum / new_sample_count
991
+ temp = X - T
985
992
if sample_weight is not None :
986
993
if np_version >= parse_version ("1.16.6" ):
987
994
# equivalent to np.nansum((X-T)**2 * sample_weight, axis=0)
988
995
# safer because np.float64(X*W) != np.float64(X)*np.float64(W)
989
996
# dtype arg of np.matmul only exists since version 1.16
990
- new_unnormalized_variance = _safe_accumulator_op (
991
- np .matmul , sample_weight , np .where (np .isnan (X ), 0 , (X - T ) ** 2 )
992
- )
993
997
correction = _safe_accumulator_op (
994
- np .matmul , sample_weight , np .where (np . isnan ( X ) , 0 , X - T )
998
+ np .matmul , sample_weight , np .where (X_nan_mask , 0 , temp )
995
999
)
996
- else :
1000
+ temp **= 2
997
1001
new_unnormalized_variance = _safe_accumulator_op (
998
- np .nansum , ( X - T ) ** 2 * sample_weight [:, None ], axis = 0
1002
+ np .matmul , sample_weight , np . where ( X_nan_mask , 0 , temp )
999
1003
)
1004
+ else :
1000
1005
correction = _safe_accumulator_op (
1001
- np .nansum , (X - T ) * sample_weight [:, None ], axis = 0
1006
+ sum_op , temp * sample_weight [:, None ], axis = 0
1007
+ )
1008
+ temp *= temp
1009
+ new_unnormalized_variance = _safe_accumulator_op (
1010
+ sum_op , temp * sample_weight [:, None ], axis = 0
1002
1011
)
1003
1012
else :
1004
- new_unnormalized_variance = _safe_accumulator_op (
1005
- np .nansum , (X - T ) ** 2 , axis = 0
1006
- )
1007
- correction = _safe_accumulator_op (np .nansum , X - T , axis = 0 )
1013
+ correction = _safe_accumulator_op (sum_op , temp , axis = 0 )
1014
+ temp **= 2
1015
+ new_unnormalized_variance = _safe_accumulator_op (sum_op , temp , axis = 0 )
1008
1016
1009
1017
# correction term of the corrected 2 pass algorithm.
1010
1018
# See "Algorithms for computing the sample variance: analysis
0 commit comments