@@ -30,29 +30,26 @@ def test_mean_variance_axis0():
30
30
X_lil = sp .lil_matrix (X )
31
31
X_lil [1 , 0 ] = 0
32
32
X [1 , 0 ] = 0
33
- X_csr = sp .csr_matrix (X_lil )
34
33
35
- X_means , X_vars = mean_variance_axis (X_csr , axis = 0 )
36
- assert_array_almost_equal (X_means , np .mean (X , axis = 0 ))
37
- assert_array_almost_equal (X_vars , np .var (X , axis = 0 ))
34
+ assert_raises (TypeError , mean_variance_axis , X_lil , axis = 0 )
38
35
36
+ X_csr = sp .csr_matrix (X_lil )
39
37
X_csc = sp .csc_matrix (X_lil )
40
- X_means , X_vars = mean_variance_axis (X_csc , axis = 0 )
41
38
42
- assert_array_almost_equal (X_means , np .mean (X , axis = 0 ))
43
- assert_array_almost_equal (X_vars , np .var (X , axis = 0 ))
44
- assert_raises (TypeError , mean_variance_axis , X_lil , axis = 0 )
39
+ expected_dtypes = [(np .float32 , np .float32 ),
40
+ (np .float64 , np .float64 ),
41
+ (np .int32 , np .float64 ),
42
+ (np .int64 , np .float64 )]
45
43
46
- X = X .astype (np .float32 )
47
- X_csr = X_csr .astype (np .float32 )
48
- X_csc = X_csr .astype (np .float32 )
49
- X_means , X_vars = mean_variance_axis (X_csr , axis = 0 )
50
- assert_array_almost_equal (X_means , np .mean (X , axis = 0 ))
51
- assert_array_almost_equal (X_vars , np .var (X , axis = 0 ))
52
- X_means , X_vars = mean_variance_axis (X_csc , axis = 0 )
53
- assert_array_almost_equal (X_means , np .mean (X , axis = 0 ))
54
- assert_array_almost_equal (X_vars , np .var (X , axis = 0 ))
55
- assert_raises (TypeError , mean_variance_axis , X_lil , axis = 0 )
44
+ for input_dtype , output_dtype in expected_dtypes :
45
+ X_test = X .astype (input_dtype )
46
+ for X_sparse in (X_csr , X_csc ):
47
+ X_sparse = X_sparse .astype (input_dtype )
48
+ X_means , X_vars = mean_variance_axis (X_sparse , axis = 0 )
49
+ assert_equal (X_means .dtype , output_dtype )
50
+ assert_equal (X_vars .dtype , output_dtype )
51
+ assert_array_almost_equal (X_means , np .mean (X_test , axis = 0 ))
52
+ assert_array_almost_equal (X_vars , np .var (X_test , axis = 0 ))
56
53
57
54
58
55
def test_mean_variance_axis1 ():
@@ -64,29 +61,26 @@ def test_mean_variance_axis1():
64
61
X_lil = sp .lil_matrix (X )
65
62
X_lil [1 , 0 ] = 0
66
63
X [1 , 0 ] = 0
67
- X_csr = sp .csr_matrix (X_lil )
68
64
69
- X_means , X_vars = mean_variance_axis (X_csr , axis = 1 )
70
- assert_array_almost_equal (X_means , np .mean (X , axis = 1 ))
71
- assert_array_almost_equal (X_vars , np .var (X , axis = 1 ))
65
+ assert_raises (TypeError , mean_variance_axis , X_lil , axis = 1 )
72
66
67
+ X_csr = sp .csr_matrix (X_lil )
73
68
X_csc = sp .csc_matrix (X_lil )
74
- X_means , X_vars = mean_variance_axis (X_csc , axis = 1 )
75
69
76
- assert_array_almost_equal (X_means , np .mean (X , axis = 1 ))
77
- assert_array_almost_equal (X_vars , np .var (X , axis = 1 ))
78
- assert_raises (TypeError , mean_variance_axis , X_lil , axis = 1 )
70
+ expected_dtypes = [(np .float32 , np .float32 ),
71
+ (np .float64 , np .float64 ),
72
+ (np .int32 , np .float64 ),
73
+ (np .int64 , np .float64 )]
79
74
80
- X = X .astype (np .float32 )
81
- X_csr = X_csr .astype (np .float32 )
82
- X_csc = X_csr .astype (np .float32 )
83
- X_means , X_vars = mean_variance_axis (X_csr , axis = 1 )
84
- assert_array_almost_equal (X_means , np .mean (X , axis = 1 ))
85
- assert_array_almost_equal (X_vars , np .var (X , axis = 1 ))
86
- X_means , X_vars = mean_variance_axis (X_csc , axis = 1 )
87
- assert_array_almost_equal (X_means , np .mean (X , axis = 1 ))
88
- assert_array_almost_equal (X_vars , np .var (X , axis = 1 ))
89
- assert_raises (TypeError , mean_variance_axis , X_lil , axis = 1 )
75
+ for input_dtype , output_dtype in expected_dtypes :
76
+ X_test = X .astype (input_dtype )
77
+ for X_sparse in (X_csr , X_csc ):
78
+ X_sparse = X_sparse .astype (input_dtype )
79
+ X_means , X_vars = mean_variance_axis (X_sparse , axis = 0 )
80
+ assert_equal (X_means .dtype , output_dtype )
81
+ assert_equal (X_vars .dtype , output_dtype )
82
+ assert_array_almost_equal (X_means , np .mean (X_test , axis = 0 ))
83
+ assert_array_almost_equal (X_vars , np .var (X_test , axis = 0 ))
90
84
91
85
92
86
def test_incr_mean_variance_axis ():
@@ -132,34 +126,25 @@ def test_incr_mean_variance_axis():
132
126
X = np .vstack (data_chunks )
133
127
X_lil = sp .lil_matrix (X )
134
128
X_csr = sp .csr_matrix (X_lil )
135
- X_means , X_vars = mean_variance_axis (X_csr , axis )
136
- X_means_incr , X_vars_incr , n_incr = \
137
- incr_mean_variance_axis (X_csr , axis , last_mean , last_var , last_n )
138
- assert_array_almost_equal (X_means , X_means_incr )
139
- assert_array_almost_equal (X_vars , X_vars_incr )
140
- assert_equal (X .shape [axis ], n_incr )
141
-
142
129
X_csc = sp .csc_matrix (X_lil )
143
- X_means , X_vars = mean_variance_axis (X_csc , axis )
144
- assert_array_almost_equal (X_means , X_means_incr )
145
- assert_array_almost_equal (X_vars , X_vars_incr )
146
- assert_equal (X .shape [axis ], n_incr )
147
130
148
- # All data but as float
149
- X = X .astype (np .float32 )
150
- X_csr = X_csr .astype (np .float32 )
151
- X_means , X_vars = mean_variance_axis (X_csr , axis )
152
- X_means_incr , X_vars_incr , n_incr = \
153
- incr_mean_variance_axis (X_csr , axis , last_mean , last_var , last_n )
154
- assert_array_almost_equal (X_means , X_means_incr )
155
- assert_array_almost_equal (X_vars , X_vars_incr )
156
- assert_equal (X .shape [axis ], n_incr )
157
-
158
- X_csc = X_csr .astype (np .float32 )
159
- X_means , X_vars = mean_variance_axis (X_csc , axis )
160
- assert_array_almost_equal (X_means , X_means_incr )
161
- assert_array_almost_equal (X_vars , X_vars_incr )
162
- assert_equal (X .shape [axis ], n_incr )
131
+ expected_dtypes = [(np .float32 , np .float32 ),
132
+ (np .float64 , np .float64 ),
133
+ (np .int32 , np .float64 ),
134
+ (np .int64 , np .float64 )]
135
+
136
+ for input_dtype , output_dtype in expected_dtypes :
137
+ for X_sparse in (X_csr , X_csc ):
138
+ X_sparse = X_sparse .astype (input_dtype )
139
+ X_means , X_vars = mean_variance_axis (X_sparse , axis )
140
+ X_means_incr , X_vars_incr , n_incr = \
141
+ incr_mean_variance_axis (X_sparse , axis , last_mean ,
142
+ last_var , last_n )
143
+ assert_equal (X_means_incr .dtype , output_dtype )
144
+ assert_equal (X_vars_incr .dtype , output_dtype )
145
+ assert_array_almost_equal (X_means , X_means_incr )
146
+ assert_array_almost_equal (X_vars , X_vars_incr )
147
+ assert_equal (X .shape [axis ], n_incr )
163
148
164
149
165
150
def test_mean_variance_illegal_axis ():
0 commit comments