8000 Merge branch 'django-style-filters' · scuml/server-client-python@633813d · GitHub
[go: up one dir, main page]

Skip to content

Commit 633813d

Browse files
committed
Merge branch 'django-style-filters'
2 parents 5c6bae0 + a92f146 commit 633813d

File tree

8 files changed

+173
-17
lines changed

8 files changed

+173
-17
lines changed

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,10 @@ target/
7676
# pyenv
7777
.python-version
7878

79+
# poetry
80+
poetry.lock
81+
pyproject.toml
82+
7983
# celery beat schedule file
8084
celerybeat-schedule
8185

tableauserverclient/server/endpoint/datasources_endpoint.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
1-
from .endpoint import Endpoint, api, parameter_added_in
1+
from .endpoint import QuerysetEndpoint, api, parameter_added_in
22
from .exceptions import InternalServerError, MissingRequiredFieldError
3-
from .endpoint import api, parameter_added_in, Endpoint
43
from .permissions_endpoint import _PermissionsEndpoint
5-
from .exceptions import MissingRequiredFieldError
64
from .fileuploads_endpoint import Fileuploads
75
from .resource_tagger import _ResourceTagger
86
from .. import RequestFactory, DatasourceItem, PaginationItem, ConnectionItem
7+
from ..query import QuerySet
98
from ...filesys_helpers import to_filename, make_download_path
10-
from ...models.tag_item import TagItem
119
from ...models.job_item import JobItem
1210
import os
1311
import logging
@@ -23,7 +21,7 @@
2321
logger = logging.getLogger('tableau.endpoint.datasources')
2422

2523

26-
class Datasources(Endpoint):
24+
class Datasources(QuerysetEndpoint):
2725
def __init__(self, parent_srv):
2826
super(Datasources, self).__init__(parent_srv)
2927
self._resource_tagger = _ResourceTagger(parent_srv)

tableauserverclient/server/endpoint/endpoint.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from .exceptions import ServerResponseError, InternalServerError, NonXMLResponseError
22
from functools import wraps
33
from xml.etree.ElementTree import ParseError
4-
4+
from ..query import QuerySet
55
import logging
66

77
try:
@@ -165,3 +165,25 @@ def wrapper(self, *args, **kwargs):
165165
return func(self, *args, **kwargs)
166166
return wrapper
167167
return _decorator
168+
169+
170+
class QuerysetEndpoint(Endpoint):
171+
@api(version="2.0")
172+
def all(self, *args, **kwargs):
173+
queryset = QuerySet(self)
174+
return queryset
175+
176+
@api(version="2.0")
177+
def filter(self, *args, **kwargs):
178+
queryset = QuerySet(self).filter(**kwargs)
179+
return queryset
180+
181+
@api(version="2.0")
182+
def order_by(self, *args, **kwargs):
183+
queryset = QuerySet(self).order_by(*args)
184+
return queryset
185+
186+
@api(version="2.0")
187+
def paginate(self, **kwargs):
188+
queryset = QuerySet(self).paginate(**kwargs)
189+
return queryset

tableauserverclient/server/endpoint/users_endpoint.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .endpoint import Endpoint, api
1+
from .endpoint import QuerysetEndpoint, api
22
from .exceptions import MissingRequiredFieldError
33
from .. import RequestFactory, UserItem, WorkbookItem, PaginationItem
44
from ..pager import Pager
@@ -8,7 +8,7 @@
88
logger = logging.getLogger('tableau.endpoint.users')
99

1010

11-
class Users(Endpoint):
11+
class Users(QuerysetEndpoint):
1212
@property
1313
def baseurl(self):
1414
return "{0}/sites/{1}/users".format(self.parent_srv.baseurl, self.parent_srv.site_id)

tableauserverclient/server/endpoint/views_endpoint.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,15 @@
1-
from .endpoint import Endpoint, api
1+
from .endpoint import QuerysetEndpoint, api
22
from .exceptions import MissingRequiredFieldError
33
from .resource_tagger import _ResourceTagger
44
from .permissions_endpoint import _PermissionsEndpoint
5-
from .. import RequestFactory, ViewItem, PaginationItem
6-
from ...models.tag_item import TagItem
5+
from .. import ViewItem, PaginationItem
76
import logging
87
from contextlib import closing
98

109
logger = logging.getLogger('tableau.endpoint.views')
1110

