8000 test fit stats with numpy arrays · Sivateja0689/python-sasctl@ce2eb14 · GitHub
[go: up one dir, main page]

Skip to content

Commit ce2eb14

Browse files
committed
test fit stats with numpy arrays
1 parent 580e0d8 commit ce2eb14

File tree

2 files changed

+51
-7
lines changed

2 files changed

+51
-7
lines changed

CHANGELOG.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
Unreleased
22
----------
3-
-
3+
- Added metrics module
4+
- train, test, valid inputs to register_model
5+
- overwrite register_model files
6+
- metrics included by default
7+
48

59
v1.5 (2020-2-23)
610
----------------

tests/unit/test_metrics.py

Lines changed: 46 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
import pytest
88

99

10-
def test_fit_statistics_binary(cancer_dataset):
10+
def test_fit_statistics_binary_pandas(cancer_dataset):
11+
"""Check fit statistics for a binary classification model with Pandas inputs."""
1112
sklearn = pytest.importorskip('sklearn')
1213

1314
from sklearn.ensemble import RandomForestClassifier
@@ -43,6 +44,43 @@ def test_fit_statistics_binary(cancer_dataset):
4344
assert datamap['_TAU_'] is not None
4445

4546

47+
def test_fit_statistics_binary_numpy(cancer_dataset):
48+
"""Check fit statistics for a binary classification model with Numpy inputs."""
49+
sklearn = pytest.importorskip('sklearn')
50+
51+
from sklearn.ensemble import RandomForestClassifier
52+
from sasctl.utils import metrics
53+
54+
model = RandomForestClassifier()
55+
X = cancer_dataset.drop('Type', axis=1).values
56+
y = cancer_dataset['Type'].values
57+
model.fit(X, y)
58+
59+
stats = metrics.fit_statistics(model, train=(X, y))
60+
61+
assert isinstance(stats, dict)
62+
63+
# Should only contain stats for training data
64+
assert len(stats['data']) == 1
65+
66+
assert stats['data'][0]['rowNumber'] == 1
67+
datamap = stats['data'][0]['dataMap']
68+
69+
assert datamap['_DataRole_'] == 'TRAIN'
70+
assert datamap['_NObs_'] == X.shape[0]
71+
assert datamap['_DIV_'] == X.shape[0]
72+
73+
assert datamap['_ASE_'] is not None
74+
assert datamap['_C_'] is not None
75+
assert datamap['_GAMMA_'] is not None
76+
assert datamap['_GINI_'] is not None
77+
assert datamap['_KS_'] is not None
78+
assert datamap['_MCE_'] is not None
79+
assert datamap['_MCLL_'] is not None
80+
assert datamap['_RASE_'] is not None
81+
assert datamap['_TAU_'] is not None
82+
83+
4684
def test_fit_statistics_regression(boston_dataset):
4785
sklearn = pytest.importorskip('sklearn')
4886

@@ -62,12 +100,14 @@ def test_fit_statistics_regression(boston_dataset):
62100
assert len(stats['data']) == 1
63101

64102
assert stats['data'][0]['rowNumber'] == 1
65-
assert stats['data'][0]['dataMap']['_DataRole_'] == 'TRAIN'
66-
assert stats['data'][0]['dataMap']['_NObs_'] == X.shape[0]
67-
assert stats['data'][0]['dataMap']['_DIV_'] == X.shape[0]
103+
datamap = stats['data'][0]['dataMap']
104+
105+
assert datamap['_DataRole_'] == 'TRAIN'
106+
assert datamap['_NObs_'] == X.shape[0]
107+
assert datamap['_DIV_'] == X.shape[0]
68108

69-
for stat in ('_ASE_', ):
70-
assert stats['data'][0]['dataMap'][stat] is not None
109+
assert datamap['_ASE_'] is not None
110+
assert datamap['_RASE_'] is not None
71111

72112

73113
def test_fit_statistics_multiclass(iris_dataset):

0 commit comments

Comments
 (0)
0