From 0a2f0cae9293366c079ac5327d7925fdd9bc6728 Mon Sep 17 00:00:00 2001 From: Jordan Woods <13803242+jorwoods@users.noreply.github.com> Date: Sun, 20 Apr 2025 11:08:40 -0500 Subject: [PATCH] chore: type hint database and table objects --- tableauserverclient/models/datasource_item.py | 4 +- tableauserverclient/models/flow_item.py | 4 +- tableauserverclient/models/table_item.py | 10 +- tableauserverclient/models/tableau_types.py | 14 +- .../server/endpoint/databases_endpoint.py | 119 +++++++++++-- .../server/endpoint/datasources_endpoint.py | 6 +- .../server/endpoint/dqw_endpoint.py | 22 ++- .../server/endpoint/tables_endpoint.py | 157 ++++++++++++++++-- 8 files changed, 284 insertions(+), 52 deletions(-) diff --git a/tableauserverclient/models/datasource_item.py b/tableauserverclient/models/datasource_item.py index 2005edf7e..7aab5fbaf 100644 --- a/tableauserverclient/models/datasource_item.py +++ b/tableauserverclient/models/datasource_item.py @@ -280,8 +280,8 @@ def _set_connections(self, connections) -> None: def _set_permissions(self, permissions): self._permissions = permissions - def _set_data_quality_warnings(self, dqws): - self._data_quality_warnings = dqws + def _set_data_quality_warnings(self, dqw): + self._data_quality_warnings = dqw def _set_revisions(self, revisions): self._revisions = revisions diff --git a/tableauserverclient/models/flow_item.py b/tableauserverclient/models/flow_item.py index 0083776bb..063897e41 100644 --- a/tableauserverclient/models/flow_item.py +++ b/tableauserverclient/models/flow_item.py @@ -146,8 +146,8 @@ def _set_connections(self, connections): def _set_permissions(self, permissions): self._permissions = permissions - def _set_data_quality_warnings(self, dqws): - self._data_quality_warnings = dqws + def _set_data_quality_warnings(self, dqw): + self._data_quality_warnings = dqw def _parse_common_elements(self, flow_xml, ns): if not isinstance(flow_xml, ET.Element): diff --git a/tableauserverclient/models/table_item.py b/tableauserverclient/models/table_item.py index 0afdd4df3..541f84360 100644 --- a/tableauserverclient/models/table_item.py +++ b/tableauserverclient/models/table_item.py @@ -1,8 +1,12 @@ +from typing import Callable, Optional, TYPE_CHECKING from defusedxml.ElementTree import fromstring from .exceptions import UnpopulatedPropertyError from .property_decorators import property_not_empty, property_is_boolean +if TYPE_CHECKING: + from tableauserverclient.models import DQWItem + class TableItem: def __init__(self, name, description=None): @@ -40,7 +44,7 @@ def dqws(self): return self._data_quality_warnings() @property - def id(self): + def id(self) -> Optional[str]: return self._id @property @@ -100,8 +104,8 @@ def columns(self): def _set_columns(self, columns): self._columns = columns - def _set_data_quality_warnings(self, dqws): - self._data_quality_warnings = dqws + def _set_data_quality_warnings(self, dqw: Callable[[], list["DQWItem"]]) -> None: + self._data_quality_warnings = dqw def _set_values(self, table_values): if "id" in table_values: diff --git a/tableauserverclient/models/tableau_types.py b/tableauserverclient/models/tableau_types.py index 01ee3d3a9..e69d02a06 100644 --- a/tableauserverclient/models/tableau_types.py +++ b/tableauserverclient/models/tableau_types.py @@ -1,8 +1,10 @@ from typing import Union +from tableauserverclient.models.database_item import DatabaseItem from tableauserverclient.models.datasource_item import DatasourceItem from tableauserverclient.models.flow_item import FlowItem from tableauserverclient.models.project_item import ProjectItem +from tableauserverclient.models.table_item import TableItem from tableauserverclient.models.view_item import ViewItem from tableauserverclient.models.workbook_item import WorkbookItem from tableauserverclient.models.metric_item import MetricItem @@ -25,7 +27,17 @@ class Resource: # resource types that have permissions, can be renamed, etc # todo: refactoring: should actually define TableauItem as an interface and let all these implement it -TableauItem = Union[DatasourceItem, FlowItem, MetricItem, ProjectItem, ViewItem, WorkbookItem, VirtualConnectionItem] +TableauItem = Union[ + DatasourceItem, + FlowItem, + MetricItem, + ProjectItem, + ViewItem, + WorkbookItem, + VirtualConnectionItem, + DatabaseItem, + TableItem, +] def plural_type(content_type: Union[Resource, str]) -> str: diff --git a/tableauserverclient/server/endpoint/databases_endpoint.py b/tableauserverclient/server/endpoint/databases_endpoint.py index c0e106eb2..dc88ceaa5 100644 --- a/tableauserverclient/server/endpoint/databases_endpoint.py +++ b/tableauserverclient/server/endpoint/databases_endpoint.py @@ -1,7 +1,8 @@ import logging -from typing import Union +from typing import TYPE_CHECKING, Optional, Union from collections.abc import Iterable +from tableauserverclient.models.permissions_item import PermissionsRule from tableauserverclient.server.endpoint.default_permissions_endpoint import _DefaultPermissionsEndpoint from tableauserverclient.server.endpoint.dqw_endpoint import _DataQualityWarningEndpoint from tableauserverclient.server.endpoint.endpoint import api, Endpoint @@ -13,6 +14,10 @@ from tableauserverclient.helpers.logging import logger +if TYPE_CHECKING: + from tableauserverclient.models.dqw_item import DQWItem + from tableauserverclient.server.request_options import RequestOptions + class Databases(Endpoint, TaggingMixin): def __init__(self, parent_srv): @@ -23,11 +28,29 @@ def __init__(self, parent_srv): self._data_quality_warnings = _DataQualityWarningEndpoint(parent_srv, Resource.Database) @property - def baseurl(self): + def baseurl(self) -> str: return f"{self.parent_srv.baseurl}/sites/{self.parent_srv.site_id}/databases" @api(version="3.5") - def get(self, req_options=None): + def get(self, req_options: Optional["RequestOptions"] = None) -> tuple[list[DatabaseItem], PaginationItem]: + """ + Get information about all databases on the site. Endpoint is paginated, + and will return a default of 100 items per page. Use the `req_options` + parameter to customize the request. + + REST API: https://help.tableau.com/current/api/rest_api/en-us/REST/rest_api_ref_metadata.htm#query_databases + + Parameters + ---------- + req_options : RequestOptions, optional + Options to customize the request. If not provided, defaults to None. + + Returns + ------- + tuple[list[DatabaseItem], PaginationItem] + A tuple containing a list of DatabaseItem objects and a + PaginationItem object. + """ logger.info("Querying all databases on site") url = self.baseurl server_response = self.get_request(url, req_options) @@ -37,7 +60,27 @@ def get(self, req_options=None): # Get 1 database @api(version="3.5") - def get_by_id(self, database_id): + def get_by_id(self, database_id: str) -> DatabaseItem: + """ + Get information about a single database asset on the site. + + REST API: https://help.tableau.com/current/api/rest_api/en-us/REST/rest_api_ref_metadata.htm#query_database + + Parameters + ---------- + database_id : str + The ID of the database to retrieve. + + Returns + ------- + DatabaseItem + A DatabaseItem object representing the database. + + Raises + ------ + ValueError + If the database ID is undefined. + """ if not database_id: error = "database ID undefined." raise ValueError(error) @@ -47,7 +90,24 @@ def get_by_id(self, database_id): return DatabaseItem.from_response(server_response.content, self.parent_srv.namespace)[0] @api(version="3.5") - def delete(self, database_id): + def delete(self, database_id: str) -> None: + """ + Deletes a single database asset from the server. + + Parameters + ---------- + database_id : str + The ID of the database to delete. + + Returns + ------- + None + + Raises + ------ + ValueError + If the database ID is undefined. + """ if not database_id: error = "Database ID undefined." raise ValueError(error) @@ -56,7 +116,28 @@ def delete(self, database_id): logger.info(f"Deleted single database (ID: {database_id})") @api(version="3.5") - def update(self, database_item): + def update(self, database_item: DatabaseItem) -> DatabaseItem: + """ + Update the database description, certify the database, set permissions, + or assign a User as the database contact. + + REST API: https://help.tableau.com/current/api/rest_api/en-us/REST/rest_api_ref_metadata.htm#update_database + + Parameters + ---------- + database_item : DatabaseItem + The DatabaseItem object to update. + + Returns + ------- + DatabaseItem + The updated DatabaseItem object. + + Raises + ------ + MissingRequiredFieldError + If the database item is missing an ID. + """ if not database_item.id: error = "Database item missing ID." raise MissingRequiredFieldError(error) @@ -88,43 +169,45 @@ def _get_tables_for_database(self, database_item): return tables @api(version="3.5") - def populate_permissions(self, item): + def populate_permissions(self, item: DatabaseItem) -> None: self._permissions.populate(item) @api(version="3.5") - def update_permissions(self, item, rules): + def update_permissions(self, item: DatabaseItem, rules: list[PermissionsRule]) -> list[PermissionsRule]: return self._permissions.update(item, rules) @api(version="3.5") - def delete_permission(self, item, rules): + def delete_permission(self, item: DatabaseItem, rules: list[PermissionsRule]) -> None: self._permissions.delete(item, rules) @api(version="3.5") - def populate_table_default_permissions(self, item): + def populate_table_default_permissions(self, item: DatabaseItem): self._default_permissions.populate_default_permissions(item, Resource.Table) @api(version="3.5") - def update_table_default_permissions(self, item): - return self._default_permissions.update_default_permissions(item, Resource.Table) + def update_table_default_permissions( + self, item: DatabaseItem, rules: list[PermissionsRule] + ) -> list[PermissionsRule]: + return self._default_permissions.update_default_permissions(item, rules, Resource.Table) @api(version="3.5") - def delete_table_default_permissions(self, item): - self._default_permissions.delete_default_permission(item, Resource.Table) + def delete_table_default_permissions(self, rule: PermissionsRule, item: DatabaseItem) -> None: + self._default_permissions.delete_default_permission(item, rule, Resource.Table) @api(version="3.5") - def populate_dqw(self, item): + def populate_dqw(self, item: DatabaseItem) -> None: self._data_quality_warnings.populate(item) @api(version="3.5") - def update_dqw(self, item, warning): + def update_dqw(self, item: DatabaseItem, warning: "DQWItem") -> list["DQWItem"]: return self._data_quality_warnings.update(item, warning) @api(version="3.5") - def add_dqw(self, item, warning): + def add_dqw(self, item: DatabaseItem, warning: "DQWItem") -> list["DQWItem"]: return self._data_quality_warnings.add(item, warning) @api(version="3.5") - def delete_dqw(self, item): + def delete_dqw(self, item: DatabaseItem) -> None: self._data_quality_warnings.clear(item) @api(version="3.9") diff --git a/tableauserverclient/server/endpoint/datasources_endpoint.py b/tableauserverclient/server/endpoint/datasources_endpoint.py index 69913a724..168446974 100644 --- a/tableauserverclient/server/endpoint/datasources_endpoint.py +++ b/tableauserverclient/server/endpoint/datasources_endpoint.py @@ -733,7 +733,7 @@ def populate_dqw(self, item) -> None: self._data_quality_warnings.populate(item) @api(version="3.5") - def update_dqw(self, item, warning): + def update_dqw(self, item: DatasourceItem, warning: "DQWItem") -> list["DQWItem"]: """ Update the warning type, status, and message of a data quality warning. @@ -755,7 +755,7 @@ def update_dqw(self, item, warning): return self._data_quality_warnings.update(item, warning) @api(version="3.5") - def add_dqw(self, item, warning): + def add_dqw(self, item: DatasourceItem, warning: "DQWItem") -> list["DQWItem"]: """ Add a data quality warning to a datasource. @@ -786,7 +786,7 @@ def add_dqw(self, item, warning): return self._data_quality_warnings.add(item, warning) @api(version="3.5") - def delete_dqw(self, item): + def delete_dqw(self, item: DatasourceItem) -> None: """ Delete a data quality warnings from an asset. diff --git a/tableauserverclient/server/endpoint/dqw_endpoint.py b/tableauserverclient/server/endpoint/dqw_endpoint.py index 90e31483b..d2ad517ee 100644 --- a/tableauserverclient/server/endpoint/dqw_endpoint.py +++ b/tableauserverclient/server/endpoint/dqw_endpoint.py @@ -1,4 +1,5 @@ import logging +from typing import Callable, Optional, Protocol, TYPE_CHECKING from .endpoint import Endpoint from .exceptions import MissingRequiredFieldError @@ -7,6 +8,15 @@ from tableauserverclient.helpers.logging import logger +if TYPE_CHECKING: + from tableauserverclient.server.request_options import RequestOptions + + +class HasId(Protocol): + @property + def id(self) -> Optional[str]: ... + def _set_data_quality_warnings(self, dqw: Callable[[], list[DQWItem]]): ... + class _DataQualityWarningEndpoint(Endpoint): def __init__(self, parent_srv, resource_type): @@ -14,12 +24,12 @@ def __init__(self, parent_srv, resource_type): self.resource_type = resource_type @property - def baseurl(self): + def baseurl(self) -> str: return "{}/sites/{}/dataQualityWarnings/{}".format( self.parent_srv.baseurl, self.parent_srv.site_id, self.resource_type ) - def add(self, resource, warning): + def add(self, resource: HasId, warning: DQWItem) -> list[DQWItem]: url = f"{self.baseurl}/{resource.id}" add_req = RequestFactory.DQW.add_req(warning) response = self.post_request(url, add_req) @@ -28,7 +38,7 @@ def add(self, resource, warning): return warnings - def update(self, resource, warning): + def update(self, resource: HasId, warning: DQWItem) -> list[DQWItem]: url = f"{self.baseurl}/{resource.id}" add_req = RequestFactory.DQW.update_req(warning) response = self.put_request(url, add_req) @@ -37,11 +47,11 @@ def update(self, resource, warning): return warnings - def clear(self, resource): + def clear(self, resource: HasId) -> None: url = f"{self.baseurl}/{resource.id}" return self.delete_request(url) - def populate(self, item): + def populate(self, item: HasId) -> None: if not item.id: error = "Server item is missing ID. Item must be retrieved from server first." raise MissingRequiredFieldError(error) @@ -52,7 +62,7 @@ def dqw_fetcher(): item._set_data_quality_warnings(dqw_fetcher) logger.info(f"Populated permissions for item (ID: {item.id})") - def _get_data_quality_warnings(self, item, req_options=None): + def _get_data_quality_warnings(self, item: HasId, req_options: Optional["RequestOptions"] = None) -> list[DQWItem]: url = f"{self.baseurl}/{item.id}" server_response = self.get_request(url, req_options) dqws = DQWItem.from_response(server_response.content, self.parent_srv.namespace) diff --git a/tableauserverclient/server/endpoint/tables_endpoint.py b/tableauserverclient/server/endpoint/tables_endpoint.py index 120d3ba9c..ad80e7d0e 100644 --- a/tableauserverclient/server/endpoint/tables_endpoint.py +++ b/tableauserverclient/server/endpoint/tables_endpoint.py @@ -1,7 +1,8 @@ import logging -from typing import Union +from typing import Optional, Union, TYPE_CHECKING from collections.abc import Iterable +from tableauserverclient.models.permissions_item import PermissionsRule from tableauserverclient.server.endpoint.dqw_endpoint import _DataQualityWarningEndpoint from tableauserverclient.server.endpoint.endpoint import api, Endpoint from tableauserverclient.server.endpoint.exceptions import MissingRequiredFieldError @@ -12,6 +13,10 @@ from tableauserverclient.server.pager import Pager from tableauserverclient.helpers.logging import logger +from tableauserverclient.server.request_options import RequestOptions + +if TYPE_CHECKING: + from tableauserverclient.models import DQWItem, PermissionsRule class Tables(Endpoint, TaggingMixin[TableItem]): @@ -22,11 +27,29 @@ def __init__(self, parent_srv): self._data_quality_warnings = _DataQualityWarningEndpoint(self.parent_srv, "table") @property - def baseurl(self): + def baseurl(self) -> str: return f"{self.parent_srv.baseurl}/sites/{self.parent_srv.site_id}/tables" @api(version="3.5") - def get(self, req_options=None): + def get(self, req_options: Optional[RequestOptions] = None) -> tuple[list[TableItem], PaginationItem]: + """ + Get information about all tables on the site. Endpoint is paginated, and + will return a default of 100 items per page. Use the `req_options` + parameter to customize the request. + + REST API: https://help.tableau.com/current/api/rest_api/en-us/REST/rest_api_ref_metadata.htm#query_tables + + Parameters + ---------- + req_options : RequestOptions, optional + Options to customize the request. If not provided, defaults to None. + + Returns + ------- + tuple[list[TableItem], PaginationItem] + A tuple containing a list of TableItem objects and a PaginationItem + object. + """ logger.info("Querying all tables on site") url = self.baseurl server_response = self.get_request(url, req_options) @@ -36,7 +59,27 @@ def get(self, req_options=None): # Get 1 table @api(version="3.5") - def get_by_id(self, table_id): + def get_by_id(self, table_id: str) -> TableItem: + """ + Get information about a single table on the site. + + REST API: https://help.tableau.com/current/api/rest_api/en-us/REST/rest_api_ref_metadata.htm#query_table + + Parameters + ---------- + table_id : str + The ID of the table to retrieve. + + Returns + ------- + TableItem + A TableItem object representing the table. + + Raises + ------ + ValueError + If the table ID is not provided. + """ if not table_id: error = "table ID undefined." raise ValueError(error) @@ -46,7 +89,24 @@ def get_by_id(self, table_id): return TableItem.from_response(server_response.content, self.parent_srv.namespace)[0] @api(version="3.5") - def delete(self, table_id): + def delete(self, table_id: str) -> None: + """ + Delete a single table from the server. + + Parameters + ---------- + table_id : str + The ID of the table to delete. + + Returns + ------- + None + + Raises + ------ + ValueError + If the table ID is not provided. + """ if not table_id: error = "Database ID undefined." raise ValueError(error) @@ -55,7 +115,27 @@ def delete(self, table_id): logger.info(f"Deleted single table (ID: {table_id})") @api(version="3.5") - def update(self, table_item): + def update(self, table_item: TableItem) -> TableItem: + """ + Update a table on the server. + + REST API: https://help.tableau.com/current/api/rest_api/en-us/REST/rest_api_ref_metadata.htm#update_table + + Parameters + ---------- + table_item : TableItem + The TableItem object to update. + + Returns + ------- + TableItem + The updated TableItem object. + + Raises + ------ + MissingRequiredFieldError + If the table item is missing an ID. + """ if not table_item.id: error = "table item missing ID." raise MissingRequiredFieldError(error) @@ -69,21 +149,46 @@ def update(self, table_item): # Get all columns of the table @api(version="3.5") - def populate_columns(self, table_item, req_options=None): + def populate_columns(self, table_item: TableItem, req_options: Optional[RequestOptions] = None) -> None: + """ + Populate the columns of a table item. Sets a fetcher function to + retrieve the columns when needed. + + REST API: https://help.tableau.com/current/api/rest_api/en-us/REST/rest_api_ref_metadata.htm#query_columns + + Parameters + ---------- + table_item : TableItem + The TableItem object to populate columns for. + + req_options : RequestOptions, optional + Options to customize the request. If not provided, defaults to None. + + Returns + ------- + None + + Raises + ------ + MissingRequiredFieldError + If the table item is missing an ID. + """ if not table_item.id: error = "Table item missing ID. table must be retrieved from server first." raise MissingRequiredFieldError(error) def column_fetcher(): return Pager( - lambda options: self._get_columns_for_table(table_item, options), + lambda options: self._get_columns_for_table(table_item, options), # type: ignore req_options, ) table_item._set_columns(column_fetcher) logger.info(f"Populated columns for table (ID: {table_item.id}") - def _get_columns_for_table(self, table_item, req_options=None): + def _get_columns_for_table( + self, table_item: TableItem, req_options: Optional[RequestOptions] = None + ) -> tuple[list[ColumnItem], PaginationItem]: url = f"{self.baseurl}/{table_item.id}/columns" server_response = self.get_request(url, req_options) columns = ColumnItem.from_response(server_response.content, self.parent_srv.namespace) @@ -91,7 +196,25 @@ def _get_columns_for_table(self, table_item, req_options=None): return columns, pagination_item @api(version="3.5") - def update_column(self, table_item, column_item): + def update_column(self, table_item: TableItem, column_item: ColumnItem) -> ColumnItem: + """ + Update the description of a column in a table. + + REST API: https://help.tableau.com/current/api/rest_api/en-us/REST/rest_api_ref_metadata.htm#update_column + + Parameters + ---------- + table_item : TableItem + The TableItem object representing the table. + + column_item : ColumnItem + The ColumnItem object representing the column to update. + + Returns + ------- + ColumnItem + The updated ColumnItem object. + """ url = f"{self.baseurl}/{table_item.id}/columns/{column_item.id}" update_req = RequestFactory.Column.update_req(column_item) server_response = self.put_request(url, update_req) @@ -101,31 +224,31 @@ def update_column(self, table_item, column_item): return column @api(version="3.5") - def populate_permissions(self, item): + def populate_permissions(self, item: TableItem) -> None: self._permissions.populate(item) @api(version="3.5") - def update_permissions(self, item, rules): + def update_permissions(self, item: TableItem, rules: list[PermissionsRule]) -> list[PermissionsRule]: return self._permissions.update(item, rules) @api(version="3.5") - def delete_permission(self, item, rules): + def delete_permission(self, item: TableItem, rules: list[PermissionsRule]) -> None: return self._permissions.delete(item, rules) @api(version="3.5") - def populate_dqw(self, item): + def populate_dqw(self, item: TableItem) -> None: self._data_quality_warnings.populate(item) @api(version="3.5") - def update_dqw(self, item, warning): + def update_dqw(self, item: TableItem, warning: "DQWItem") -> list["DQWItem"]: return self._data_quality_warnings.update(item, warning) @api(version="3.5") - def add_dqw(self, item, warning): + def add_dqw(self, item: TableItem, warning: "DQWItem") -> list["DQWItem"]: return self._data_quality_warnings.add(item, warning) @api(version="3.5") - def delete_dqw(self, item): + def delete_dqw(self, item: TableItem) -> None: self._data_quality_warnings.clear(item) @api(version="3.9")