8000 working lift stats · Sivateja0689/python-sasctl@a3ae0a9 · GitHub
[go: up one dir, main page]

Skip to content

Commit a3ae0a9

Browse files
committed
working lift stats
1 parent bb29c00 commit a3ae0a9

File tree

2 files changed

+219
-142
lines changed

2 files changed

+219
-142
lines changed

src/sasctl/tasks.py

Lines changed: 35 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ def _create_project(project_name, model, repo, input_vars=None,
167167

168168
def register_model(model, name, project, repository=None, input=None,
169169
version=None, files=None, force=False,
170+
train=None,
170171
record_packages=True):
171172
"""Register a model in the model repository.
172173
@@ -318,28 +319,44 @@ def register_model(model, name, project, repository=None, input=None,
318319
target_funcs = [f for f in ('predict', 'predict_proba')
319 10000 320
if hasattr(model, f)]
320321

322+
# Save actual model instance
323+
model_obj = model
324+
321325
# Extract model properties
322-
model = _sklearn_to_dict(model)
326+
model = _sklearn_to_dict(model_obj)
323327
model['name'] = name
324328

329+
from .utils.metrics import lift_statistics, roc_statistics, fit_statistics
330+
for name, func in (('dmcas_lift.json', lift_statistics),
331+
('dmcas_fitstats.json', fit_statistics)):
332+
if not any(f['name'] == name for f in files):
333+
stats = func(model_obj, train=train)
334+
files.append({'name': name,
335+
'file': json.dumps(stats)})
336+
337+
# stats = roc_statistics(model_obj, train=train)
338+
# files.append({'name': 'dmcas_roc.json',
339+
# 'file': json.dumps(stats)})
340+
325341
# Get package versions in environment
326-
packages = installed_packages()
327-
if record_packages and packages is not None:
328-
model.setdefault('properties', [])
329-
330-
# Define a custom property to capture each package version
331-
# NOTE: some packages may not conform to the 'name==version' format
332-
# expected here (e.g those installed with pip install -e). Such
333-
# packages also generally contain characters that are not allowed
334-
# in custom properties, so they are excluded here.
335-
for p in packages:
336-
if '==' in p:
337-
n, v = p.split('==')
338-
model['properties'].append(_property('env_%s' % n, v))
339-
340-
# Generate and upload a requirements.txt file
341-
files.append({'name': 'requirements.txt',
342-
'file': '\n'.join(packages)})
342+
if record_packages and not any(f['name'] == 'requirements.txt' for f in files):
343+
packages = installed_packages()
344+
if packages is not None:
345+
model.setdefault('properties', [])
346+
347+
# Define a custom property to capture each package version
348+
# NOTE: some packages may not conform to the 'name==version' format
349+
# expected here (e.g those installed with pip install -e). Such
350+
# packages also generally contain characters that are not allowed
351+
# in custom properties, so they are excluded here.
352+
for p in packages:
353+
if '==' in p:
354+
n, v = p.split('==')
355+
model['properties'].append(_property('env_%s' % n, v))
356+
357+
# Generate and upload a requirements.txt file
358+
files.append({'name': 'requirements.txt',
359+
'file': '\n'.join(packages)})
343360

344361
# Generate PyMAS wrapper
345362
try:

0 commit comments

Comments
 (0)
0