1211

13-
class Views(Endpoint):
12+
class Views(QuerysetEndpoint):
1413
def __init__(self, parent_srv):
1514
super(Views, self).__init__(parent_srv)
1615
self._resource_tagger = _ResourceTagger(parent_srv)

tableauserverclient/server/endpoint/workbooks_endpoint.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
1-
from .endpoint import Endpoint, api, parameter_added_in
1+
from .endpoint import QuerysetEndpoint, api, parameter_added_in
22
from .exceptions import InternalServerError, MissingRequiredFieldError
33
from .permissions_endpoint import _PermissionsEndpoint
4-
from .exceptions import MissingRequiredFieldError
54
from .fileuploads_endpoint import Fileuploads
65
from .resource_tagger import _ResourceTagger
76
from .. import RequestFactory, WorkbookItem, ConnectionItem, ViewItem, PaginationItem
@@ -23,7 +22,7 @@
2322
logger = logging.getLogger('tableau.endpoint.workbooks')
2423

2524

26-
class Workbooks(Endpoint):
25+
class Workbooks(QuerysetEndpoint):
2726
def __init__(self, parent_srv):
2827
super(Workbooks, self).__init__(parent_srv)
2928
self._resource_tagger = _ResourceTagger(parent_srv)
@@ -39,8 +38,10 @@ def get(self, req_options=None):
3938
logger.info('Querying all workbooks on site')
4039
url = self.baseurl
4140
server_response = self.get_request(url, req_options)
42-
pagination_item = PaginationItem.from_response(server_response.content, self.parent_srv.namespace)
43-
all_workbook_items = WorkbookItem.from_response(server_response.content, self.parent_srv.namespace)
41+
pagination_item = PaginationItem.from_response(
42+
server_response.content, self.parent_srv.namespace)
43+
all_workbook_items = WorkbookItem.from_response(
44+
server_response.content, self.parent_srv.namespace)
4445
return all_workbook_items, pagination_item
4546

4647
# Get 1 workbook

tableauserverclient/server/query.py

Lines changed: 89 additions & 0 deletions
10000
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
from .request_options import RequestOptions
2+
from .filter import Filter
3+
from .sort import Sort
4+
5+
6+
def to_camel_case(word):
7+
return word.split('_')[0] + ''.join(x.capitalize() or '_' for x in word.split('_')[1:])
8+
9+
10+
class QuerySet:
11+
12+
def __init__(self, model):
13+
self.model = model
14+
self.request_options = RequestOptions()
15+
self._result_cache = None
16+
self._pagination_item = None
17+
18+
def __iter__(self):
19+
self._fetch_all()
20+
return iter(self._result_cache)
21+
22+
def __getitem__(self, k):
23+
return list(self)[k]
24+
25+
def _fetch_all(self):
26+
"""
27+
Retrieve the data and store result and pagination item in cache
28+
"""
29+
if self._result_cache is None:
30+
self._result_cache, self._pagination_item = self.model.get(self.request_options)
31+
32+
@property
33+
def total_available(self):
34+
self._fetch_all()
35+
return self._pagination_item.total_available
36+
37+
@property
38+
def page_number(self):
39+
self._fetch_all()
40+
return self._pagination_item.page_number
41+
42+
@property
43+
def page_size(self):
44+
self._fetch_all()
45+
return self._pagination_item.page_size
46+
47+
def filter(self, **kwargs):
48+
for kwarg_key, value in kwargs.items():
49+
field_name, operator = self._parse_shorthand_filter(kwarg_key)
50+
self.request_options.filter.add(Filter(field_name, operator, value))
51+
return self
52+
53+
def order_by(self, *args):
54+
for arg in args:
55+
field_name, direction = self._parse_shorthand_sort(arg)
56+
self.request_options.sort.add(Sort(field_name, direction))
57+
return self
58+
59+
def paginate(self, **kwargs):
60+
if "page_number" in kwargs:
61+
self.request_options.pagenumber = kwargs["page_number"]
62+
if "page_size" in kwargs:
63+
self.request_options.pagesize = kwargs["page_size"]
64+
return self
65+
66+
def _parse_shorthand_filter(self, key):
67+
tokens = key.split("__", 1)
68+
if len(tokens) == 1:
69+
operator = RequestOptions.Operator.Equals
70+
else:
71+
operator = tokens[1]
72+
if operator not in RequestOptions.Operator.__dict__.values():
73+
raise ValueError("Operator `{}` is not valid.".format(operator))
74+
75+
field = to_camel_case(tokens[0])
76+
if field not in RequestOptions.Field.__dict__.values():
77+
raise ValueError("Field name `{}` is not valid.".format(field))
78+
return (field, operator)
79+
80+
def _parse_shorthand_sort(self, key):
81+
direction = RequestOptions.Direction.Asc
82+
if key.startswith("-"):
83+
direction = RequestOptions.Direction.Desc
84+
key = key[1:]
85+
86+
key = to_camel_case(key)
87+
if key not in RequestOptions.Field.__dict__.values():
88+
raise ValueError("Sort key name %s is not valid.", key)
89+
return (key, direction)

