8000 Jorwoods/type hint flows by jorwoods · Pull Request #937 · tableau/server-client-python · GitHub
[go: up one dir, main page]

Skip to content

Jorwoods/type hint flows #937

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jan 28, 2022
Merged
54 changes: 31 additions & 23 deletions tableauserverclient/models/flow_item.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,28 @@
from ..datetime_helpers import parse_datetime
import copy

from typing import List, Optional, TYPE_CHECKING, Set

if TYPE_CHECKING:
import datetime
from .connection_item import ConnectionItem
from .permissions_item import Permission
from .dqw_item import DQWItem


class FlowItem(object):
def __init__(self, project_id, name=None):
self._webpage_url = None
self._created_at = None
self._id = None
self._initial_tags = set()
self._project_name = None
self._updated_at = None
self.name = name
self.owner_id = None
self.project_id = project_id
self.tags = set()
self.description = None
def __init__(self, project_id: str, name: Optional[str] = None) -> None:
self._webpage_url: Optional[str] = None
self._created_at: Optional["datetime.datetime"] = None
self._id: Optional[str] = None
self._initial_tags: Set[str] = set()
self._project_name: Optional[str] = None
self._updated_at: Optional["datetime.datetime"] = None
self.name: Optional[str] = name
self.owner_id: Optional[str] = None
self.project_id: str = project_id
self.tags: Set[str] = set()
self.description: Optional[str] = None

self._connections = None
self._permissions = None
Expand All @@ -39,11 +47,11 @@ def permissions(self):
return self._permissions()

@property
def webpage_url(self):
def webpage_url(self) -> Optional[str]:
return self._webpage_url

@property
def created_at(self):
def created_at(self) -> Optional["datetime.datetime"]:
return self._created_at

@property
Expand All @@ -54,36 +62,36 @@ def dqws(self):
return self._data_quality_warnings()

@property
def id(self):
def id(self) -> Optional[str]:
return self._id

@property
def project_id(self):
def project_id(self) -> str:
return self._project_id

@project_id.setter
@property_not_nullable
def project_id(self, value):
def project_id(self, value: str) -> None:
self._project_id = value

@property
def description(self):
def description(self) -> Optional[str]:
return self._description

@description.setter
def description(self, value):
def description(self, value: str) -> None:
self._description = value

@property
def project_name(self):
def project_name(self) -> Optional[str]:
return self._project_name

@property
def flow_type(self):
def flow_type(self): # What is this? It doesn't seem to get set anywhere.
return self._flow_type

@property
def updated_at(self):
def updated_at(self) -> Optional["datetime.datetime"]:
return self._updated_at

