diff --git a/tableauserverclient/models/group_item.py b/tableauserverclient/models/group_item.py index fdc06604b..08d0be4a0 100644 --- a/tableauserverclient/models/group_item.py +++ b/tableauserverclient/models/group_item.py @@ -4,78 +4,83 @@ from .reference_item import ResourceReference from .user_item import UserItem +from typing import Callable, List, Optional, TYPE_CHECKING + +if TYPE_CHECKING: + from ..server import Pager + class GroupItem(object): - tag_name = "group" + tag_name: str = "group" class LicenseMode: - onLogin = "onLogin" - onSync = "onSync" + onLogin: str = "onLogin" + onSync: str = "onSync" - def __init__(self, name=None, domain_name=None): - self._id = None - self._license_mode = None - self._minimum_site_role = None - self._users = None - self.name = name - self.domain_name = domain_name + def __init__(self, name=None, domain_name=None) -> None: + self._id: Optional[str] = None + self._license_mode: Optional[str] = None + self._minimum_site_role: Optional[str] = None + self._users: Optional[Callable[..., "Pager"]] = None + self.name: Optional[str] = name + self.domain_name: Optional[str] = domain_name @property - def domain_name(self): + def domain_name(self) -> Optional[str]: return self._domain_name @domain_name.setter - def domain_name(self, value): + def domain_name(self, value: str) -> None: self._domain_name = value @property - def id(self): + def id(self) -> Optional[str]: return self._id @property - def name(self): + def name(self) -> Optional[str]: return self._name @name.setter @property_not_empty - def name(self, value): + def name(self, value: str) -> None: self._name = value @property - def license_mode(self): + def license_mode(self) -> Optional[str]: return self._license_mode @license_mode.setter @property_is_enum(LicenseMode) - def license_mode(self, value): + def license_mode(self, value: str) -> None: self._license_mode = value @property - def minimum_site_role(self): + def minimum_site_role(self) -> Optional[str]: return self._minimum_site_role @minimum_site_role.setter @property_is_enum(UserItem.Roles) - def minimum_site_role(self, value): + def minimum_site_role(self, value: str) -> None: self._minimum_site_role = value @property - def users(self): + def users(self) -> "Pager": if self._users is None: error = "Group must be populated with users first." raise UnpopulatedPropertyError(error) # Each call to `.users` should create a new pager, this just runs the callable return self._users() - def to_reference(self): + def to_reference(self) -> ResourceReference: return ResourceReference(id_=self.id, tag_name=self.tag_name) - def _set_users(self, users): + def _set_users(self, users: Callable[..., "Pager"]) -> None: self._users = users @classmethod - def from_response(cls, resp, ns): + def from_response(cls, resp, ns) -> List["GroupItem"]: all_group_items = list() parsed_response = ET.fromstring(resp) all_group_xml = parsed_response.findall(".//t:group", namespaces=ns) @@ -100,5 +105,5 @@ def from_response(cls, resp, ns): return all_group_items @staticmethod - def as_reference(id_): + def as_reference(id_: str) -> ResourceReference: return ResourceReference(id_, GroupItem.tag_name) diff --git a/tableauserverclient/server/endpoint/groups_endpoint.py b/tableauserverclient/server/endpoint/groups_endpoint.py index b771e56d8..3f4fe827e 100644 --- a/tableauserverclient/server/endpoint/groups_endpoint.py +++ b/tableauserverclient/server/endpoint/groups_endpoint.py @@ -7,15 +7,20 @@ logger = logging.getLogger("tableau.endpoint.groups") +from typing import List, Optional, TYPE_CHECKING, Tuple, Union + +if TYPE_CHECKING: + from ..request_options import RequestOptions + class Groups(Endpoint): @property - def baseurl(self): + def baseurl(self) -> str: return "{0}/sites/{1}/groups".format(self.parent_srv.baseurl, self.parent_srv.site_id) # Gets all groups @api(version="2.0") - def get(self, req_options=None): + def get(self, req_options: Optional["RequestOptions"] = None) -> Tuple[List[GroupItem], PaginationItem]: logger.info("Querying all groups on site") url = self.baseurl server_response = self.get_request(url, req_options) @@ -25,7 +30,7 @@ def get(self, req_options=None): # Gets all users in a given group @api(version="2.0") - def populate_users(self, group_item, req_options=None): + def populate_users(self, group_item, req_options: Optional["RequestOptions"] = None) -> None: if not group_item.id: error = "Group item missing ID. Group must be retrieved from server first." raise MissingRequiredFieldError(error) @@ -40,7 +45,9 @@ def user_pager(): group_item._set_users(user_pager) - def _get_users_for_group(self, group_item, req_options=None): + def _get_users_for_group( + self, group_item, req_options: Optional["RequestOptions"] = None + ) -> Tuple[List[GroupItem], PaginationItem]: url = "{0}/{1}/users".format(self.baseurl, group_item.id) server_response = self.get_request(url, req_options) user_item = UserItem.from_response(server_response.content, self.parent_srv.namespace) @@ -50,7 +57,7 @@ def _get_users_for_group(self, group_item, req_options=None): # Deletes 1 group by id @api(version="2.0") - def delete(self, group_id): + def delete(self, group_id: str) -> None: if not group_id: error = "Group ID undefined." raise ValueError(error) @@ -59,7 +66,9 @@ def delete(self, group_id): logger.info("Deleted single group (ID: {0})".format(group_id)) @api(version="2.0") - def update(self, group_item, default_site_role=None, as_job=False): + def update( + self, group_item: GroupItem, default_site_role: Optional[str] = None, as_job: bool = False + ) -> Union[GroupItem, JobItem]: # (1/8/2021): Deprecated starting v0.15 if default_site_role is not None: import warnings @@ -90,7 +99,7 @@ def update(self, group_item, default_site_role=None, as_job=False): # Create a 'local' Tableau group @api(version="2.0") - def create(self, group_item): + def create(self, group_item: GroupItem) -> GroupItem: url = self.baseurl create_req = RequestFactory.Group.create_local_req(group_item) server_response = self.post_request(url, create_req) @@ -98,7 +107,7 @@ def create(self, group_item): # Create a group based on Active Directory @api(version="2.0") - def create_AD_group(self, group_item, asJob=False): + def create_AD_group(self, group_item: GroupItem, asJob: bool = False) -> Union[GroupItem, JobItem]: asJobparameter = "?asJob=true" if asJob else "" url = self.baseurl + asJobparameter create_req = RequestFactory.Group.create_ad_req(group_item) @@ -110,7 +119,7 @@ def create_AD_group(self, group_item, asJob=False): # Removes 1 user from 1 group @api(version="2.0") - def remove_user(self, group_item, user_id): + def remove_user(self, group_item: GroupItem, user_id: str) -> None: if not group_item.id: error = "Group item missing ID." raise MissingRequiredFieldError(error) @@ -123,7 +132,7 @@ def remove_user(self, group_item, user_id): # Adds 1 user to 1 group @api(version="2.0") - def add_user(self, group_item, user_id): + def add_user(self, group_item: GroupItem, user_id: str) -> UserItem: if not group_item.id: error = "Group item missing ID." raise MissingRequiredFieldError(error) diff --git a/tableauserverclient/server/request_factory.py b/tableauserverclient/server/request_factory.py index bdc7bbc38..c411a7aa7 100644 --- a/tableauserverclient/server/request_factory.py +++ b/tableauserverclient/server/request_factory.py @@ -12,7 +12,6 @@ from ..models import FlowItem from ..models import ConnectionItem - def _add_multipart(parts: Dict) -> Tuple[Any, str]: mime_multipart_parts = list() for name, (filename, data, content_type) in parts.items(): @@ -353,24 +352,30 @@ def publish_req_chunked(self, flow_item, connections=None) -> Tuple[Any, str]: class GroupRequest(object): - def add_user_req(self, user_id): + def add_user_req(self, user_id: str) -> bytes: xml_request = ET.Element("tsRequest") user_element = ET.SubElement(xml_request, "user") user_element.attrib["id"] = user_id return ET.tostring(xml_request) - def create_local_req(self, group_item): + def create_local_req(self, group_item: GroupItem) -> bytes: xml_request = ET.Element("tsRequest") group_element = ET.SubElement(xml_request, "group") - group_element.attrib["name"] = group_item.name + if group_item.name is not None: + group_element.attrib["name"] = group_item.name + else: + raise ValueError("Group name must be populated") if group_item.minimum_site_role is not None: group_element.attrib["minimumSiteRole"] = group_item.minimum_site_role return ET.tostring(xml_request) - def create_ad_req(self, group_item): + def create_ad_req(self, group_item: GroupItem) -> bytes: xml_request = ET.Element("tsRequest") group_element = ET.SubElement(xml_request, "group") - group_element.attrib["name"] = group_item.name + if group_item.name is not None: + group_element.attrib["name"] = group_item.name + else: + raise ValueError("Group name must be populated") import_element = ET.SubElement(group_element, "import") import_element.attrib["source"] = "ActiveDirectory" if group_item.domain_name is None: @@ -384,7 +389,7 @@ def create_ad_req(self, group_item): import_element.attrib["siteRole"] = group_item.minimum_site_role return ET.tostring(xml_request) - def update_req(self, group_item, default_site_role=None): + def update_req(self, group_item: GroupItem, default_site_role: Optional[str] = None) -> bytes: # (1/8/2021): Deprecated starting v0.15 if default_site_role is not None: import warnings @@ -399,13 +404,20 @@ def update_req(self, group_item, default_site_role=None): xml_request = ET.Element("tsRequest") group_element = ET.SubElement(xml_request, "group") - group_element.attrib["name"] = group_item.name + + if group_item.name is not None: + group_element.attrib["name"] = group_item.name + else: + raise ValueError("Group name must be populated") if group_item.domain_name is not None and group_item.domain_name != "local": # Import element is only accepted in the request for AD groups import_element = ET.SubElement(group_element, "import") import_element.attrib["source"] = "ActiveDirectory" import_element.attrib["domainName"] = group_item.domain_name - import_element.attrib["siteRole"] = group_item.minimum_site_role + if isinstance(group_item.minimum_site_role, str): + import_element.attrib["siteRole"] = group_item.minimum_site_role + else: + raise ValueError("Minimum site role must be provided.") if group_item.license_mode is not None: import_element.attrib["grantLicenseMode"] = group_item.license_mode else: diff --git a/test/test_group.py b/test/test_group.py index 63155c6ea..fbfb488f1 100644 --- a/test/test_group.py +++ b/test/test_group.py @@ -19,7 +19,7 @@ class GroupTests(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.server = TSC.Server("http://test") # Fake signin @@ -28,7 +28,7 @@ def setUp(self): self.baseurl = self.server.groups.baseurl - def test_get(self): + def test_get(self) -> None: with open(GET_XML, "rb") as f: response_xml = f.read().decode("utf-8") with requests_mock.mock() as m: @@ -48,11 +48,11 @@ def test_get(self): self.assertEqual("TableauExample", all_groups[2].name) self.assertEqual("local", all_groups[2].domain_name) - def test_get_before_signin(self): + def test_get_before_signin(self) -> None: self.server._auth_token = None self.assertRaises(TSC.NotSignedInError, self.server.groups.get) - def test_populate_users(self): + def test_populate_users(self) -> None: with open(POPULATE_USERS, "rb") as f: response_xml = f.read().decode("utf-8") with requests_mock.mock() as m: @@ -72,12 +72,12 @@ def test_populate_users(self): self.assertEqual("Publisher", user.site_role) self.assertEqual("2016-08-16T23:17:06Z", format_datetime(user.last_login)) - def test_delete(self): + def test_delete(self) -> None: with requests_mock.mock() as m: m.delete(self.baseurl + "/e7833b48-c6f7-47b5-a2a7-36e7dd232758", status_code=204) self.server.groups.delete("e7833b48-c6f7-47b5-a2a7-36e7dd232758") - def test_remove_user(self): + def test_remove_user(self) -> None: with open(POPULATE_USERS, "rb") as f: response_xml_populate = f.read().decode("utf-8") @@ -100,7 +100,7 @@ def test_remove_user(self): m.get(self.baseurl + "/e7833b48-c6f7-47b5-a2a7-36e7dd232758/users", text=response_xml_empty) self.assertEqual(0, len(list(single_group.users))) - def test_add_user(self): + def test_add_user(self) -> None: with open(ADD_USER, "rb") as f: response_xml_add = f.read().decode("utf-8") with open(ADD_USER_POPULATE, "rb") as f: @@ -119,7 +119,7 @@ def test_add_user(self): self.assertEqual("testuser", user.name) self.assertEqual("ServerAdministrator", user.site_role) - def test_add_user_before_populating(self): + def test_add_user_before_populating(self) -> None: with open(GET_XML, "rb") as f: get_xml_response = f.read().decode("utf-8") with open(ADD_USER, "rb") as f: @@ -135,7 +135,7 @@ def test_add_user_before_populating(self): single_group = all_groups[0] self.server.groups.add_user(single_group, "5de011f8-5aa9-4d5b-b991-f462c8dd6bb7") - def test_add_user_missing_user_id(self): + def test_add_user_missing_user_id(self) -> None: with open(POPULATE_USERS, "rb") as f: response_xml = f.read().decode("utf-8") with requests_mock.mock() as m: @@ -146,7 +146,8 @@ def test_add_user_missing_user_id(self): self.assertRaises(ValueError, self.server.groups.add_user, single_group, "") - def test_add_user_missing_group_id(self): + + def test_add_user_missing_group_id(self) -> None: single_group = TSC.GroupItem("test") single_group._users = [] self.assertRaises( @@ -156,7 +157,7 @@ def test_add_user_missing_group_id(self): "5de011f8-5aa9-4d5b-b991-f462c8dd6bb7", ) - def test_remove_user_before_populating(self): + def test_remove_user_before_populating(self) -> None: with open(GET_XML, "rb") as f: response_xml = f.read().decode("utf-8") with requests_mock.mock() as m: @@ -170,7 +171,7 @@ def test_remove_user_before_populating(self): single_group = all_groups[0] self.server.groups.remove_user(single_group, "5de011f8-5aa9-4d5b-b991-f462c8dd6bb7") - def test_remove_user_missing_user_id(self): + def test_remove_user_missing_user_id(self) -> None: with open(POPULATE_USERS, "rb") as f: response_xml = f.read().decode("utf-8") with requests_mock.mock() as m: @@ -181,7 +182,7 @@ def test_remove_user_missing_user_id(self): self.assertRaises(ValueError, self.server.groups.remove_user, single_group, "") - def test_remove_user_missing_group_id(self): + def test_remove_user_missing_group_id(self) -> None: single_group = TSC.GroupItem("test") single_group._users = [] self.assertRaises( @@ -191,7 +192,7 @@ def test_remove_user_missing_group_id(self): "5de011f8-5aa9-4d5b-b991-f462c8dd6bb7", ) - def test_create_group(self): + def test_create_group(self) -> None: with open(CREATE_GROUP, "rb") as f: response_xml = f.read().decode("utf-8") with requests_mock.mock() as m: @@ -201,7 +202,7 @@ def test_create_group(self): self.assertEqual(group.name, u"試供品") self.assertEqual(group.id, "3e4a9ea0-a07a-4fe6-b50f-c345c8c81034") - def test_create_ad_group(self): + def test_create_ad_group(self) -> None: with open(CREATE_GROUP_AD, "rb") as f: response_xml = f.read().decode("utf-8") with requests_mock.mock() as m: @@ -214,7 +215,7 @@ def test_create_ad_group(self): self.assertEqual(group.minimum_site_role, "Creator") self.assertEqual(group.domain_name, "active-directory-domain-name") - def test_create_group_async(self): + def test_create_group_async(self) -> None: with open(CREATE_GROUP_ASYNC, "rb") as f: response_xml = f.read().decode("utf-8") with requests_mock.mock() as m: @@ -225,7 +226,7 @@ def test_create_group_async(self): self.assertEqual(job.mode, "Asynchronous") self.assertEqual(job.type, "GroupImport") - def test_update(self): + def test_update(self) -> None: with open(UPDATE_XML, "rb") as f: response_xml = f.read().decode("utf-8") with requests_mock.mock() as m: @@ -241,7 +242,7 @@ def test_update(self): self.assertEqual("onLogin", group.license_mode) # async update is not supported for local groups - def test_update_local_async(self): + def test_update_local_async(self) -> None: group = TSC.GroupItem("myGroup") group._id = "ef8b19c0-43b6-11e6-af50-63f5805dbe3c" self.assertRaises(ValueError, self.server.groups.update, group, as_job=True)