8000 refactor: use mixin class for tagging endpoints · tableau/server-client-python@f7b510e · GitHub
[go: up one dir, main page]

Skip to content

Commit f7b510e

Browse files
committed
refactor: use mixin class for tagging endpoints
1 parent 3759248 commit f7b510e

File tree

5 files changed

+150
-50
lines changed

5 files changed

+150
-50
lines changed

tableauserverclient/server/endpoint/resource_tagger.py

Lines changed: 60 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,21 @@
1+
import abc
12
import copy
3+
from typing import Generic, Iterable, Set, TypeVar, Union
24
import urllib.parse
35

4-
from .endpoint import Endpoint
5-
from .exceptions import ServerResponseError
6-
from ..exceptions import EndpointUnavailableError
6+
from tableauserverclient.server.endpoint.endpoint import Endpoint
7+
from tableauserverclient.server.endpoint.exceptions import ServerResponseError
8+
from tableauserverclient.server.exceptions import EndpointUnavailableError
79
from tableauserverclient.server import RequestFactory
810
from tableauserverclient.models import TagItem
911

1012
from tableauserverclient.helpers.logging import logger
1113

12-
1314
class _ResourceTagger(Endpoint):
1415
# Add new tags to resource
1516
def _add_tags(self, baseurl, resource_id, tag_set):
16-
url = "{0}/{1}/tags".format(baseurl, resource_id)
17+
url = "{0}/{1}/tags".format(baseurl, resource_id
18+
)
1719
add_req = RequestFactory.Tag.add_req(tag_set)
1820

1921
try:
@@ -49,3 +51,56 @@ def update_tags(self, baseurl, resource_item):
4951
resource_item.tags = self._add_tags(baseurl, resource_item.id, add_set)
5052
resource_item._initial_tags = copy.copy(resource_item.tags)
5153
logger.info("Updated tags to {0}".format(resource_item.tags))
54+
55+
T = TypeVar("T")
56+
57+
class TaggingMixin(Generic[T]):
58+
59+
@abc.abstractmethod
60+
def baseurl(self) -> str:
61+
raise NotImplementedError("baseurl must be implemented.")
62+
63+
def add_tags(self, item: Union[T, str], tags: Union[Iterable[str], str]) -> Set[str]:
64+
item_id = getattr(item, "id", item)
65+
66+
if not isinstance(item_id, str):
67+
raise ValueError("ID not found.")
68+
69+
if isinstance(tags, str):
70+
tag_set = set([tags])
71+
else:
72+
tag_set = set(tags)
73+
74+
url = f"{self.baseurl}/{item_id}/tags"
75+
add_req = RequestFactory.Tag.add_req(tag_set)
76+
server_response = self.put_request(url, add_req)
77+
return TagItem.from_response(server_response.content, self.parent_srv.namespace)
78+
79+
def delete_tags(self, item: Union[T, str], tags: Union[Iterable[str], str]) -> None:
80+
item_id = getattr(item, "id", item)
81+
82+
if not isinstance(item_id, str):
83+
raise ValueError("ID not found.")
84+
85+
if isinstance(tags, str):
86+
tag_set = set([tags])
87+
else:
88+
tag_set = set(tags)
89+
90+
for tag in tag_set:
91+
encoded_tag_name = urllib.parse.quote(tag)
92+
url = f"{self.baseurl}/{item_id}/tags/{encoded_tag_name}"
93+
self.delete_request(url)
94+
95+
def update_tags(self, item: T) -> None:
96+
if item.tags == item._initial_tags:
97+
return
98+
99+
add_set = item.tags - item._initial_tags
100+
remove_set = item._initial_tags - item.tags
101+
self.delete_tags(item, remove_set)
102+
if add_set:
103+
item.tags = self.add_tags(item, add_set)
104+
item._initial_tags = copy.copy(item.tags)
105+
logger.info(f"Updated tags to {item.tags}")
106+

tableauserverclient/server/endpoint/workbooks_endpoint.py

Lines changed: 5 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from tableauserverclient.server.endpoint.endpoint import QuerysetEndpoint, api, parameter_added_in
1212
from tableauserverclient.server.endpoint.exceptions import InternalServerError, MissingRequiredFieldError
1313
from tableauserverclient.server.endpoint.permissions_endpoint import _PermissionsEndpoint
14-
from tableauserverclient.server.endpoint.resource_tagger import _ResourceTagger
14+
from tableauserverclient.server.endpoint.resource_tagger import _ResourceTagger, TaggingMixin
1515

