8000 Merge pull request #42 from jameskochubasas/master · teoland/python-sasctl@b3c34e6 · GitHub
[go: up one dir, main page]

Skip to content

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit b3c34e6

Browse files
authored
Merge pull request sassoftware#42 from jameskochubasas/master
Fixing the MM performance capabilities for all models
2 parents a6e1bd7 + 5ca8491 commit b3c34e6

File tree

6 files changed

+137
-25
lines changed

6 files changed

+137
-25
lines changed

src/sasctl/_services/model_management.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,11 +174,16 @@ def create_performance_definition(cls,
174174

175175
# Performance data cannot be captured unless certain project properties
176176
# have been configured.
177-
for required in ['targetVariable', 'targetLevel',
178-
'predictionVariable']:
177+
for required in ['targetVariable', 'targetLevel']:
179178
if getattr(project, required, None) is None:
180179
raise ValueError("Project %s must have the '%s' property set."
181180
% (project.name, required))
181+
if project['function'] == 'classification' and project['eventProbabilityVariable'] == None:
182+
raise ValueError("Project %s must have the 'eventProbabilityVariable' property set."
183+
% (project.name))
184+
if project['function'] == 'prediction' and project['predictionVariable'] == None:
185+
raise ValueError("Project %s must have the 'predictionVariable' property set."
186+
% (project.name))
182187

183188
request = {'projectId': project.id,
184189
'name': name or model.name + ' Performance',

src/sasctl/tasks.py

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@ def _sklearn_to_dict(model):
4949
'RandomForestClassifier': 'Forest',
5050
'DecisionTreeClassifier': 'Decision tree',
5151
'DecisionTreeRegressor': 'Decision tree',
52-
'classifier': 'Classification',
53-
'regressor': 'Prediction'}
52+
'classifier': 'classification',
53+
'regressor': 'prediction'}
5454

5555
if hasattr(model, '_final_estimator'):
5656
estimator = type(model._final_estimator)
@@ -207,10 +207,26 @@ def get_version(x):
207207
# If model is a CASTable then assume it holds an ASTORE model.
208208
# Import these via a ZIP file.
209209
if 'swat.cas.table.CASTable' in str(type(model)):
210-
zipfile = utils.create_package(model)
210+
zipfile = utils.create_package(model, input=input)
211211

212212
if create_project:
213-
project = mr.create_project(project, repo_obj)
213+
outvar=[]
214+
invar=[]
215+
import zipfile as zp
216+
import copy
217+
zipfilecopy = copy.deepcopy(zipfile)
218+
tmpzip=zp.ZipFile(zipfilecopy)
219+
if "outputVar.json" in tmpzip.namelist():
220+
outvar=json.loads(tmpzip.read("outputVar.json").decode('utf=8')) #added decode for 3.5 and older
221+
for tmp in outvar:
222+
tmp.update({'role':'output'})
223+
if "inputVar.json" in tmpzip.namelist():
224+
invar=json.loads(tmpzip.read("inputVar.json").decode('utf-8')) #added decode for 3.5 and older
225+
for tmp in invar:
226+
if tmp['role'] != 'input':
227+
tmp['role']='input'
228+
vars=invar + outvar
229+
project = mr.create_project(project, repo_obj, variables=vars)
214230

215231
model = mr.import_model_from_zip(name, project, zipfile,
216232
version=version)
@@ -302,17 +318,27 @@ def get_version(x):
302318
else:
303319
prediction_variable = None
304320

305-
project = mr.create_project(project, repo_obj,
321+
# As of Viya 3.4 the 'predictionVariable' parameter is not set during
322+
# project creation. Update the project if necessary.
323+
if function == 'prediction': #Predications require predictionVariable
324+
project = mr.create_project(project, repo_obj,
306325
variables=vars,
307326
function=model.get('function'),
308327
targetLevel=target_level,
309328
predictionVariable=prediction_variable)
310329

311-
# As of Viya 3.4 the 'predictionVariable' parameter is not set during
312-
# project creation. Update the project if necessary.
313-
if project.get('predictionVariable') != prediction_variable:
314-
project['predictionVariable'] = prediction_variable
315-
mr.update_project(project)
330+
if project.get('predictionVariable') != prediction_variable:
331+
project['predictionVariable'] = prediction_variable
332+
mr.update_project(project)
333+
else: #Classifications require eventProbabilityVariable
334+
project = mr.create_project(project, repo_obj,
335+
variables=vars,
336+
function=model.get('function'),
337+
targetLevel=target_level,
338+
eventProbabilityVariable=prediction_variable)
339+
if project.get('eventProbabilityVariable') != prediction_variable:
340+
project['eventProbabilityVariable'] = prediction_variable
341+
mr.update_project(project)
316342

317343
model = mr.create_model(model, project)
318344

@@ -506,9 +532,12 @@ def update_model_performance(data, model, label, refresh=True):
506532
"regression and binary classification projects. "
507533
"Received project with '%s' target level. Should be "
508534
"'Interval' or 'Binary'.", project.get('targetLevel'))
509-
elif project.get('predictionVariable', '') == '':
535+
elif project.get('predictionVariable', '') == '' and project.get('function', '').lower() == 'prediction':
510536
raise ValueError("Project '%s' does not have a prediction variable "
511537
"specified." % project)
538+
elif project.get('eventProbabilityVariable', '') == '' and project.get('function', '').lower() == 'classification':
539+
raise ValueError("Project '%s' does not have an Event Probability variable "
540+
"specified." % project)
512541

