7
7
import pytest
8
8
9
9
10
- def test_fit_statistics_binary (cancer_dataset ):
10
+ def test_fit_statistics_binary_pandas (cancer_dataset ):
11
+ """Check fit statistics for a binary classification model with Pandas inputs."""
11
12
sklearn = pytest .importorskip ('sklearn' )
12
13
13
14
from sklearn .ensemble import RandomForestClassifier
@@ -43,6 +44,43 @@ def test_fit_statistics_binary(cancer_dataset):
43
44
assert datamap ['_TAU_' ] is not None
44
45
45
46
47
+ def test_fit_statistics_binary_numpy (cancer_dataset ):
48
+ """Check fit statistics for a binary classification model with Numpy inputs."""
49
+ sklearn = pytest .importorskip ('sklearn' )
50
+
51
+ from sklearn .ensemble import RandomForestClassifier
52
+ from sasctl .utils import metrics
53
+
54
+ model = RandomForestClassifier ()
55
+ X = cancer_dataset .drop ('Type' , axis = 1 ).values
56
+ y = cancer_dataset ['Type' ].values
57
+ model .fit (X , y )
58
+
59
+ stats = metrics .fit_statistics (model , train = (X , y ))
60
+
61
+ assert isinstance (stats , dict )
62
+
63
+ # Should only contain stats for training data
64
+ assert len (stats ['data' ]) == 1
65
+
66
+ assert stats ['data' ][0 ]['rowNumber' ] == 1
67
+ datamap = stats ['data' ][0 ]['dataMap' ]
68
+
69
+ assert datamap ['_DataRole_' ] == 'TRAIN'
70
+ assert datamap ['_NObs_' ] == X .shape [0 ]
71
+ assert datamap ['_DIV_' ] == X .shape [0 ]
72
+
73
+ assert datamap ['_ASE_' ] is not None
74
+ assert datamap ['_C_' ] is not None
75
+ assert datamap ['_GAMMA_' ] is not None
76
+ assert datamap ['_GINI_' ] is not None
77
+ assert datamap ['_KS_' ] is not None
78
+ assert datamap ['_MCE_' ] is not None
79
+ assert datamap ['_MCLL_' ] is not None
80
+ assert datamap ['_RASE_' ] is not None
81
+ assert datamap ['_TAU_' ] is not None
82
+
83
+
46
84
def test_fit_statistics_regression (boston_dataset ):
47
85
sklearn = pytest .importorskip ('sklearn' )
48
86
@@ -62,12 +100,14 @@ def test_fit_statistics_regression(boston_dataset):
62
100
assert len (stats ['data' ]) == 1
63
101
64
102
assert stats ['data' ][0 ]['rowNumber' ] == 1
65
- assert stats ['data' ][0 ]['dataMap' ]['_DataRole_' ] == 'TRAIN'
66
- assert stats ['data' ][0 ]['dataMap' ]['_NObs_' ] == X .shape [0 ]
67
- assert stats ['data' ][0 ]['dataMap' ]['_DIV_' ] == X .shape [0 ]
103
+ datamap = stats ['data' ][0 ]['dataMap' ]
104
+
105
+ assert datamap ['_DataRole_' ] == 'TRAIN'
106
+ assert datamap ['_NObs_' ] == X .shape [0 ]
107
+ assert datamap ['_DIV_' ] == X .shape [0 ]
68
108
69
- for stat in ( '_ASE_' , ):
70
- assert stats [ 'data' ][ 0 ][ 'dataMap' ][ stat ] is not None
109
+ assert datamap [ '_ASE_' ] is not None
110
+ assert datamap [ '_RASE_' ] is not None
71
111
72
112
73
113
def test_fit_statistics_multiclass (iris_dataset ):
0 commit comments