8000 fix MM project properties so performance UI capabilities work · teoland/python-sasctl@62d7558 · GitHub
[go: up one dir, main page]

Skip to content

Commit 62d7558

Browse files
author
Cloud User
committed
fix MM project properties so performance UI capabilities work
1 parent a4ec570 commit 62d7558

File tree

4 files changed

+50
-18
lines changed

4 files changed

+50
-18
lines changed

src/sasctl/_services/model_management.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,11 +174,16 @@ def create_performance_definition(cls,
174174

175175
# Performance data cannot be captured unless certain project properties
176176
# have been configured.
177-
for required in ['targetVariable', 'targetLevel',
178-
'predictionVariable']:
177+
for required in ['targetVariable', 'targetLevel']:
179178
if getattr(project, required, None) is None:
180179
raise ValueError("Project %s must have the '%s' property set."
181180
% (project.name, required))
181+
if project['function'] == 'classification' and project['eventProbabilityVariable'] == None:
182+
raise ValueError("Project %s must have the 'eventProbabilityVariable' property set."
183+
% (project.name))
184+
if project['function'] == 'prediction' and project['predictionVariable'] == None:
185+
raise ValueError("Project %s must have the 'predictionVariable' property set."
186+
% (project.name))
182187

183188
request = {'projectId': project.id,
184189
'name': name or model.name + ' Performance',

src/sasctl/tasks.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@ def _sklearn_to_dict(model):
4949
'RandomForestClassifier': 'Forest',
5050
'DecisionTreeClassifier': 'Decision tree',
5151
'DecisionTreeRegressor': 'Decision tree',
52-
'classifier': 'Classification',
53-
'regressor': 'Prediction'}
52+
'classifier': 'classification',
53+
'regressor': 'prediction'}
5454

5555
if hasattr(model, '_final_estimator'):
5656
estimator = type(model._final_estimator)
@@ -302,17 +302,27 @@ def get_version(x):
302302
else:
303303
prediction_variable = None
304304

305-
project = mr.create_project(project, repo_obj,
305+
# As of Viya 3.4 the 'predictionVariable' parameter is not set during
306+
# project creation. Update the project if necessary.
307+
if model.get('function') == 'prediction': #Predications require predictionVariable
308+
project = mr.create_project(project, repo_obj,
306309
variables=vars,
307310
function=model.get('function'),
308311
targetLevel=target_level,
309312
predictionVariable=prediction_variable)
310313

311-
# As of Viya 3.4 the 'predictionVariable' parameter is not set during
312-
# project creation. Update the project if necessary.
313-
if project.get('predictionVariable') != prediction_variable:
314-
project['predictionVariable'] = prediction_variable
315-
mr.update_project(project)
314+
if project.get('predictionVariable') != prediction_variable:
315+
project['predictionVariable'] = prediction_variable
316+
mr.update_project(project)
317+
else: #Classifications require eventProbabilityVariable
318+
project = mr.create_project(project, repo_obj,
319+
variables=vars,
320+
function=model.get('function'),
321+
targetLevel=target_level,
322+
eventProbabilityVariable=prediction_variable)
323+
if project.get('eventProbabilityVariable') != prediction_variable:
324+
project['eventProbabilityVariable'] = prediction_variable
325+
mr.update_project(project)
316326

317327
model = mr.create_model(model, project)
318328

@@ -506,9 +516,12 @@ def update_model_performance(data, model, label, refresh=True):
506516
"regression and binary classification projects. "
507517
"Received project with '%s' target level. Should be "
508518
"'Interval' or 'Binary'.", project.get('targetLevel'))
509-
elif project.get('predictionVariable', '') == '':
519+
elif project.get('predictionVariable', '') == '' and project.get('function', '').lower() == 'prediction':
510520
raise ValueError("Project '%s' does not have a prediction variable "
511521
"specified." % project)
522+
elif project.get('eventProbabilityVariable', '') == '' and project.get('function', '').lower() == 'classification':
523+
raise ValueError("Project '%s' does not have an Event Probability variable "
524+
"specified." % project)
512525

513526
# Find the performance definition for the model
514527
# As of Viya 3.4, no way to search by model or project

tests/unit/test_model_management.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,20 @@ def test_create_performance_definition():
5050
with pytest.raises(ValueError):
5151
# Project missing some required properties
5252
get_project.return_value = copy.deepcopy(PROJECT)
53-
get_project.return_value['predictionVariable'] = 'predicted'
53+
get_project.return_value['function'] = 'classification'
54+
_ = mm.create_performance_definition('model', 'TestLibrary', 'TestData')
55+
56+
with pytest.raises(ValueError):
57+
# Project missing some required properties
58+
get_project.return_value = copy.deepcopy(PROJECT)
59+
get_project.return_value['function'] = 'prediction'
5460
_ = mm.create_performance_definition('model', 'TestLibrary', 'TestData')
5561

5662
get_project.return_value = copy.deepcopy(PROJECT)
5763
get_project.return_value['targetVariable'] = 'target'
5864
get_project.return_value['targetLevel'] = 'interval'
5965
get_project.return_value['predictionVariable'] = 'predicted'
66+
get_project.return_value['function'] = 'prediction'
6067
_ = mm.create_performance_definition('model', 'TestLibrary',
6168
'TestData',
6269
max_bins=3,

tests/unit/test_tasks.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,27 +22,27 @@ def test_sklearn_metadata():
2222

2323
info = _sklearn_to_dict(LinearRegression())
2424
assert info['algorithm'] == 'Linear regression'
25-
assert info['function'] == 'Prediction'
25+
assert info['function'] == 'prediction'
2626

2727
info = _sklearn_to_dict(LogisticRegression())
2828
assert info['algorithm'] == 'Logistic regression'
29-
assert info['function'] == 'Classification'
29+
assert info['function'] == 'classification'
3030

3131
info = _sklearn_to_dict(SVC())
3232
assert info['algorithm'] == 'Support vector machine'
33-
assert info['function'] == 'Classification'
33+
assert info['function'] == 'classification'
3434

3535
info = _sklearn_to_dict(GradientBoostingClassifier())
3636
assert info['algorithm'] == 'Gradient boosting'
37-
assert info['function'] == 'Classification'
37+
assert info['function'] == 'classification'
3838

3939
info = _sklearn_to_dict(DecisionTreeClassifier())
4040
assert info['algorithm'] == 'Decision tree'
41-
assert info['function'] == 'Classification'
41+
assert info['function'] == 'classification'
4242

4343
info = _sklearn_to_dict(RandomForestClassifier())
4444
assert info['algorithm'] == 'Forest'
45-
assert info['function'] == 'Classification'
45+
assert info['function'] == 'classification'
4646

4747

4848
def test_parse_module_url():
@@ -96,6 +96,13 @@ def test_save_performance_project_types():
9696
project.return_value = {'function': 'Prediction',
9797
'targetLevel': 'Binary'}
9898
update_model_performance(None, None, None)
99+
100+
# Classification variable required
101+
with pytest.raises(ValueError):
102+
project.return_value = {'function': 'classification',
103+
'targetLevel': 'Binary'}
104+
update_model_performance(None, None, None)
105+
99106

100107
# Check projects w/ invalid properties
101108

0 commit comments

Comments
 (0)
0