def _set_connections(self, connections):
Expand Down Expand Up @@ -161,7 +169,7 @@ def _set_values(
self.owner_id = owner_id

@classmethod
def from_response(cls, resp, ns):
def from_response(cls, resp, ns) -> List["FlowItem"]:
all_flow_items = list()
parsed_response = ET.fromstring(resp)
all_flow_xml = parsed_response.findall(".//t:flow", namespaces=ns)
Expand Down
53 changes: 30 additions & 23 deletions tableauserverclient/server/endpoint/flows_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,23 @@
import cgi
from contextlib import closing

from typing import Iterable, List, Optional, TYPE_CHECKING, Tuple, Union

# The maximum size of a file that can be published in a single request is 64MB
FILESIZE_LIMIT = 1024 * 1024 * 64 # 64MB

ALLOWED_FILE_EXTENSIONS = ["tfl", "tflx"]

logger = logging.getLogger("tableau.endpoint.flows")

if TYPE_CHECKING:
from .. import DQWItem
from ..request_options import RequestOptions
from ...models.permissions_item import Permission, PermissionsRule


FilePath = Union[str, os.PathLike]


class Flows(Endpoint):
def __init__(self, parent_srv):
Expand All @@ -29,12 +39,12 @@ def __init__(self, parent_srv):
self._data_quality_warnings = _DataQualityWarningEndpoint(self.parent_srv, "flow")

@property
def baseurl(self):
def baseurl(self) -> str:
return "{0}/sites/{1}/flows".format(self.parent_srv.baseurl, self.parent_srv.site_id)

# Get all flows
@api(version="3.3")
def get(self, req_options=None):
def get(self, req_options: Optional["RequestOptions"] = None) -> Tuple[List[FlowItem], PaginationItem]:
logger.info("Querying all flows on site")
url = self.baseurl
server_response = self.get_request(url, req_options)
Expand All @@ -44,7 +54,7 @@ def get(self, req_options=None):

# Get 1 flow by id
@api(version="3.3")
def get_by_id(self, flow_id):
def get_by_id(self, flow_id: str) -> FlowItem:
if not flow_id:
error = "Flow ID undefined."
raise ValueError(error)
Expand All @@ -55,7 +65,7 @@ def get_by_id(self, flow_id):

# Populate flow item's connections
@api(version="3.3")
def populate_connections(self, flow_item):
def populate_connections(self, flow_item: FlowItem) -> None:
if not flow_item.id:
error = "Flow item missing ID. Flow must be retrieved from server first."
raise MissingRequiredFieldError(error)
Expand All @@ -66,15 +76,15 @@ def connections_fetcher():
flow_item._set_connections(connections_fetcher)
logger.info("Populated connections for flow (ID: {0})".format(flow_item.id))

def _get_flow_connections(self, flow_item, req_options=None):
def _get_flow_connections(self, flow_item, req_options: Optional["RequestOptions"] = None) -> List[ConnectionItem]:
url = "{0}/{1}/connections".format(self.baseurl, flow_item.id)
server_response = self.get_request(url, req_options)
connections = ConnectionItem.from_response(server_response.content, self.parent_srv.namespace)
return connections

# Delete 1 flow by id
@api(version="3.3")
def delete(self, flow_id):
def delete(self, flow_id: str) -> None:
if not flow_id:
error = "Flow ID undefined."
raise ValueError(error)
Expand All @@ -84,7 +94,7 @@ def delete(self, flow_id):

# Download 1 flow by id
@api(version="3.3")
def download(self, flow_id, filepath=None):
def download(self, flow_id: str, filepath: FilePath = None) -> str:
if not flow_id:
error = "Flow ID undefined."
raise ValueError(error)
Expand All @@ -105,7 +115,7 @@ def download(self, flow_id, filepath=None):

# Update flow
@api(version="3.3")
def update(self, flow_item):
def update(self, flow_item: FlowItem) -> FlowItem:
if not flow_item.id:
error = "Flow item missing ID. Flow must be retrieved from server first."
raise MissingRequiredFieldError(error)
Expand All @@ -122,7 +132,7 @@ def update(self, flow_item):

# Update flow connections
@api(version="3.3")
def update_connection(self, flow_item, connection_item):
def update_connection(self, flow_item: FlowItem, connection_item: ConnectionItem) -> ConnectionItem:
url = "{0}/{1}/connections/{2}".format(self.baseurl, flow_item.id, connection_item.id)

update_req = RequestFactory.Connection.update_req(connection_item)
Expand All @@ -133,7 +143,7 @@ def update_connection(self, flow_item, connection_item):
return connection

@api(version=&quo 57AE t;3.3")
def refresh(self, flow_item):
def refresh(self, flow_item: FlowItem) -> JobItem:
url = "{0}/{1}/run".format(self.baseurl, flow_item.id)
empty_req = RequestFactory.Empty.empty_req()
server_response = self.post_request(url, empty_req)
Expand All @@ -142,7 +152,9 @@ def refresh(self, flow_item):

# Publish flow
@api(version="3.3")
def publish(self, flow_item, file_path, mode, connections=None):
def publish(
self, flow_item: FlowItem, file_path: FilePath, mode: str, connections: Optional[List[ConnectionItem]] = None
) -> FlowItem:
if not os.path.isfile(file_path):
error = "File path does not lead to an existing file."
raise IOError(error)
Expand Down Expand Up @@ -189,13 +201,8 @@ def publish(self, flow_item, file_path, mode, connections=None):
logger.info("Published {0} (ID: {1})".format(filename, new_flow.id))
return new_flow

server_response = self.post_request(url, xml_request, content_type)
new_flow = FlowItem.from_response(server_response.content, self.parent_srv.namespace)[0]
logger.info("Published {0} (ID: {1})".format(filename, new_flow.id))
return new_flow

@api(version="3.3")
def populate_permissions(self, item):
def populate_permissions(self, item: FlowItem) -> None:
self._permissions.populate(item)

@api(version="3.3")
Expand All @@ -209,25 +216,25 @@ def update_permission(self, item, permission_item):
self._permissions.update(item, permission_item)

@api(version="3.3")
def update_permissions(self, item, permission_item):
def update_permissions(self, item: FlowItem, permission_item: Iterable["PermissionsRule"]) -> None:
self._permissions.update(item, permission_item)

@api(version="3.3")
def delete_permission(self, item, capability_item):
def delete_permission(self, item: FlowItem, capability_item: "PermissionsRule") -> None:
self._permissions.delete(item, capability_item)

@api(version="3.5")
def populate_dqw(self, item):
def populate_dqw(self, item: FlowItem) -> None:
self._data_quality_warnings.populate(item)

@api(version="3.5")
def update_dqw(self, item, warning):
def update_dqw(self, item: FlowItem, warning: "DQWItem") -> None:
return self._data_quality_warnings.update(item, warning)

@api(version="3.5")
def add_dqw(self, item, warning):
def add_dqw(self, item: FlowItem, warning: "DQWItem") -> None:
return self._data_quality_warnings.add(item, warning)

@api(version="3.5")
def delete_dqw(self, item):
def delete_dqw(self, item: FlowItem) -> None:
self._data_quality_warnings.clear(item)
23 changes: 16 additions & 7 deletions tableauserverclient/server/request_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@

from ..models import TaskItem, UserItem, GroupItem, PermissionsRule, FavoriteItem

from typing import Optional, TYPE_CHECKING
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Tuple

if TYPE_CHECKING:
from ..models import DataAlertItem
from ..models import FlowItem
from ..models import ConnectionItem


def _add_multipart(parts):
def _add_multipart(parts: Dict) -> Tuple[Any, str]:
mime_multipart_parts = list()
for name, (filename, data, content_type) in parts.items():
multipart_part = RequestField(name=name, data=data, filename=filename)
Expand Down Expand Up @@ -302,10 +304,11 @@ def chunk_req(self, chunk):


class FlowRequest(object):
def _generate_xml(self, flow_item, connections=None):
def _generate_xml(self, flow_item: "FlowItem", connections: Optional[List["ConnectionItem"]] = None) -> bytes:
xml_request = ET.Element("tsRequest")
flow_element = ET.SubElement(xml_request, "flow")
flow_element.attrib["name"] = flow_item.name
if flow_item.name is not None:
flow_element.attrib["name"] = flow_item.name
project_element = ET.SubElement(flow_element, "project")
project_element.attrib["id"] = flow_item.project_id

Expand All @@ -315,7 +318,7 @@ def _generate_xml(self, flow_item, connections=None):
_add_connections_element(connections_element, connection)
return ET.tostring(xml_request)

def update_req(self, flow_item):
def update_req(self, flow_item: "FlowItem") -> bytes:
xml_request = ET.Element("tsRequest")
flow_element = ET.SubElement(xml_request, "flow")
if flow_item.project_id:
Expand All @@ -327,7 +330,13 @@ def update_req(self, flow_item):

return ET.tostring(xml_request)

def publish_req(self, flow_item, filename, file_contents, connections=None):
def publish_req(
self,
flow_item: "FlowItem",
filename: str,
file_contents: bytes,
connections: Optional[List["ConnectionItem"]] = None,
) -> Tuple[Any, str]:
xml_request = self._generate_xml(flow_item, connections)

parts = {
Expand All @@ -336,7 +345,7 @@ def publish_req(self, flow_item, filename, file_contents, connections=None):
}
return _add_multipart(parts)

def publish_req_chunked(self, flow_item, connections=None):
def publish_req_chunked(self, flow_item, connections=None) -> Tuple[Any, str]:
xml_request = self._generate_xml(flow_item, connections)

parts = {"request_payload": ("", xml_request, "text/xml")}
Expand Down
10 changes: 5 additions & 5 deletions test/test_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@


class FlowTests(unittest.TestCase):
def setUp(self):
def setUp(self) -> None:
self.server = TSC.Server("http://test")

# Fake signin
Expand All @@ -26,7 +26,7 @@ def setUp(self):

self.baseurl = self.server.flows.baseurl

def test_get(self):
def test_get(self) -> None:
response_xml = read_xml_asset(GET_XML)
with requests_mock.mock() as m:
m.get(self.baseurl, text=response_xml)
Expand All @@ -53,7 +53,7 @@ def test_get(self):
self.assertEqual("aa23f4ac-906f-11e9-86fb-3f0f71412e77", all_flows[1].project_id)
self.assertEqual("9127d03f-d996-405f-b392-631b25183a0f", all_flows[1].owner_id)

def test_update(self):
def test_update(self) -> None:
response_xml = read_xml_asset(UPDATE_XML)
with requests_mock.mock() as m:
m.put(self.baseurl + "/587daa37-b84d-4400-a9a2-aa90e0be7837", text=response_xml)
Expand All @@ -68,7 +68,7 @@ def test_update(self):
self.assertEqual("7ebb3f20-0fd2-4f27-a2f6-c539470999e2", single_datasource.owner_id)
self.assertEqual("So fun to see", single_datasource.description)

def test_populate_connections(self):
def test_populate_connections(self) -> None:
response_xml = read_xml_asset(POPULATE_CONNECTIONS_XML)
with requests_mock.mock() as m:
m.get(self.baseurl + "/9dbd2263-16b5-46e1-9c43-a76bb8ab65fb/connections", text=response_xml)
Expand Down Expand Up @@ -97,7 +97,7 @@ def test_populate_connections(self):
self.assertEqual("sally", conn3.username)
self.assertEqual(True, conn3.embed_password)

def test_populate_permissions(self):
def test_populate_permissions(self) -> None:
with open(asset(POPULATE_PERMISSIONS_XML), "rb") as f:
response_xml = f.read().decode("utf-8")
with requests_mock.mock() as m:
Expand Down
0