513542
# Find the performance definition for the model
514543
# As of Viya 3.4, no way to search by model or project

src/sasctl/utils/astore.py

Lines changed: 68 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,21 @@
2020
swat = None
2121

2222

23-
def create_package(table):
23+
def create_package(table, input=None):
2424
"""Create an importable model package from a CAS table.
2525
2626
Parameters
2727
----------
2828
table : swat.CASTable
2929
The CAS table containing an ASTORE or score code.
30+
input : DataFrame, type, list of type, or dict of str: type, optional
31+
The expected type for each input value of the target function.
32+
Can be omitted if target function includes type hints. If a DataFrame
33+
is provided, the columns will be inspected to determine type information.
34+
If a single type is provided, all columns will be assumed to be that type,
35+
otherwise a list of column types or a dictionary of column_name: type
36+
may be provided.
37+
3038
3139
Returns
3240
-------
@@ -45,18 +53,26 @@ def create_package(table):
4553
assert isinstance(table, swat.CASTable)
4654

4755
if 'DataStepSrc' in table.columns:
48-
return create_package_from_datastep(table)
56+
#Input only passed to datastep
57+
return create_package_from_datastep(table, input=input)
4958
else:
5059
return create_package_from_astore(table)
5160

5261

53-
def create_package_from_datastep(table):
62+
def create_package_from_datastep(table, input=None):
5463
"""Create an importable model package from a score code table.
5564
5665
Parameters
5766
----------
5867
table : swat.CASTable
5968
The CAS table containing the score code.
69+
input : DataFrame, type, list of type, or dict of str: type, optional
70+
The expected type for each input value of the target function.
71+
Can be omitted if target function includes type hints. If a DataFrame
72+
is provided, the columns will be inspected to determine type information.
73+
If a single type is provided, all columns will be assumed to be that type,
74+
otherwise a list of column types or a dictionary of column_name: type
75+
may be provided.
6076
6177
Returns
6278
-------
@@ -73,11 +89,59 @@ def create_package_from_datastep(table):
7389

7490
dscode = table.to_frame().loc[0, 'DataStepSrc']
7591

92+
# Extract inputs if provided
93+
input_vars = []
94+
# Workaround because sasdataframe does not like to be check if exist
95+
if str(inp 558 ut) != "None":
96+
from .pymas.python import ds2_variables
97+
vars=None
98+
if hasattr(input, 'columns'):
99+
# Assuming input is a DataFrame representing model inputs. Use to
100+
# get input variables
101+
vars = ds2_variables(input)
102+
elif isinstance(input, type):
103+
params = OrderedDict([(k, input)
104+
for k in target_func.__code__.co_varnames])
105+
vars = ds2_variables(params)
106+
elif isinstance(input, dict):
107+
vars = ds2_variables(input)
108+
if vars:
109+
input_vars = [var.as_model_metadata() for var in vars if not var.out]
110+
111+
#Find outputs from ds code
112+
output_vars=[]
113+
for sasline in dscode.split('\n'):
114+
if sasline.strip().startswith('label'):
115+
output_var=dict()
116+
for tmp in sasline.split('='):
117+
if 'label' in tmp:
118+
ovarname=tmp.split('label')[1].strip()
119+
output_var.update({"name":ovarname})
120+
#Determine type of variable is decimal or string
121+
if "length " + ovarname in dscode:
122+
sastype=dscode.split("length " + ovarname)[1].split(';')[0].strip()
123+
if "$" in sastype:
124+
output_var.update({"type":"string"})
125+
output_var.update({"length":sastype.split("$")[1]})
126+
else:
127+
output_var.update({"type":"decimal"})
128+
output_var.update({"length":sastype})
129+
else:
130+
#If no length for varaible, default is decimal, 8
131+
output_var.update({"type":"decimal"})
132+
output_var.update({"length":8})
133+
else:
134+
output_var.update({"description":tmp.split(';')[0].strip().strip("'")})
135+
output_vars.append(output_var)
136+
76137
file_metadata = [{'role': 'score', 'name': 'dmcas_scorecode.sas'}]
77138

