@@ -1252,7 +1252,8 @@ def test_svm_probA_proB_deprecated(SVMClass, data, deprecated_prob):
1252
1252
getattr (clf , deprecated_prob )
1253
1253
1254
1254
1255
- def test_custom_kernel_not_array_input ():
1255
+ @pytest .mark .parametrize ("Estimator" , [svm .SVC , svm .SVR ])
1256
+ def test_custom_kernel_not_array_input (Estimator ):
1256
1257
"""Test using a custom kernel that is not fed with array-like for floats"""
1257
1258
data = ["A A" , "A" , "B" , "B B" , "A B" ]
1258
1259
X = np .array ([[2 , 0 ], [1 , 0 ], [0 , 1 ], [0 , 2 ], [1 , 1 ]]) # count encoding
@@ -1273,15 +1274,21 @@ def string_kernel(X1, X2):
1273
1274
K = string_kernel (data , data )
1274
1275
assert_array_equal (np .dot (X , X .T ), K )
1275
1276
1276
- svc1 = svm . SVC (kernel = string_kernel ).fit (data , y )
1277
- svc2 = svm . SVC (kernel = 'linear' ).fit (X , y )
1278
- svc3 = svm . SVC (kernel = 'precomputed' ).fit (K , y )
1277
+ svc1 = Estimator (kernel = string_kernel ).fit (data , y )
1278
+ svc2 = Estimator (kernel = 'linear' ).fit (X , y )
1279
+ svc3 = Estimator (kernel = 'precomputed' ).fit (K , y )
1279
1280
1280
1281
assert svc1 .score (data , y ) == svc3 .score (K , y )
1281
1282
assert svc1 .score (data , y ) == svc2 .score (X , y )
1282
- assert_array_almost_equal (svc1 .decision_function (data ),
1283
- svc2 .decision_function (X ))
1284
- assert_array_almost_equal (svc1 .decision_function (data ),
1285
- svc3 .decision_function (K ))
1286
- assert_array_equal (svc1 .predict (data ), svc2 .predict (X ))
1287
- assert_array_equal (svc1 .predict (data ), svc3 .predict (K ))
1283
+ if hasattr (svc1 , 'decision_function' ): # classifier
1284
+ assert_array_almost_equal (svc1 .decision_function (data ),
1285
+ svc2 .decision_function (X ))
1286
+ assert_array_almost_equal (svc1 .decision_function (data ),
1287
+ svc3 .decision_function (K ))
1288
+ assert_array_equal (svc1 .predict (data ), svc2 .predict (X ))
1289
+ assert_array_equal (svc1 .predict (data ), svc3 .predict (K ))
1290
+ else : # regressor
1291
+ assert_array_almost_equal (svc1 .predict (data ),
1292
+ svc2 .predict (X ))
1293
+ assert_array_almost_equal (svc1 .predict (data ),
1294
+ svc3 .predict (K ))
0 commit comments