diff --git a/.github/workflows/build-test-deploy.yml b/.github/workflows/build-test-deploy.yml index a0072348..87fac0e2 100644 --- a/.github/workflows/build-test-deploy.yml +++ b/.github/workflows/build-test-deploy.yml @@ -22,8 +22,7 @@ jobs: strategy: matrix: python-version: ['3.8', '3.9', '3.10', '3.11'] - os-version: ['ubuntu-20.04', 'windows-latest', 'macos-latest'] -# Pinned Ubuntu version to 20.04 since no Python 3.6 builds available on ubuntu-latest (22.04) as of 2022-12-7. + os-version: ['ubuntu-latest', 'windows-latest', 'macos-latest'] # os-version: [ubuntu-latest, windows-latest, macos-latest] steps: diff --git a/CHANGELOG.md b/CHANGELOG.md index f955f4e9..706e8dcb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,24 @@ +v1.11.6 (2025-11-18) +-------------------- +**Improvements** +- Added `create_requirements_txt` parameter to `create_requirements_json()` function in `write_json_files.py` to optionally generate a requirements.txt file alongside the requirements.json file. + +v1.11.5 (2025-06-27) +-------------------- +**Improvements** +- Added model versioning methods to `model_repository.py` to handle model version endpoints. +- Allow for user to set custom timeout length for score testing in `score_model_with_cas`. + +v1.11.4 (2025-05-02) +-------------------- +**Improvements** +- Improved `upload_local_model` to allow for SAS Model Manager to properly intake local ASTORE models. + +v1.11.3 (2025-04-29) +-------------------- +**Improvements** +- Added `upload_local_model` to `tasks.py`, which can be used to upload local directories to SAS Model Manager without any file generation. + v1.11.2 (2025-04-08) -------------------- **Bugfixes** diff --git a/examples/pzmm_generate_complete_model_card.ipynb b/examples/pzmm_generate_complete_model_card.ipynb index 3a68271b..60124580 100644 --- a/examples/pzmm_generate_complete_model_card.ipynb +++ b/examples/pzmm_generate_complete_model_card.ipynb @@ -578,7 +578,7 @@ " \"MartialStatus_Married_AF_spouse\", 'MartialStatus_Married_civ_spouse', 'MartialStatus_Never_married', 'MartialStatus_Divorced', 'MartialStatus_Separated', \n", " 'MartialStatus_Widowed', 'Race_White', 'Race_Black', 'Race_Asian_Pac_Islander', 'Race_Amer_Indian_Eskimo', 'Race_Other', 'Relationship_Husband', \n", " 'Relationship_Not_in_family', 'Relationship_Own_child', 'Relationship_Unmarried', 'Relationship_Wife', 'Relationship_Other_relative', 'WorkClass_Private',\n", - " 'Education_Bachelors'\n", + " 'Education_Bachelors', 'Education_Some_college', 'Education_HS_grad'\n", " ]\n", " # OHE columns must be removed after data combination\n", " predictor_columns = ['Age', 'HoursPerWeek', 'WorkClass_Private', 'WorkClass_Self', 'WorkClass_Gov', \n", @@ -1716,12 +1716,14 @@ ], "source": [ "# Step 13: Generate requirements files\n", - "requirements_json = pzmm.JSONFiles.create_requirements_json(output_path)\n", + "requirements_json = pzmm.JSONFiles.create_requirements_json(output_path, create_requirements_txt=False)\n", "\n", "import json\n", "print(json.dumps(requirements_json, sort_keys=True, indent=4))\n", "\n", "for requirement in requirements_json:\n", + " # Example: Replace sklearn with scikit-learn in requirements\n", + " # (This is redundant in newer versions but shows how to modify package names)\n", " if 'sklearn' in requirement['step']:\n", " requirement['command'] = requirement[\"command\"].replace('sklearn', 'scikit-learn')\n", " requirement['step'] = requirement['step'].replace('sklearn', 'scikit-learn')\n", diff --git a/examples/pzmm_generate_requirements_json.ipynb b/examples/pzmm_generate_requirements_json.ipynb index 604ae800..afd6096c 100644 --- a/examples/pzmm_generate_requirements_json.ipynb +++ b/examples/pzmm_generate_requirements_json.ipynb @@ -14,16 +14,18 @@ "id": "e9b8cb7c-1974-4af5-8992-d51f90fcfe5b", "metadata": {}, "source": [ - "# Automatic Generation of the requirements.json File\n", + "# Automatic Generation of the requirements.json or requirements.txt File\n", "In order to validate Python models within a container publishing destination, the Python packages which contain the modules that are used in the Python score code file and its score resource files must be installed in the run-time container. You can install the packages when you publish a Python model or decision that contains a Python model to a container publishing destination by adding a `requirements.json` file that includes the package install statements to your model.\n", "\n", "This notebook provides an example execution and assessment of the create_requirements_json() function added in python-sasctl v1.8.0. The aim of this function is help to create the instructions (aka the `requirements.json` file) for a lightweight Python container in SAS Model Manager. Lightweight here meaning that the container will only install the packages found in the model's pickle files and python scripts.\n", "\n", + "Additionally, the create_requirements_json() function provides an optional parameter `create_requirements_txt` which when set to `True` will generate a requirements.txt file alongside the requirements.json file. By default this option is set to `False`. The requirements.txt file is needed when consuming Python models in SAS Event Stream Processing, which requires this format for package installation in their environment. While SAS Model Manager continues to use the requirements.json format, adding the requirements.txt file ensures compatibility across both platforms. \n", + "\n", "### **User Warnings**\n", "The methods utilized in this function can determine package dependencies and versions from provided scripts and pickle files, but there are some stipulations that need to be considered:\n", "\n", "1. If run outside of the development environment that the model was created in, the create_requirements_json() function **CANNOT** determine the required package _versions_ accurately. \n", - "2. Not all Python packages have matching import and install names and as such some of the packages added to the requirements.json file may be incorrectly named (i.e. `import sklearn` vs `pip install scikit-learn`).\n", + "2. Not all Python packages have matching import and install names and as such some of the packages added to the requirements.json file may be incorrectly named (i.e. `import sklearn` vs `pip install scikit-learn`). Some of the major packages with differing import and install names are automatically converted. \n", "\n", "As such, it is recommended that the user check over the requirements.json file for package name and version accuracy before deploying to a run-time container in SAS Model Manager." ] @@ -63,7 +65,7 @@ "outputs": [], "source": [ "model_dir = Path.cwd() / \"data/hmeqModels/DecisionTreeClassifier\"\n", - "requirements_json = pzmm.JSONFiles.create_requirements_json(model_dir)" + "requirements_json = pzmm.JSONFiles.create_requirements_json(model_dir, create_requirements_txt=False)" ] }, { @@ -145,6 +147,8 @@ ], "source": [ "for requirement in requirements_json:\n", + " # Example: Replace sklearn with scikit-learn in requirements\n", + " # (This is redundant in newer versions but shows how to modify package names)\n", " if 'sklearn' in requirement['step']:\n", " requirement['command'] = requirement[\"command\"].replace('sklearn', 'scikit-learn')\n", " requirement['step'] = requirement['step'].replace('sklearn', 'scikit-learn')\n", diff --git a/src/sasctl/__init__.py b/src/sasctl/__init__.py index e4c7a0e1..be3f3c75 100644 --- a/src/sasctl/__init__.py +++ b/src/sasctl/__init__.py @@ -4,7 +4,7 @@ # Copyright © 2019, SAS Institute Inc., Cary, NC, USA. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -__version__ = "1.11.2" +__version__ = "1.11.6" __author__ = "SAS" __credits__ = [ "Yi Jian Ching", @@ -16,6 +16,7 @@ "Scott Lindauer", "DJ Moore", "Samya Potlapalli", + "Samuel Babak", ] __license__ = "Apache 2.0" __copyright__ = ( diff --git a/src/sasctl/_services/model_management.py b/src/sasctl/_services/model_management.py index e91b1aa6..7950d95e 100644 --- a/src/sasctl/_services/model_management.py +++ b/src/sasctl/_services/model_management.py @@ -28,7 +28,13 @@ class ModelManagement(Service): # TODO: set ds2MultiType @classmethod def publish_model( - cls, model, destination, name=None, force=False, reload_model_table=False + cls, + model, + destination, + model_version="latest", + name=None, + force=False, + reload_model_table=False, ): """ @@ -38,6 +44,8 @@ def publish_model( The name or id of the model, or a dictionary representation of the model. destination : str Name of destination to publish the model to. + model_version : str or dict, optional + Provide the version id, name, or dict to publish. Defaults to 'latest'. name : str, optional Provide a custom name for the published model. Defaults to None. force : bool, optional @@ -68,6 +76,23 @@ def publish_model( # TODO: Verify allowed formats by destination type. # As of 19w04 MAS throws HTTP 500 if name is in invalid format. + if model_version != "latest": + if isinstance(model_version, dict) and "modelVersionName" in model_version: + model_version_name = model_version["modelVersionName"] + elif ( + isinstance(model_version, dict) + and "modelVersionName" not in model_version + ): + raise ValueError("Model version is not recognized.") + elif isinstance(model_version, str) and cls.is_uuid(model_version): + model_version_name = mr.get_model_or_version(model, model_version)[ + "modelVersionName" + ] + else: + model_version_name = model_version + else: + model_version_name = "" + model_name = name or "{}_{}".format( model_obj["name"].replace(" ", ""), model_obj["id"] ).replace("-", "") @@ -79,6 +104,7 @@ def publish_model( { "modelName": mp._publish_name(model_name), "sourceUri": model_uri.get("uri"), + "modelVersionID": model_version_name, "publishLevel": "model", } ], @@ -104,6 +130,7 @@ def create_performance_definition( table_prefix, project=None, models=None, + modelVersions=None, library_name="Public", name=None, description=None, @@ -136,6 +163,8 @@ def create_performance_definition( The name or id of the model(s), or a dictionary representation of the model(s). For multiple models, input a list of model names, or a list of dictionaries. If no models are specified, all models in the project specified will be used. Defaults to None. + modelVersions: str, list, optional + The name of the model version(s). Defaults to None, so all models are latest. library_name : str The library containing the input data, default is 'Public'. name : str, optional @@ -239,10 +268,13 @@ def create_performance_definition( "property set." % project.name ) + # Creating the new array of modelIds with version names appended + updated_models = cls.check_model_versions(models, modelVersions) + request = { "projectId": project.id, "name": name or project.name + " Performance", - "modelIds": [model.id for model in models], + "modelIds": updated_models, "championMonitored": monitor_champion, "challengerMonitored": monitor_challenger, "maxBins": max_bins, @@ -279,7 +311,6 @@ def create_performance_definition( for v in project.get("variables", []) if v.get("role") == "output" ] - return cls.post( "/performanceTasks", json=request, @@ -288,6 +319,57 @@ def create_performance_definition( }, ) + @classmethod + def check_model_versions(cls, models, modelVersions): + """ + Checking if the model version(s) are valid and append to model id accordingly. + + Parameters + ---------- + models: list of str + List of models. + modelVersions : list of str + List of model versions associated with models. + + Returns + ------- + String list + """ + if not modelVersions: + return [model.id for model in models] + + updated_models = [] + if not isinstance(modelVersions, list): + modelVersions = [modelVersions] + + if len(models) < len(modelVersions): + raise ValueError( + "There are too many versions for the amount of models specified." + ) + + modelVersions = modelVersions + [""] * (len(models) - len(modelVersions)) + for model, modelVersionName in zip(models, modelVersions): + + if ( + isinstance(modelVersionName, dict) + and "modelVersionName" in modelVersionName + ): + + modelVersionName = modelVersionName["modelVersionName"] + elif ( + isinstance(modelVersionName, dict) + and "modelVersionName" not in modelVersionName + ): + + raise ValueError("Model version is not recognized.") + + if modelVersionName != "": + updated_models.append(model.id + ":" + modelVersionName) + else: + updated_models.append(model.id) + + return updated_models + @classmethod def execute_performance_definition(cls, definition): """Launches a job to run a performance definition. diff --git a/src/sasctl/_services/model_publish.py b/src/sasctl/_services/model_publish.py index c3fa225f..90f665ad 100644 --- a/src/sasctl/_services/model_publish.py +++ b/src/sasctl/_services/model_publish.py @@ -10,6 +10,7 @@ from .model_repository import ModelRepository from .service import Service +from ..utils.decorators import deprecated class ModelPublish(Service): @@ -90,7 +91,7 @@ def delete_destination(cls, item): return cls.delete("/destinations/{name}".format(name=item)) - @classmethod + @deprecated("Use publish_model in model_management.py instead.", "1.11.5") def publish_model(cls, model, destination, name=None, code=None, notes=None): """Publish a model to an existing publishing destination. diff --git a/src/sasctl/_services/model_repository.py b/src/sasctl/_services/model_repository.py index dfbbb95d..74fb6446 100644 --- a/src/sasctl/_services/model_repository.py +++ b/src/sasctl/_services/model_repository.py @@ -8,10 +8,17 @@ import datetime from warnings import warn +import requests +from requests.exceptions import HTTPError +import urllib -from ..core import HTTPError, current_session, delete, get, sasctl_command +# import traceback +# import sys + +from ..core import current_session, delete, get, sasctl_command, RestObj from .service import Service + FUNCTIONS = { "Analytical", "Classification", @@ -615,11 +622,222 @@ def list_model_versions(cls, model): list """ - model = cls.get_model(model) - if cls.get_model_link(model, "modelVersions") is None: - raise ValueError("Unable to retrieve versions for model '%s'" % model) - return cls.request_link(model, "modelVersions") + if current_session().version_info() < 4: + model = cls.get_model(model) + if cls.get_model_link(model, "modelVersions") is None: + raise ValueError("Unable to retrieve versions for model '%s'" % model) + + return cls.request_link(model, "modelVersions") + else: + link = cls.get_model_link(model, "modelHistory") + if link is None: + raise ValueError( + "Cannot find link for version history for model '%s'" % model + ) + + modelHistory = cls.request_link( + link, + "modelHistory", + headers={"Accept": "application/vnd.sas.collection+json"}, + ) + + if isinstance(modelHistory, RestObj): + return [modelHistory] + return modelHistory + + @classmethod + def get_model_version(cls, model, version_id): + """Get a specific version of a model. + + Parameters + ---------- + model : str or dict + The name, id, or dictionary representation of a model. + version_id: str + The id of a model version. + + Returns + ------- + RestObj + + """ + + model_history = cls.list_model_versions(model) + + for item in model_history: + if item["id"] == version_id: + return cls.request_link( + item, + "self", + headers={"Accept": "application/vnd.sas.models.model.version+json"}, + ) + + raise ValueError("The version id specified could not be found.") + + @classmethod + def get_model_with_versions(cls, model): + """Get the current model with its version history. + + Parameters + ---------- + model : str or dict + The name, id, or dictionary representation of a model. + + Returns + ------- + list + + """ + + if cls.is_uuid(model): + model_id = model + elif isinstance(model, dict) and "id" in model: + model_id = model["id"] + else: + model = cls.get_model(model) + if not model: + raise HTTPError( + "This model may not exist in a project or the model may not exist at all." + ) + model_id = model["id"] + + versions_uri = f"/models/{model_id}/versions" + try: + version_history = cls.request( + "GET", + versions_uri, + headers={"Accept": "application/vnd.sas.collection+json"}, + ) + except urllib.error.HTTPError as e: + raise HTTPError( + f"Request failed: Model id may be referencing a non-existing model." + ) from None + + if isinstance(version_history, RestObj): + return [version_history] + + return version_history + + @classmethod + def get_model_or_version(cls, model, version_id): + """Get a specific version of a model but if model id and version id are the same, the current model is returned. + + Parameters + ---------- + model : str or dict + The name, id, or dictionary representation of a model. + version_id: str + The id of a model version. + + Returns + ------- + RestObj + + """ + + version_history = cls.get_model_with_versions(model) + + for item in version_history: + if item["id"] == version_id: + return cls.request_link( + item, + "self", + headers={ + "Accept": "application/vnd.sas.models.model.version+json, application/vnd.sas.models.model+json" + }, + ) + + raise ValueError("The version id specified could not be found.") + + @classmethod + def get_model_version_contents(cls, model, version_id): + """Get the contents of a model version. + + Parameters + ---------- + model : str or dict + The name, id, or dictionary representation of a model. + version_id: str + The id of a model version. + + Returns + ------- + list + + """ + model_version = cls.get_model_version(model, version_id) + version_contents = cls.request_link( + model_version, + "contents", + headers={"Accept": "application/vnd.sas.collection+json"}, + ) + + if isinstance(version_contents, RestObj): + return [version_contents] + + return version_contents + + @classmethod + def get_model_version_content_metadata(cls, model, version_id, content_id): + """Get the content metadata header information for a model version. + + Parameters + ---------- + model : str or dict + The name, id, or dictionary representation of a model. + version_id: str + The id of a model version. + content_id: str + The id of the content file. + + Returns + ------- + RestObj + + """ + model_version_contents = cls.get_model_version_contents(model, version_id) + + for item in model_version_contents: + if item["id"] == content_id: + return cls.request_link( + item, + "self", + headers={"Accept": "application/vnd.sas.models.model.content+json"}, + ) + + raise ValueError("The content id specified could not be found.") + + @classmethod + def get_model_version_content(cls, model, version_id, content_id): + """Get the specific content inside the content file for a model version. + + Parameters + ---------- + model : str or dict + The name, id, or dictionary representation of a model. + version_id: str + The id of a model version. + content_id: str + The id of the specific content file. + + Returns + ------- + list + + """ + + metadata = cls.get_model_version_content_metadata(model, version_id, content_id) + version_content_file = cls.request_link( + metadata, "content", headers={"Accept": "text/plain"} + ) + + if version_content_file is None: + raise HTTPError("Something went wrong while accessing the metadata file.") + + if isinstance(version_content_file, RestObj): + return [version_content_file] + return version_content_file @classmethod def copy_analytic_store(cls, model): diff --git a/src/sasctl/_services/score_definitions.py b/src/sasctl/_services/score_definitions.py index 448a28c5..05733d2b 100644 --- a/src/sasctl/_services/score_definitions.py +++ b/src/sasctl/_services/score_definitions.py @@ -46,7 +46,7 @@ def create_score_definition( description: str = "", server_name: str = "cas-shared-default", library_name: str = "Public", - model_version: str = "latest", + model_version: Union[str, dict] = "latest", ): """Creates the score definition service. @@ -69,7 +69,7 @@ def create_score_definition( library_name: str, optional The library within the CAS server the table exists in. Defaults to "Public". model_version: str, optional - The user-chosen version of the model with the specified model_id. Defaults to "latest". + The user-chosen version of the model. Deafaults to "latest". Returns ------- @@ -83,9 +83,7 @@ def create_score_definition( else: object_descriptor_type = "sas.models.model.ds2" - if cls._model_repository.is_uuid(model): - model_id = model - elif isinstance(model, dict) and "id" in model: + if isinstance(model, dict) and "id" in model: model_id = model["id"] else: model = cls._model_repository.get_model(model) @@ -118,7 +116,7 @@ def create_score_definition( table = cls._cas_management.get_table(table_name, library_name, server_name) if not table and not table_file: raise HTTPError( - f"This table may not exist in CAS. Please include the `table_file` argument in the function call if it doesn't exist." + "This table may not exist in CAS. Include the `table_file` argument." ) elif not table and table_file: cls._cas_management.upload_file( @@ -127,16 +125,19 @@ def create_score_definition( table = cls._cas_management.get_table(table_name, library_name, server_name) if not table: raise HTTPError( - f"The file failed to upload properly or another error occurred." + "The file failed to upload properly or another error occurred." ) # Checks if the inputted table exists, and if not, uploads a file to create a new table + object_uri, model_version = cls.check_model_version(model_id, model_version) + # Checks if the model version is valid and how to find the name + save_score_def = { "name": model_name, # used to be score_def_name "description": description, "objectDescriptor": { - "uri": f"/modelManagement/models/{model_id}", - "name": f"{model_name}({model_version})", + "uri": object_uri, + "name": f"{model_name} ({model_version})", "type": f"{object_descriptor_type}", }, "inputData": { @@ -151,7 +152,7 @@ def create_score_definition( "projectUri": f"/modelRepository/projects/{model_project_id}", "projectVersionUri": f"/modelRepository/projects/{model_project_id}/projectVersions/{model_project_version_id}", "publishDestination": "", - "versionedModel": f"{model_name}({model_version})", + "versionedModel": f"{model_name} ({model_version})", }, "mappings": inputMapping, } @@ -163,3 +164,37 @@ def create_score_definition( "/definitions", data=json.dumps(save_score_def), headers=headers_score_def ) # The response information of the score definition can be seen as a JSON as well as a RestOBJ + + @classmethod + def check_model_version(cls, model_id: str, model_version: Union[str, dict]): + """Checks if the model version is valid. + + Parameters + ---------- + model_version : str or dict + The model version to check. + + Returns + ------- + String tuple + """ + if model_version != "latest": + + if isinstance(model_version, dict) and "modelVersionName" in model_version: + model_version = model_version["modelVersionName"] + elif ( + isinstance(model_version, dict) + and "modelVersionName" not in model_version + ): + raise ValueError("Model version cannot be found.") + elif isinstance(model_version, str) and cls.is_uuid(model_version): + model_version = cls._model_repository.get_model_or_version( + model_id, model_version + )["modelVersionName"] + + object_uri = f"/modelManagement/models/{model_id}/versions/@{model_version}" + + else: + object_uri = f"/modelManagement/models/{model_id}" + + return object_uri, model_version diff --git a/src/sasctl/pzmm/write_json_files.py b/src/sasctl/pzmm/write_json_files.py index 1c0c560d..8eb98bf9 100644 --- a/src/sasctl/pzmm/write_json_files.py +++ b/src/sasctl/pzmm/write_json_files.py @@ -1614,6 +1614,7 @@ def create_requirements_json( cls, model_path: Union[str, Path, None] = Path.cwd(), output_path: Union[str, Path, None] = None, + create_requirements_txt: bool = False, ) -> Union[dict, None]: """ Searches the model directory for Python scripts and pickle files and @@ -1636,7 +1637,11 @@ def create_requirements_json( environment. When provided with an output_path argument, this function outputs a JSON file - named "requirements.json". Otherwise, a list of dicts is returned. + named "requirements.json". If create_requirements_txt is True, it will also + create a requirements.txt file. Otherwise, a list of dicts is returned. + + Note: requirements.txt file is only created when both output_path and + create_requirements_txt are specified. Parameters ---------- @@ -1644,6 +1649,10 @@ def create_requirements_json( The path to a Python project, by default the current working directory. output_path : str or pathlib.Path, optional The path for the output requirements.json file. The default value is None. + create_requirements_txt : bool, optional + Whether to also create a requirements.txt file in addition to the + requirements.json file. This is useful for SAS Event Stream Processing + environments. The default value is False. Returns ------- @@ -1662,11 +1671,57 @@ def create_requirements_json( package_list = list(set(list(_flatten(package_list)))) package_list = cls.remove_standard_library_packages(package_list) package_and_version = cls.get_local_package_version(package_list) + # Identify packages with missing versions missing_package_versions = [ item[0] for item in package_and_version if not item[1] ] + IMPORT_TO_INSTALL_MAPPING = { + # Data Science & ML Core + "sklearn": "scikit-learn", + "skimage": "scikit-image", + "cv2": "opencv-python", + "PIL": "Pillow", + # Data Formats & Parsing + "yaml": "PyYAML", + "bs4": "beautifulsoup4", + "docx": "python-docx", + "pptx": "python-pptx", + # Date & Time Utilities + "dateutil": "python-dateutil", + # Database Connectors + "MySQLdb": "MySQL-python", + "psycopg2": "psycopg2-binary", + # System & Platform + "win32api": "pywin32", + "win32com": "pywin32", + # Scientific Libraries + "Bio": "biopython", + } + + # Map import names to their corresponding package installation names + package_and_version = [ + (IMPORT_TO_INSTALL_MAPPING.get(name, name), version) + for name, version in package_and_version + ] + + if create_requirements_txt: + requirements_txt = "" + if missing_package_versions: + requirements_txt += "# Warning- The existence and/or versions for the following packages could not be determined:\n" + requirements_txt += "# " + ", ".join(missing_package_versions) + "\n" + + for package, version in package_and_version: + if version: + requirements_txt += f"{package}=={version}\n" + + if output_path: + with open( # skipcq: PTC-W6004 + Path(output_path) / "requirements.txt", "w" + ) as file: + file.write(requirements_txt) + # Create a list of dicts related to each package or warning json_dicts = [] if missing_package_versions: @@ -1800,16 +1855,16 @@ def find_imports(file_path: Union[str, Path]) -> List[str]: file_text = file.read() # Parse the file to get the abstract syntax tree representation tree = ast.parse(file_text) - modules = [] + modules = set() # Walk through each node in the ast to find import calls for node in ast.walk(tree): # Determine parent module for `from * import *` calls if isinstance(node, ast.ImportFrom): - modules.append(node.module) + modules.add(node.module.split(".")[0]) elif isinstance(node, ast.Import): for name in node.names: - modules.append(name.name) + modules.add(name.name.split(".")[0]) modules = list(set(modules)) try: diff --git a/src/sasctl/tasks.py b/src/sasctl/tasks.py index d466c10f..ca659630 100644 --- a/src/sasctl/tasks.py +++ b/src/sasctl/tasks.py @@ -15,6 +15,7 @@ from pathlib import Path from typing import Union from warnings import warn +import zipfile import pandas as pd @@ -264,10 +265,9 @@ def _register_sas_model( out_var = [] in_var = [] import copy - import zipfile as zp zip_file_copy = copy.deepcopy(zip_file) - tmp_zip = zp.ZipFile(zip_file_copy) + tmp_zip = zipfile.ZipFile(zip_file_copy) if "outputVar.json" in tmp_zip.namelist(): out_var = json.loads( tmp_zip.read("outputVar.json").decode("utf=8") @@ -327,8 +327,8 @@ def _register_sas_model( if current_session().version_info() < 4: # Upload the model as a ZIP file if using Viya 3. - zipfile = utils.create_package(model, input=input) - model = mr.import_model_from_zip(name, project, zipfile, version=version) + zip_file = utils.create_package(model, input=input) + model = mr.import_model_from_zip(name, project, zip_file, version=version) else: # If using Viya 4, just upload the raw AStore and Model Manager will handle inspection. astore = cas.astore.download(rstore=model) @@ -981,6 +981,7 @@ def score_model_with_cas( library_name: str = "Public", model_version: str = "latest", use_cas_gateway: bool = False, + timeout: int = 300, ): score_definition = sd.create_score_definition( score_def_name, @@ -994,7 +995,85 @@ def score_model_with_cas( use_cas_gateway=use_cas_gateway, ) score_execution = se.create_score_execution(score_definition.id) - score_execution_poll = se.poll_score_execution_state(score_execution) + score_execution_poll = se.poll_score_execution_state(score_execution, timeout) print(score_execution_poll) score_results = se.get_score_execution_results(score_execution, use_cas_gateway) return score_results + + +def upload_local_model( + path: Union[str, Path], + model_name: str, + project_name: str, + repo_name: Union[str, dict] = None, + version: str = "latest", +): + """A function to upload a model and any associated files to the model repository. + Parameters + ---------- + path : Union[str, Path] + The path to the model and any associated files. + model_name : str + The name of the model. + project_name : str + The name of the project to which the model will be uploaded. + repo_name : Union[str, dict], optional + repository in which to create the project + version: str, optional + The version of the model being uploaded. Defaults to 'latest'. For new model version, use 'new'. + """ + # Use default repository if not specified + try: + if repo_name is None: + repository = mr.default_repository() + else: + repository = mr.get_repository(repo_name) + except HTTPError as e: + if e.code == 403: + raise AuthorizationError( + "Unable to register model. User account does not have read permissions " + "for the /modelRepository/repositories/ URL. Please contact your SAS " + "Viya administrator." + ) + raise e + + # Unable to find or create the repo. + if not repository and not repo_name: + raise ValueError("Unable to find a default repository") + if not repository: + raise ValueError(f"Unable to find repository '{repo_name}'") + + # Get project from repo if it exists; if it doesn't, create a new one + p = mr.get_project(project_name) + if p is None: + p = mr.create_project(project_name, repository) + + # zip up all files in directory (except any previous zip files) + zip_name = str(Path(path) / (model_name + ".zip")) + file_names = sorted(Path(path).glob("*[!(zip|sasast)]")) + sasast_file = next(Path(path).glob("*.sasast"), None) + if sasast_file: + # If a sasast file is present, upload it as well + with open(sasast_file, "rb") as sasast: + sasast_model = sasast.read() + data = { + "name": model_name, + "projectId": p.id, + "type": "ASTORE", + "versionOption": version, + } + files = {"files": (sasast_file.name, sasast_model)} + model = mr.post("/models", files=files, data=data) + for file in file_names: + with open(file, "r") as f: + mr.add_model_content(model, f, file.name) + else: + with zipfile.ZipFile(str(zip_name), mode="w") as zFile: + for file in file_names: + zFile.write(str(file), arcname=file.name) + # upload zipped model + with open(zip_name, "rb") as zip_file: + model = mr.import_model_from_zip( + model_name, project_name, zip_file, version=version + ) + return model diff --git a/tests/unit/test_model_management.py b/tests/unit/test_model_management.py index fbd4fc36..834b0ecc 100644 --- a/tests/unit/test_model_management.py +++ b/tests/unit/test_model_management.py @@ -23,6 +23,8 @@ def test_create_performance_definition(): RestObj({"name": "Test Model 2", "id": "67890", "projectId": PROJECT["id"]}), ] USER = "username" + VERSION_MOCK = {"modelVersionName": "1.0"} + VERSION_MOCK_NONAME = {} with mock.patch("sasctl.core.Session._get_authorization_token"): current_session("example.com", USER, "password") @@ -111,6 +113,32 @@ def test_create_performance_definition(): table_prefix="TestData", ) + with pytest.raises(ValueError): + # Model verions exceeds models + get_model.side_effect = copy.deepcopy(MODELS) + _ = mm.create_performance_definition( + models=["model1", "model2"], + modelVersions=["1.0", "2.0", "3.0"], + library_name="TestLibrary", + table_prefix="TestData", + max_bins=3, + monitor_challenger=True, + monitor_champion=True, + ) + + with pytest.raises(ValueError): + # Model version dictionary missing modelVersionName + get_model.side_effect = copy.deepcopy(MODELS) + _ = mm.create_performance_definition( + models=["model1", "model2"], + modelVersions=VERSION_MOCK_NONAME, + library_name="TestLibrary", + table_prefix="TestData", + max_bins=3, + monitor_challenger=True, + monitor_champion=True, + ) + get_project.return_value = copy.deepcopy(PROJECT) get_project.return_value["targetVariable"] = "target" get_project.return_value["targetLevel"] = "interval" @@ -125,21 +153,68 @@ def test_create_performance_definition(): monitor_challenger=True, monitor_champion=True, ) + url, data = post_models.call_args + assert post_models.call_count == 1 + assert PROJECT["id"] == data["json"]["projectId"] + assert MODELS[0]["id"] in data["json"]["modelIds"] + assert MODELS[1]["id"] in data["json"]["modelIds"] + assert "TestLibrary" == data["json"]["dataLibrary"] + assert "TestData" == data["json"]["dataPrefix"] + assert "cas-shared-default" == data["json"]["casServerId"] + assert data["json"]["name"] + assert data["json"]["description"] + assert data["json"]["maxBins"] == 3 + assert data["json"]["championMonitored"] is True + assert data["json"]["challengerMonitored"] is True - assert post_models.call_count == 1 - url, data = post_models.call_args - - assert PROJECT["id"] == data["json"]["projectId"] - assert MODELS[0]["id"] in data["json"]["modelIds"] - assert MODELS[1]["id"] in data["json"]["modelIds"] - assert "TestLibrary" == data["json"]["dataLibrary"] - assert "TestData" == data["json"]["dataPrefix"] - assert "cas-shared-default" == data["json"]["casServerId"] - assert data["json"]["name"] - assert data["json"]["description"] - assert data["json"]["maxBins"] == 3 - assert data["json"]["championMonitored"] is True - assert data["json"]["challengerMonitored"] is True + get_model.side_effect = copy.deepcopy(MODELS) + _ = mm.create_performance_definition( + # One model version as a string name + models=["model1", "model2"], + modelVersions="1.0", + library_name="TestLibrary", + table_prefix="TestData", + max_bins=3, + monitor_challenger=True, + monitor_champion=True, + ) + + assert post_models.call_count == 2 + url, data = post_models.call_args + assert f"{MODELS[0]['id']}:1.0" in data["json"]["modelIds"] + assert MODELS[1]["id"] in data["json"]["modelIds"] + + get_model.side_effect = copy.deepcopy(MODELS) + # List of string type model versions + _ = mm.create_performance_definition( + models=["model1", "model2"], + modelVersions=["1.0", "2.0"], + library_name="TestLibrary", + table_prefix="TestData", + max_bins=3, + monitor_challenger=True, + monitor_champion=True, + ) + assert post_models.call_count == 3 + url, data = post_models.call_args + assert f"{MODELS[0]['id']}:1.0" in data["json"]["modelIds"] + assert f"{MODELS[1]['id']}:2.0" in data["json"]["modelIds"] + + get_model.side_effect = copy.deepcopy(MODELS) + # List of dictionary type and string type model versions + _ = mm.create_performance_definition( + models=["model1", "model2"], + modelVersions=[VERSION_MOCK, "2.0"], + library_name="TestLibrary", + table_prefix="TestData", + max_bins=3, + monitor_challenger=True, + monitor_champion=True, + ) + assert post_models.call_count == 4 + url, data = post_models.call_args + assert f"{MODELS[0]['id']}:1.0" in data["json"]["modelIds"] + assert f"{MODELS[1]['id']}:2.0" in data["json"]["modelIds"] with mock.patch( "sasctl._services.model_management.ModelManagement" ".post" @@ -160,20 +235,39 @@ def test_create_performance_definition(): monitor_champion=True, ) - assert post_project.call_count == 1 - url, data = post_project.call_args - - assert PROJECT["id"] == data["json"]["projectId"] - assert MODELS[0]["id"] in data["json"]["modelIds"] - assert MODELS[1]["id"] in data["json"]["modelIds"] - assert "TestLibrary" == data["json"]["dataLibrary"] - assert "TestData" == data["json"]["dataPrefix"] - assert "cas-shared-default" == data["json"]["casServerId"] - assert data["json"]["name"] - assert data["json"]["description"] - assert data["json"]["maxBins"] == 3 - assert data["json"]["championMonitored"] is True - assert data["json"]["challengerMonitored"] is True + # one extra test for project with version id + + assert post_project.call_count == 1 + url, data = post_project.call_args + + assert PROJECT["id"] == data["json"]["projectId"] + assert MODELS[0]["id"] in data["json"]["modelIds"] + assert MODELS[1]["id"] in data["json"]["modelIds"] + assert "TestLibrary" == data["json"]["dataLibrary"] + assert "TestData" == data["json"]["dataPrefix"] + assert "cas-shared-default" == data["json"]["casServerId"] + assert data["json"]["name"] + assert data["json"]["description"] + assert data["json"]["maxBins"] == 3 + assert data["json"]["championMonitored"] is True + assert data["json"]["challengerMonitored"] is True + + get_model.side_effect = copy.deepcopy(MODELS) + # Project with model version + _ = mm.create_performance_definition( + project="project", + modelVersions="2.0", + library_name="TestLibrary", + table_prefix="TestData", + max_bins=3, + monitor_challenger=True, + monitor_champion=True, + ) + + assert post_project.call_count == 2 + url, data = post_project.call_args + assert f"{MODELS[0]['id']}:2.0" in data["json"]["modelIds"] + assert MODELS[1]["id"] in data["json"]["modelIds"] def test_table_prefix_format(): with pytest.raises(ValueError): diff --git a/tests/unit/test_model_repository.py b/tests/unit/test_model_repository.py index 9232896b..bf4f9284 100644 --- a/tests/unit/test_model_repository.py +++ b/tests/unit/test_model_repository.py @@ -13,6 +13,10 @@ from sasctl import current_session from sasctl.services import model_repository as mr +from sasctl.core import RestObj, VersionInfo, request +from requests import HTTPError +import urllib.error + def test_create_model(): MODEL_NAME = "Test Model" @@ -230,3 +234,343 @@ def test_add_model_content(): assert post.call_args[1]["files"] == { "files": ("test.pkl", binary_data, "application/image") } + + +def test_create_model_version(): + model_mock = {"id": 12345} + new_model_mock = {"id": 34567} + with mock.patch( + "sasctl._services.model_repository.ModelRepository.get_model", + side_effect=[ + model_mock, + model_mock, + new_model_mock, + model_mock, + new_model_mock, + ], + ) as get_model: + with mock.patch( + "sasctl._services.model_repository.ModelRepository.get_model_link" + ) as get_model_link: + with mock.patch( + "sasctl._services.model_repository.ModelRepository.request_link" + ) as request_link: + get_model_link_mock = { + "method": "GET", + "rel": "modelHistory", + "href": "/modelRepository/models/12345/history", + "uri": "/modelRepository/models/12345/history", + "type": "application/vnd.sas.collection", + "responseItemType": "application/vnd.sas.models.model.version", + } + + get_model_link.return_value = None + with pytest.raises(ValueError): + mr.create_model_version(model=model_mock, minor=False) + + get_model_link.return_value = get_model_link_mock + response = mr.create_model_version(model=model_mock, minor=False) + + request_link.assert_called_with( + model_mock, "addModelVersion", json={"option": "major"} + ) + assert response == new_model_mock + + response = mr.create_model_version(model=model_mock, minor=True) + request_link.assert_called_with( + model_mock, "addModelVersion", json={"option": "minor"} + ) + assert response == new_model_mock + + +def test_list_model_versions(): + with mock.patch( + "sasctl._services.model_repository.ModelRepository.get_model_link" + ) as get_model_link: + with mock.patch( + "sasctl._services.model_repository.ModelRepository.request_link" + ) as request_link: + with mock.patch("sasctl.core.Session.version_info") as version: + version.return_value = VersionInfo(4) + get_model_link.return_value = None + with pytest.raises(ValueError): + mr.list_model_versions( + model="12345", + ) + + get_model_link_mock = { + "method": "GET", + "rel": "modelHistory", + "href": "/modelRepository/models/12345/history", + "uri": "/modelRepository/models/12345/history", + "type": "application/vnd.sas.collection", + "responseItemType": "application/vnd.sas.models.model.version", + } + + get_model_link.return_value = get_model_link_mock + + response = mr.list_model_versions(model="12345") + assert response + + request_link.return_value = RestObj({"id": "12345"}) + response = mr.list_model_versions(model="12345") + assert isinstance(response, list) + + request_link.return_value = [ + RestObj({"id": "12345"}), + RestObj({"id": "3456"}), + ] + response = mr.list_model_versions(model="12345") + assert isinstance(response, list) + + +def test_get_model_version(): + with mock.patch( + "sasctl._services.model_repository.ModelRepository.list_model_versions" + ) as list_model_versions: + with mock.patch( + "sasctl._services.model_repository.ModelRepository.request_link" + ) as request_link: + + list_model_versions_mock = [ + { + "id": "123", + "links": [ + { + "method": "GET", + "rel": "self", + "href": "/modelRepository/models/abc/history/123", + "uri": "/modelRepository/models/abc/history/123", + "type": "demo", + } + ], + }, + {"id": "345", "links": []}, + ] + + list_model_versions.return_value = list_model_versions_mock + + with pytest.raises(ValueError): + mr.get_model_version(model="000", version_id="000") + + response = mr.get_model_version(model="000", version_id="123") + request_link.assert_called_once_with( + list_model_versions_mock[0], + "self", + headers={"Accept": "application/vnd.sas.models.model.version+json"}, + ) + + +def test_get_model_with_versions(): + with mock.patch( + "sasctl._services.model_repository.ModelRepository.is_uuid" + ) as is_uuid: + with mock.patch( + "sasctl._services.model_repository.ModelRepository.get_model" + ) as get_model: + with mock.patch( + "sasctl._services.model_repository.ModelRepository.request" + ) as request: + + is_uuid.return_value = True + response = mr.get_model_with_versions(model="12345") + assert response + + is_uuid.return_value = False + get_model.return_value = None + response = mr.get_model_with_versions(model={"id": "12345"}) + assert response + + is_uuid.return_value = False + get_model.return_value = None + with pytest.raises(HTTPError): + mr.get_model_with_versions(model=RestObj) + + is_uuid.return_value = False + get_model.return_value = RestObj({"id": "123456"}) + request.side_effect = urllib.error.HTTPError( + url="http://demo.sas.com", + code=404, + msg="Not Found", + hdrs=None, + fp=None, + ) + with pytest.raises(HTTPError): + mr.get_model_with_versions(model=RestObj) + + request.side_effect = None + request.return_value = RestObj({"id": "12345"}) + response = mr.get_model_with_versions(model=RestObj) + assert isinstance(response, list) + + request.return_value = [ + RestObj({"id": "12345"}), + RestObj({"id": "3456"}), + ] + response = mr.get_model_with_versions(model=RestObj) + assert isinstance(response, list) + + request.assert_any_call( + "GET", + "/models/123456/versions", + headers={"Accept": "application/vnd.sas.collection+json"}, + ) + + request.assert_any_call( + "GET", + "/models/12345/versions", + headers={"Accept": "application/vnd.sas.collection+json"}, + ) + + +def test_get_model_or_version(): + with mock.patch( + "sasctl._services.model_repository.ModelRepository.get_model_with_versions" + ) as get_model_with_versions: + with mock.patch( + "sasctl._services.model_repository.ModelRepository.request_link" + ) as request_link: + + get_model_with_versions_mock = [ + { + "id": "123", + "links": [ + { + "method": "GET", + "rel": "self", + "href": "/modelRepository/models/abc/history/123", + "uri": "/modelRepository/models/abc/history/123", + "type": "demo", + } + ], + }, + {"id": "345", "links": []}, + ] + + get_model_with_versions.return_value = [] + with pytest.raises(ValueError): + mr.get_model_or_version(model="000", version_id="000") + + get_model_with_versions.return_value = get_model_with_versions_mock + with pytest.raises(ValueError): + mr.get_model_or_version(model="000", version_id="000") + + response = mr.get_model_or_version(model="000", version_id="123") + request_link.assert_called_once_with( + get_model_with_versions_mock[0], + "self", + headers={ + "Accept": "application/vnd.sas.models.model.version+json, application/vnd.sas.models.model+json" + }, + ) + + +def test_get_model_version_contents(): + with mock.patch( + "sasctl._services.model_repository.ModelRepository.get_model_version" + ) as get_model_version: + with mock.patch( + "sasctl._services.model_repository.ModelRepository.request_link" + ) as request_link: + + get_model_version.return_value = {"id": "000"} + request_link.return_value = RestObj({"id": "12345"}) + response = mr.get_model_version_contents(model="12345", version_id="3456") + assert isinstance(response, list) + + request_link.return_value = [ + RestObj({"id": "12345"}), + RestObj({"id": "3456"}), + ] + response = mr.get_model_version_contents(model="12345", version_id="3456") + assert isinstance(response, list) + + request_link.assert_any_call( + {"id": "000"}, + "contents", + headers={"Accept": "application/vnd.sas.collection+json"}, + ) + + +def test_get_model_version_content_metadata(): + with mock.patch( + "sasctl._services.model_repository.ModelRepository.get_model_version_contents" + ) as get_model_version_contents: + with mock.patch( + "sasctl._services.model_repository.ModelRepository.request_link" + ) as request_link: + + get_model_with_metadata_mock = [ + { + "id": "123", + "links": [ + { + "method": "GET", + "rel": "self", + "href": "/modelRepository/models/abc/history/123", + "uri": "/modelRepository/models/abc/history/123", + "type": "demo", + } + ], + }, + {"id": "345", "links": []}, + ] + + get_model_version_contents.return_value = [] + with pytest.raises(ValueError): + mr.get_model_version_content_metadata( + model="000", version_id="123", content_id="000" + ) + + get_model_version_contents.return_value = get_model_with_metadata_mock + with pytest.raises(ValueError): + mr.get_model_version_content_metadata( + model="abc", version_id="123", content_id="000" + ) + + response = mr.get_model_version_content_metadata( + model="abc", version_id="123", content_id="345" + ) + assert response + request_link.assert_called_once_with( + get_model_with_metadata_mock[1], + "self", + headers={"Accept": "application/vnd.sas.models.model.content+json"}, + ) + + +def test_get_model_version_content(): + with mock.patch( + "sasctl._services.model_repository.ModelRepository.get_model_version_content_metadata" + ) as get_model_version_content_metadata: + with mock.patch( + "sasctl._services.model_repository.ModelRepository.request_link" + ) as request_link: + + get_model_version_content_metadata.return_value = {"id": 000} + request_link.return_value = None + with pytest.raises(HTTPError): + mr.get_model_version_content( + model="abc", version_id="123", content_id="345" + ) + + request_link.return_value = RestObj({"id": "12345"}) + response = mr.get_model_version_content( + model="abc", version_id="123", content_id="345" + ) + assert isinstance(response, list) + + request_link.return_value = [ + RestObj({"id": "12345"}), + RestObj({"id": "3456"}), + ] + response = mr.get_model_version_content( + model="abc", version_id="123", content_id="345" + ) + assert isinstance(response, list) + + request_link.assert_any_call( + {"id": 000}, + "content", + headers={"Accept": "text/plain"}, + ) diff --git a/tests/unit/test_score_definitions.py b/tests/unit/test_score_definitions.py index d1210866..1ebdc462 100644 --- a/tests/unit/test_score_definitions.py +++ b/tests/unit/test_score_definitions.py @@ -63,89 +63,190 @@ def test_create_score_definition(): "sasctl._services.cas_management.CASManagement.upload_file" ) as upload_file: with mock.patch( - "sasctl._services.score_definitions.ScoreDefinitions.post" - ) as post: - # Invalid model id test case - get_model.return_value = None - with pytest.raises(HTTPError): - sd.create_score_definition( - score_def_name="test_create_sd", - model="12345", - table_name="test_table", - ) - # Valid model id but invalid table name with no table_file argument test case - get_model_mock = { - "id": "12345", - "projectId": "54321", - "projectVersionId": "67890", - "name": "test_model", - } - get_model.return_value = get_model_mock - get_table.return_value = None - with pytest.raises(HTTPError): - sd.create_score_definition( - score_def_name="test_create_sd", - model="12345", - table_name="test_table", - ) - - # Invalid table name with a table_file argument that doesn't work test case - get_table.return_value = None - upload_file.return_value = None - get_table.return_value = None - with pytest.raises(HTTPError): - sd.create_score_definition( - score_def_name="test_create_sd", - model="12345", - table_name="test_table", - table_file="test_path", - ) - - # Valid table_file argument that successfully creates a table test case - get_table.return_value = None - upload_file.return_value = RestObj - get_table_mock = {"tableName": "test_table"} - get_table.return_value = get_table_mock - response = sd.create_score_definition( - score_def_name="test_create_sd", - model="12345", - table_name="test_table", - table_file="test_path", - ) - assert response - - # Valid table_name argument test case - get_table.return_value = get_table_mock - response = sd.create_score_definition( - score_def_name="test_create_sd", - model="12345", - table_name="test_table", - table_file="test_path", - ) - assert response - - # Checking response with inputVariables in model elements - get_model_mock = { - "id": "12345", - "projectId": "54321", - "projectVersionId": "67890", - "name": "test_model", - "inputVariables": [ - {"name": "first"}, - {"name": "second"}, - {"name": "third"}, - ], - } - get_model.return_value = get_model_mock - get_table.return_value = get_table_mock - response = sd.create_score_definition( - score_def_name="test_create_sd", - model="12345", - table_name="test_table", - ) - assert response - assert post.call_count == 3 - - data = post.call_args - json_data = json.loads(data.kwargs["data"]) - assert json_data["mappings"] != [] + "sasctl._services.model_repository.ModelRepository.get_model_or_version" + ) as get_model_or_version: + with mock.patch( + "sasctl._services.score_definitions.ScoreDefinitions.is_uuid" + ) as is_uuid: + with mock.patch( + "sasctl._services.score_definitions.ScoreDefinitions.post" + ) as post: + + # Invalid model id test case + get_model.return_value = None + with pytest.raises(HTTPError): + sd.create_score_definition( + score_def_name="test_create_sd", + model="12345", + table_name="test_table", + ) + # Valid model id but invalid table name with no table_file argument test case + get_model_mock = { + "id": "12345", + "projectId": "54321", + "projectVersionId": "67890", + "name": "test_model", + } + get_model.return_value = get_model_mock + get_table.return_value = None + with pytest.raises(HTTPError): + sd.create_score_definition( + score_def_name="test_create_sd", + model="12345", + table_name="test_table", + ) + + # Invalid table name with a table_file argument that doesn't work test case + get_table.return_value = None + upload_file.return_value = None + get_table.return_value = None + with pytest.raises(HTTPError): + sd.create_score_definition( + score_def_name="test_create_sd", + model="12345", + table_name="test_table", + table_file="test_path", + ) + + # Valid table_file argument that successfully creates a table test case + get_table.return_value = None + upload_file.return_value = RestObj + get_table_mock = {"tableName": "test_table"} + get_table.return_value = get_table_mock + response = sd.create_score_definition( + score_def_name="test_create_sd", + model="12345", + table_name="test_table", + table_file="test_path", + ) + assert response + + # Valid table_name argument test case + get_table.return_value = get_table_mock + response = sd.create_score_definition( + score_def_name="test_create_sd", + model="12345", + table_name="test_table", + table_file="test_path", + ) + assert response + + # Checking response with inputVariables in model elements + get_model_mock = { + "id": "12345", + "projectId": "54321", + "projectVersionId": "67890", + "name": "test_model", + "inputVariables": [ + {"name": "first"}, + {"name": "second"}, + {"name": "third"}, + ], + } + get_model.return_value = get_model_mock + get_table.return_value = get_table_mock + response = sd.create_score_definition( + score_def_name="test_create_sd", + model="12345", + table_name="test_table", + ) + assert response + assert post.call_count == 3 + + data = post.call_args + json_data = json.loads(data.kwargs["data"]) + assert json_data["mappings"] != [] + assert ( + json_data["objectDescriptor"]["name"] + == "test_model (latest)" + ) + assert ( + json_data["properties"]["versionedModel"] + == "test_model (latest)" + ) + + # Model version dictionary with no model version name + with pytest.raises(ValueError): + response = sd.create_score_definition( + score_def_name="test_create_sd", + model="12345", + table_name="test_table", + model_version={}, + ) + + # Model version as a model version name string, not UUID + get_model.return_value = get_model_mock + get_table.return_value = get_table_mock + is_uuid.return_value = False + response = sd.create_score_definition( + score_def_name="test_create_sd", + model="12345", + table_name="test_table", + model_version="1.0", + ) + assert response + assert post.call_count == 4 + + data = post.call_args + json_data = json.loads(data.kwargs["data"]) + assert ( + json_data["objectDescriptor"]["name"] + == "test_model (1.0)" + ) + assert ( + json_data["properties"]["versionedModel"] + == "test_model (1.0)" + ) + + # Model version as a dict with modelVersionName key + get_model.return_value = get_model_mock + get_table.return_value = get_table_mock + is_uuid.return_value = False + response = sd.create_score_definition( + score_def_name="test_create_sd", + model="12345", + table_name="test_table", + model_version={"modelVersionName": "1.0"}, + ) + assert response + assert post.call_count == 5 + + data = post.call_args + json_data = json.loads(data.kwargs["data"]) + assert ( + json_data["objectDescriptor"]["name"] + == "test_model (1.0)" + ) + assert ( + json_data["properties"]["versionedModel"] + == "test_model (1.0)" + ) + + # Model version as a dictionary with model version name key + get_version_mock = { + "id": "3456", + "modelVersionName": "1.0", + } + get_model.return_value = get_model_mock + get_table.return_value = get_table_mock + is_uuid.return_value = True + get_model_or_version.return_value = get_version_mock + response = sd.create_score_definition( + score_def_name="test_create_sd", + model="12345", + table_name="test_table", + model_version="3456", + ) + assert response + assert post.call_count == 6 + + data = post.call_args + json_data = json.loads(data.kwargs["data"]) + assert ( + json_data["objectDescriptor"]["name"] + == "test_model (1.0)" + ) + assert ( + json_data["properties"]["versionedModel"] + == "test_model (1.0)" + ) diff --git a/tests/unit/test_write_json_files.py b/tests/unit/test_write_json_files.py index b0a3c6a0..3321fc30 100644 --- a/tests/unit/test_write_json_files.py +++ b/tests/unit/test_write_json_files.py @@ -699,8 +699,9 @@ def test_create_requirements_json(change_dir): dtc = dtc.fit(x_train, y_train) with open(tmp_dir / "DecisionTreeClassifier.pickle", "wb") as pkl_file: pickle.dump(dtc, pkl_file) - jf.create_requirements_json(tmp_dir, Path(tmp_dir)) + jf.create_requirements_json(tmp_dir, Path(tmp_dir), True) assert (Path(tmp_dir) / "requirements.json").exists() + assert (Path(tmp_dir) / "requirements.txt").exists() json_dict = jf.create_requirements_json(tmp_dir) expected = [ @@ -709,13 +710,20 @@ def test_create_requirements_json(change_dir): "command": f"pip install numpy=={np.__version__}", }, { - "step": "install sklearn", - "command": f"pip install sklearn=={sk.__version__}", + "step": "install scikit-learn", + "command": f"pip install scikit-learn=={sk.__version__}", }, ] unittest.TestCase.maxDiff = None unittest.TestCase().assertCountEqual(json_dict, expected) + # Verify requirements.txt content + with open(Path(tmp_dir) / "requirements.txt", "r") as file: + requirements_content = [line.strip() for line in file.readlines()] + + assert f"numpy=={np.__version__}" in requirements_content + assert f"scikit-learn=={sk.__version__}" in requirements_content + class TestAssessBiasHelpers(unittest.TestCase): md_1 = pd.DataFrame({"Value": [0], "Base": ["A"], "Compare": ["C"]})