8000 implement unit tests for new code · SkuaD01/scikit-learn@cdccae0 · GitHub
[go: up one dir, main page]

Skip to content

Commit cdccae0

Browse files
committed
implement unit tests for new code
1 parent 08b60a9 commit cdccae0

File tree

1 file changed

+29
-0
lines changed

1 file changed

+29
-0
lines changed

sklearn/preprocessing/tests/test_data.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,35 @@ def test_polynomial_features():
126126
assert interact.powers_.shape == (interact.n_output_features_,
127127
interact.n_input_features_)
128128

129+
def test_polynomial_features_min_degree_equals_max_degree():
130+
X = [[2, 3, 4]]
131+
poly = PolynomialFeatures(min_degree=2, max_degree=2, interaction_only=True)
132+
res = poly.fit_transform(X)
133+
assert_array_almost_equal(res, [[6., 8., 12.]])
134+
135+
def test_polynomial_features_include_bias_overrides_min_degree():
136+
X = [[2, 3, 4]]
137+
poly = PolynomialFeatures(include_bias=False, min_degree=0, max_degree=2, interaction_only=True)
138+
res = poly.fit_transform(X)
139+
assert_array_almost_equal(res, [[2., 3., 4., 6., 8., 12.]])
140+
141+
def test_polynomial_features_min_degree_overrides_include_bias():
142+
X = [[2, 3, 4]]
143+
poly = PolynomialFeatures(include_bias=False, min_degree=2, max_degree=2, interaction_only=True)
144+
res = poly.fit_transform(X)
145+
assert_array_almos 909D t_equal(res, [[6., 8., 12.]])
146+
147+
def test_polynomial_features_min_degree_greater_than_max_degree():
148+
X = [[2, 3, 4]]
149+
poly = PolynomialFeatures(min_degree=4, max_degree=2, interaction_only=True)
150+
res = poly.fit_transform(X)
151+
assert_array_almost_equal(res, [[]])
152+
153+
def test_polynomial_features_max_degree_instead_of_degree():
154+
X = [[2, 3, 4]]
155+
poly = PolynomialFeatures(max_degree=2, degree=6, interaction_only=True)
156+
res = poly.fit_transform(X)
157+
assert_array_almost_equal(res, [[1., 2., 3., 4., 6., 8., 12.]])
129158

130159
def test_polynomial_feature_names():
131160
X = np.arange(30).reshape(10, 3)

0 commit comments

Comments
 (0)
0