78139
zip_file = _build_zip_from_files({
79140
'fileMetadata.json': file_metadata,
80-
'dmcas_scorecode.sas': dscode
141+
'dmcas_scorecode.sas': dscode,
142+
'ModelProperties.json': {"scoreCodeType":"dataStep"},
143+
'outputVar.json': output_vars,
144+
'inputVar.json': input_vars
81145
})
82146

83147
return zip_file

src/sasctl/utils/pymas/ds2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,7 @@ def _map_type(cls, mapping, t):
339339

340340
def as_model_metadata(self):
341341
viya_type = self._map_type(self.DS2_TYPE_TO_VIYA, self.type)
342-
role = 'Output' if self.out else 'Input'
342+
role = 'Output' if self.out else 'input'
343343

344344
return OrderedDict(
345345
[('name', self.name), ('role', role), ('type', viya_type)])

tests/unit/test_model_management.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,20 @@ def test_create_performance_definition():
5050
with pytest.raises(ValueError):
5151
# Project missing some required properties
5252
get_project.return_value = copy.deepcopy(PROJECT)
53-
get_project.return_value['predictionVariable'] = 'predicted'
53+
get_project.return_value['function'] = 'classification'
54+
_ = mm.create_performance_definition('model', 'TestLibrary', 'TestData')
55+
56+
with pytest.raises(ValueError):
57+
# Project missing some required properties
58+
get_project.return_value = copy.deepcopy(PROJECT)
59+
get_project.return_value['function'] = 'prediction'
5460
_ = mm.create_performance_definition('model', 'TestLibrary', 'TestData')
5561

5662
get_project.return_value = copy.deepcopy(PROJECT)
5763
get_project.return_value['targetVariable'] = 'target'
5864
get_project.return_value['targetLevel'] = 'interval'
5965
get_project.return_value['predictionVariable'] = 'predicted'
66+
get_project.return_value['function'] = 'prediction'
6067
_ = mm.create_performance_definition('model', 'TestLibrary',
6168
'TestData',
6269
max_bins=3,

tests/unit/test_tasks.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,27 +22,27 @@ def test_sklearn_metadata():
2222

2323
info = _sklearn_to_dict(LinearRegression())
2424
assert info['algorithm'] == 'Linear regression'
25-
assert info['function'] == 'Prediction'
25+
assert info['function'] == 'prediction'
2626

2727
info = _sklearn_to_dict(LogisticRegression())
2828
assert info['algorithm'] == 'Logistic regression'
29-
assert info['function'] == 'Classification'
29+
assert info['function'] == 'classification'
3030

3131
info = _sklearn_to_dict(SVC())
3232
assert info['algorithm'] == 'Support vector machine'
33-
assert info['function'] == 'Classification'
33+
assert info['function'] == 'classification'
3434

3535
info = _sklearn_to_dict(GradientBoostingClassifier())
3636
assert info['algorithm'] == 'Gradient boosting'
37-
assert info['function'] == 'Classification'
37+
assert info['function'] == 'classification'
3838

3939
info = _sklearn_to_dict(DecisionTreeClassifier())
4040
assert info['algorithm'] == 'Decision tree'
41-
assert info['function'] == 'Classification'
41+
assert info['function'] == 'classification'
4242

4343
info = _sklearn_to_dict(RandomForestClassifier())
4444
assert info['algorithm'] == 'Forest'
45-
assert info['function'] == 'Classification'
45+
assert info['function'] == 'classification'
4646

4747

4848
def test_parse_module_url():
@@ -96,6 +96,13 @@ def test_save_performance_project_types():
9696
project.return_value = {'function': 'Prediction',
9797
'targetLevel': 'Binary'}
9898
update_model_performance(None, None, None)
99+
100+
# Classification variable required
101+
with pytest.raises(ValueError):
102+
project.return_value = {'function': 'classification',
103+
'targetLevel': 'Binary'}
104+
update_model_performance(None, None, None)
105+
99106

100107
# Check projects w/ invalid properties
101108

0 commit comments

Comments
 (0)
0