test/test_request_option.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,17 @@ def test_filter_equals(self):
7676
self.assertEqual('RESTAPISample', matching_workbooks[0].name)
7777
self.assertEqual('RESTAPISample', matching_workbooks[1].name)
7878

79+
def test_filter_equals_shorthand(self):
80+
with open(FILTER_EQUALS, 'rb') as f:
81+
response_xml = f.read().decode('utf-8')
82+
with requests_mock.mock() as m:
83+
m.get(self.baseurl + '/workbooks?filter=name:eq:RESTAPISample', text=response_xml)
84+
matching_workbooks = self.server.workbooks.filter(name='RESTAPISample').order_by("name")
85+
86+
self.assertEqual(2, matching_workbooks.total_available)
87+
self.assertEqual('RESTAPISample', matching_workbooks[0].name)
88+
self.assertEqual('RESTAPISample', matching_workbooks[1].name)
89+
7990
def test_filter_tags_in(self):
8091
with open(FILTER_TAGS_IN, 'rb') as f:
8192
response_xml = f.read().decode('utf-8')
@@ -91,6 +102,22 @@ def test_filter_tags_in(self):
91102
self.assertEqual(set(['safari']), matching_workbooks[1].tags)
92103
self.assertEqual(set(['sample']), matching_workbooks[2].tags)
93104

105+
def test_filter_tags_in_shorthand(self):
106+
with open(FILTER_TAGS_IN, 'rb') as f:
107+
response_xml = f.read().decode('utf-8')
108+
with requests_mock.mock() as m:
109+
m.get(self.baseurl + '/workbooks?filter=tags:in:[sample,safari,weather]', text=response_xml)
110+
matching_workbooks = self.server.workbooks.filter(tags__in=['sample', 'safari', 'weather'])
111+
112+
self.assertEqual(3, matching_workbooks.total_available)
113+
self.assertEqual(set(['weather']), matching_workbooks[0].tags)
114+
self.assertEqual(set(['safari']), matching_workbooks[1].tags)
115+
self.assertEqual(set(['sample']), matching_workbooks[2].tags)
116+
117+
def test_invalid_shorthand_option(self):
118+
with self.assertRaises(ValueError):
119+
self.server.workbooks.filter(nonexistant__in=['sample', 'safari'])
120+
94121
def test_multiple_filter_options(self):
95122
with open(FILTER_MULTIPLE, 'rb') as f:
96123
response_xml = f.read().decode('utf-8')
@@ -107,3 +134,19 @@ def test_multiple_filter_options(self):
107134
for _ in range(100):
108135
matching_workbooks, pagination_item = self.server.workbooks.get(req_option)
109136
self.assertEqual(3, pagination_item.total_available)
137+
138+
def test_multiple_filter_options_shorthand(self):
139+
with open(FILTER_MULTIPLE, 'rb') as f:
140+
response_xml = f.read().decode('utf-8')
141+
# To ensure that this is deterministic, run this a few times
142+
with requests_mock.mock() as m:
143+
# Sometimes pep8 requires you to do things you might not otherwise do
144+
url = ''.join((self.baseurl, '/workbooks?pageNumber=1&pageSize=100&',
145+
'filter=name:eq:foo,tags:in:[sample,safari,weather]'))
146+
m.get(url, text=response_xml)
147+
148+
for _ in range(100):
149+
matching_workbooks = self.server.workbooks.filter(
150+
tags__in=['sample', 'safari', 'weather'], name='foo'
151+
)
152+
self.assertEqual(3, matching_workbooks.total_available)

0 commit comments

Comments
 (0)
0