1616
from tableauserverclient.filesys_helpers import (
1717
to_filename,
@@ -58,7 +58,7 @@
5858
PathOrFileW = Union[FilePath, FileObjectW]
5959

6060

61-
class Workbooks(QuerysetEndpoint[WorkbookItem]):
61+
class Workbooks(QuerysetEndpoint[WorkbookItem], TaggingMixin[WorkbookItem]):
6262
def __init__(self, parent_srv: "Server") -> None:
6363
super(Workbooks, self).__init__(parent_srv)
6464
self._resource_tagger = _ResourceTagger(parent_srv)
@@ -501,31 +501,6 @@ def schedule_extract_refresh(
501501
) -> List["AddResponse"]: # actually should return a task
502502
return self.parent_srv.schedules.add_to_schedule(schedule_id, workbook=item)
503503

504-
@api(version="1.0")
505-
def add_tags(self, workbook: Union[WorkbookItem, str], tags: Union[Iterable[str], str]) -> Set[str]:
506-
workbook = getattr(workbook, "id", workbook)
507-
508-
if not isinstance(workbook, str):
509-
raise ValueError("Workbook ID not found.")
510-
511-
if isinstance(tags, str):
512-
tag_set = set([tags])
513-
else:
514-
tag_set = set(tags)
515-
516-
return self._resource_tagger._add_tags(self.baseurl, workbook, tag_set)
517-
518-
@api(version="1.0")
519-
def delete_tags(self, workbook: Union[WorkbookItem, str], tags: Union[Iterable[str], str]) -> None:
520-
workbook = getattr(workbook, "id", workbook)
521-
522-
if not isinstance(workbook, str):
523-
raise ValueError("Workbook ID not found.")
524-
525-
if isinstance(tags, str):
526-
tag_set = set([tags])
527-
else:
528-
tag_set = set(tags)
529-
530-
for tag in tag_set:
531-
self._resource_tagger._delete_tag(self.baseurl, workbook, tag)
504+
Workbooks.add_tags = api(version="1.0")(Workbooks.add_tags)
505+
Workbooks.delete_tags = api(version="1.0")(Workbooks.delete_tags)
506+
Workbooks.update_tags = api(version="1.0")(Workbooks.update_tags)

test/assets/workbook_add_tag.xml

Lines changed: 0 additions & 6 deletions
This file was deleted.

test/assets/workbook_add_tags.xml

Lines changed: 0 additions & 9 deletions
This file was deleted.

test/test_tagging.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
import re
2+
from typing import Iterable
3+
from xml.etree import ElementTree as ET
4+
5+
import pytest
6+
import requests_mock
7+
import tableauserverclient as TSC
8+
9+
@pytest.fixture
10+
def get_server() -> TSC.Server:
11+
server = TSC.Server("http://test", False)
12+
13+
# Fake sign in
14+
server._site_id = "dad65087-b08b-4603-af4e-2887b8aafc67"
15+
server._auth_token = "j80k54ll2lfMZ0tv97mlPvvSCRyD0DOM"
16+
server.version = "3.28"
17+
return server
18+
19+
def xml_response_factory(tags: Iterable[str]) -> str:
20+
root = ET.Element("tsResponse")
21+
tags_element = ET.SubElement(root, "tags")
22+
for tag in tags:
23+
tag_element = ET.SubElement(tags_element, "tag")
24+
tag_element.attrib["label"] = tag
25+
root.attrib["xmlns"] = "http://tableau.com/api"
26+
return ET.tostring(root, encoding="utf-8").decode("utf-8")
27+
28+
def make_workbook() -> TSC.WorkbookItem:
29+
workbook = TSC.WorkbookItem("project", "test")
30+
workbook._id = "06b944d2-959d-4604-9305-12323c95e70e"
31+
return workbook
32+
33+
@pytest.mark.parametrize("endpoint_type, item", [
34+
("workbooks", make_workbook()),
35+
])
36+
@pytest.mark.parametrize("tags", [
37+
"a",
38+
["a", "b"],
39+
])
40+
def test_add_tags(get_server, endpoint_type, item, tags) -> None:
41+
add_tags_xml = xml_response_factory(tags)
42+
endpoint = getattr(get_server, endpoint_type)
43+
id_ = getattr(item, "id", item)
44+
45+
with requests_mock.mock() as m:
46+
m.put(
47+
f"{endpoint.baseurl}/{id_}/tags",
48+
status_code=200,
49+
text=add_tags_xml,
50+
)
51+
tag_result = endpoint.add_tags(item, tags)
52+
53+
if isinstance(tags, str):
54+
tags = [tags]
55+
assert set(tag_result) == set(tags)
56+
57+
@pytest.mark.parametrize("endpoint_type, item", [
58+
("workbooks", make_workbook()),
59+
])
60+
@pytest.mark.parametrize("tags", [
61+
"a",
62+
["a", "b"],
63+
])
64+
def test_delete_tags(get_server, endpoint_type, item, tags) -> None:
65+
add_tags_xml = xml_response_factory(tags)
66+
endpoint = getattr(get_server, endpoint_type)
67+
id_ = getattr(item, "id", item)
68+
69+
if isinstance(tags, str):
70+
tags = [tags]
71+
tag_paths = "|".join(tags)
72+
tag_paths = f"({tag_paths})"
73+
matcher = re.compile(rf"{endpoint.baseurl}\/{id_}\/tags\/{tag_paths}")
74+
with requests_mock.mock() as m:
75+
m.delete(
76+
matcher,
77+
status_code=200,
78+
text=add_tags_xml,
79+
)
80+
endpoint.delete_tags(item, tags)
81+
history = m.request_history
82+
83+
assert len(history) == len(tags)
84+
urls = sorted([r.url.split("/")[-1] for r in history])
85+
assert set(urls) == set(tags)

0 commit comments

Comments
 (0)
0